[enhancement](nereids) make SSB works (#10659)

enhancement
- refactor compute output expression on root fragment in nereids planner
- refactor aggregate plan translator
- refactor aggregate disassemble rule
- slightly refactor sort plan translator
- add exchange node on the top of plan node tree if it is needed
- slightly refactor PhysicalPlanTranslator#translatePlan

fix
- slotDescriptor should not reuse between TupleDescriptors
- expression's nullable now works fine
- remove quotes when parse string literal
- set resolvedTupleExprs in SortNode to control output
- remove the extra column in sortTupleSlotExprs in SortInfo

known issues
- aggregate function must be the top expression in output expression (need project in ExecNode in BE)
- first phase aggregate could not convert to stream mode.
- OlapScanNode do not set data partition
- Sort could not process expression like 'order by a + 1' and SortInfo generated in a trick way and should be refactor when we want to support 'order by a + 1'
- column prune do not work as expected
This commit is contained in:
morrySnow
2022-07-11 11:33:17 +08:00
committed by GitHub
parent a044b5dcc5
commit 1dccfa3d84
59 changed files with 895 additions and 434 deletions

View File

@ -1267,6 +1267,20 @@ public class FunctionCallExpr extends Expr {
@Override
public void finalizeImplForNereids() throws AnalysisException {
super.finalizeImplForNereids();
// TODO: support other functions
if (fnName.getFunction().equalsIgnoreCase("sum")) {
// Prevent the cast type in vector exec engine
Type childType = getChild(0).type.getMaxResolutionType();
fn = getBuiltinFunction(fnName.getFunction(), new Type[]{childType},
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
type = fn.getReturnType();
}
}
/**
* NOTICE: This function only used for Nereids, should not call it if u don't know what it is mean.
*/
public void setMergeForNereids(boolean isMergeAggFn) {
this.isMergeAggFn = isMergeAggFn;
}
}

View File

@ -35,4 +35,6 @@ public interface Queriable {
List<Expr> getResultExprs();
ArrayList<String> getColLabels();
String toDigest();
}

View File

@ -18,21 +18,17 @@
package org.apache.doris.nereids;
import org.apache.doris.analysis.DescriptorTable;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.SlotDescriptor;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.StatementBase;
import org.apache.doris.analysis.TupleDescriptor;
import org.apache.doris.analysis.TupleId;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.UserException;
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator;
import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
import org.apache.doris.nereids.jobs.AnalyzeRulesJob;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.OptimizeRulesJob;
import org.apache.doris.nereids.jobs.PredicatePushDownRulesJob;
import org.apache.doris.nereids.jobs.batch.AnalyzeRulesJob;
import org.apache.doris.nereids.jobs.batch.DisassembleRulesJob;
import org.apache.doris.nereids.jobs.batch.OptimizeRulesJob;
import org.apache.doris.nereids.jobs.batch.PredicatePushDownRulesJob;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.memo.Memo;
@ -47,12 +43,9 @@ import org.apache.doris.planner.ScanNode;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
@ -81,33 +74,14 @@ public class NereidsPlanner extends Planner {
PhysicalPlanTranslator physicalPlanTranslator = new PhysicalPlanTranslator();
PlanTranslatorContext planTranslatorContext = new PlanTranslatorContext();
physicalPlanTranslator.translatePlan(physicalPlan, planTranslatorContext);
PlanFragment root = physicalPlanTranslator.translatePlan(physicalPlan, planTranslatorContext);
scanNodeList = planTranslatorContext.getScanNodeList();
descTable = planTranslatorContext.getDescTable();
fragments = new ArrayList<>(planTranslatorContext.getPlanFragmentList());
for (PlanFragment fragment : fragments) {
fragment.finalize(queryStmt);
}
Collections.reverse(fragments);
PlanFragment root = fragments.get(0);
// compute output exprs
Map<Integer, Expr> outputCandidates = Maps.newHashMap();
List<Expr> outputExprs = Lists.newArrayList();
for (TupleId tupleId : root.getPlanRoot().getTupleIds()) {
TupleDescriptor tupleDescriptor = descTable.getTupleDesc(tupleId);
for (SlotDescriptor slotDescriptor : tupleDescriptor.getSlots()) {
SlotRef slotRef = new SlotRef(slotDescriptor);
outputCandidates.put(slotDescriptor.getId().asInt(), slotRef);
}
}
physicalPlan.getOutput().stream()
.forEach(i -> outputExprs.add(planTranslatorContext.findExpr(i)));
root.setOutputExprs(outputExprs);
root.getPlanRoot().convertToVectoriezd();
logicalPlanAdapter.setResultExprs(outputExprs);
// set output exprs
logicalPlanAdapter.setResultExprs(root.getOutputExprs());
ArrayList<String> columnLabelList = physicalPlan.getOutput().stream()
.map(NamedExpression::getName).collect(Collectors.toCollection(ArrayList::new));
logicalPlanAdapter.setColLabels(columnLabelList);
@ -147,6 +121,9 @@ public class NereidsPlanner extends Planner {
PredicatePushDownRulesJob predicatePushDownRulesJob = new PredicatePushDownRulesJob(plannerContext);
predicatePushDownRulesJob.execute();
DisassembleRulesJob disassembleRulesJob = new DisassembleRulesJob(plannerContext);
disassembleRulesJob.execute();
OptimizeRulesJob optimizeRulesJob = new OptimizeRulesJob(plannerContext);
optimizeRulesJob.execute();

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.analyzer;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.NodeType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
@ -25,6 +26,7 @@ import com.google.common.base.Joiner;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* Expression for unbound function.
@ -52,6 +54,14 @@ public class UnboundFunction extends Expression implements Unbound {
return children();
}
@Override
public String toSql() throws UnboundException {
String params = children.stream()
.map(Expression::toSql)
.collect(Collectors.joining(", "));
return name + "(" + (isDistinct ? "DISTINCT " : "") + params + ")";
}
@Override
public String toString() {
String params = Joiner.on(", ").join(children);

View File

@ -53,7 +53,7 @@ public class UnboundSlot extends Slot implements Unbound {
}
@Override
public String sql() {
public String toSql() {
return nameParts.stream().map(Utils::quoteIfNeeded).reduce((left, right) -> left + "." + right).orElse("");
}

View File

@ -40,7 +40,7 @@ public class UnboundStar extends NamedExpression implements LeafExpression, Unbo
}
@Override
public String sql() {
public String toSql() {
String qualified = qualifier.stream().map(Utils::quoteIfNeeded).reduce((t1, t2) -> t1 + "." + t2).orElse("");
if (StringUtils.isNotEmpty(qualified)) {
return qualified + ".*";
@ -56,7 +56,7 @@ public class UnboundStar extends NamedExpression implements LeafExpression, Unbo
@Override
public String toString() {
return sql();
return toSql();
}
@Override

View File

@ -22,7 +22,7 @@ import org.apache.doris.nereids.PlanContext;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.operators.Operator;
import org.apache.doris.nereids.operators.OperatorVisitor;
import org.apache.doris.nereids.operators.plans.physical.PhysicalAggregation;
import org.apache.doris.nereids.operators.plans.physical.PhysicalAggregate;
import org.apache.doris.nereids.operators.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.operators.plans.physical.PhysicalOlapScan;
import org.apache.doris.nereids.operators.plans.physical.PhysicalProject;
@ -62,7 +62,7 @@ public class CostCalculator {
}
@Override
public CostEstimate visitPhysicalAggregation(PhysicalAggregation physicalAggregation, PlanContext context) {
public CostEstimate visitPhysicalAggregation(PhysicalAggregate physicalAggregate, PlanContext context) {
StatsDeriveResult statistics = context.getStatisticsWithCheck();
return CostEstimate.ofCpu(statistics.computeSize());
}

View File

@ -77,4 +77,9 @@ public class LogicalPlanAdapter extends StatementBase implements Queriable {
public void setColLabels(ArrayList<String> colLabels) {
this.colLabels = colLabels;
}
public String toDigest() {
// TODO: generate real digest
return "";
}
}

View File

@ -30,6 +30,7 @@ import org.apache.doris.analysis.StringLiteral;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.NodeType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Arithmetic;
import org.apache.doris.nereids.trees.expressions.Between;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
@ -83,8 +84,8 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
}
@Override
public Expr visitSlotReference(SlotReference slotReference, PlanTranslatorContext context) {
return context.findExpr(slotReference);
public Expr visitAlias(Alias alias, PlanTranslatorContext context) {
return alias.child().accept(this, context);
}
@Override
@ -122,6 +123,13 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
lessThanEqual.child(1).accept(this, context));
}
@Override
public Expr visitNullSafeEqual(NullSafeEqual nullSafeEqual, PlanTranslatorContext context) {
return new BinaryPredicate(Operator.EQ_FOR_NULL,
nullSafeEqual.child(0).accept(this, context),
nullSafeEqual.child(1).accept(this, context));
}
@Override
public Expr visitNot(Not not, PlanTranslatorContext context) {
return new org.apache.doris.analysis.CompoundPredicate(
@ -131,10 +139,8 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
}
@Override
public Expr visitNullSafeEqual(NullSafeEqual nullSafeEqual, PlanTranslatorContext context) {
return new BinaryPredicate(Operator.EQ_FOR_NULL,
nullSafeEqual.child(0).accept(this, context),
nullSafeEqual.child(1).accept(this, context));
public Expr visitSlotReference(SlotReference slotReference, PlanTranslatorContext context) {
return context.findSlotRef(slotReference.getExprId());
}
/**
@ -157,16 +163,6 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
throw new RuntimeException(String.format("Unsupported data type: %s", dataType.toString()));
}
// TODO: Supports for `distinct`
@Override
public Expr visitBoundFunction(BoundFunction function, PlanTranslatorContext context) {
List<Expr> paramList = new ArrayList<>();
for (Expression expr : function.getArguments()) {
paramList.add(expr.accept(this, context));
}
return new FunctionCallExpr(function.getName(), paramList);
}
@Override
public Expr visitBetween(Between between, PlanTranslatorContext context) {
throw new RuntimeException("Unexpected invocation");
@ -194,6 +190,16 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
compoundPredicate.child(1).accept(this, context));
}
// TODO: Supports for `distinct`
@Override
public Expr visitBoundFunction(BoundFunction function, PlanTranslatorContext context) {
List<Expr> paramList = new ArrayList<>();
for (Expression expr : function.getArguments()) {
paramList.add(expr.accept(this, context));
}
return new FunctionCallExpr(function.getName(), paramList);
}
@Override
public Expr visitArithmetic(Arithmetic arithmetic, PlanTranslatorContext context) {
Arithmetic.ArithmeticOperator arithmeticOperator = arithmetic.getArithmeticOperator();
@ -201,5 +207,4 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
arithmetic.child(0).accept(this, context),
arithmeticOperator.isBinary() ? arithmetic.child(1).accept(this, context) : null);
}
}

View File

@ -29,8 +29,9 @@ import org.apache.doris.analysis.TupleDescriptor;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.Table;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.operators.plans.AggPhase;
import org.apache.doris.nereids.operators.plans.JoinType;
import org.apache.doris.nereids.operators.plans.physical.PhysicalAggregation;
import org.apache.doris.nereids.operators.plans.physical.PhysicalAggregate;
import org.apache.doris.nereids.operators.plans.physical.PhysicalFilter;
import org.apache.doris.nereids.operators.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.operators.plans.physical.PhysicalHeapSort;
@ -54,7 +55,6 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalUnaryPlan;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.planner.AggregationNode;
import org.apache.doris.planner.CrossJoinNode;
import org.apache.doris.planner.DataPartition;
import org.apache.doris.planner.ExchangeNode;
import org.apache.doris.planner.HashJoinNode;
@ -66,8 +66,11 @@ import org.apache.doris.planner.SortNode;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.commons.collections.CollectionUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@ -100,8 +103,28 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
}
}
public void translatePlan(PhysicalPlan physicalPlan, PlanTranslatorContext context) {
visit(physicalPlan, context);
/**
* Translate Nereids Physical Plan tree to Stale Planner PlanFragment tree.
*
* @param physicalPlan Nereids Physical Plan tree
* @param context context to help translate
* @return Stale Planner PlanFragment tree
*/
public PlanFragment translatePlan(PhysicalPlan physicalPlan, PlanTranslatorContext context) {
PlanFragment rootFragment = visit(physicalPlan, context);
if (rootFragment.isPartitioned() && rootFragment.getPlanRoot().getNumInstances() > 1) {
rootFragment = exchangeToMergeFragment(rootFragment, context);
}
List<Expr> outputExprs = Lists.newArrayList();
physicalPlan.getOutput().stream().map(Slot::getExprId)
.forEach(exprId -> outputExprs.add(context.findSlotRef(exprId)));
rootFragment.setOutputExprs(outputExprs);
rootFragment.getPlanRoot().convertToVectoriezd();
for (PlanFragment fragment : context.getPlanFragmentList()) {
fragment.finalize(null);
}
Collections.reverse(context.getPlanFragmentList());
return rootFragment;
}
@Override
@ -112,62 +135,92 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
/**
* Translate Agg.
* todo: support DISTINCT
*/
@Override
public PlanFragment visitPhysicalAggregation(
PhysicalUnaryPlan<PhysicalAggregation, Plan> agg, PlanTranslatorContext context) {
public PlanFragment visitPhysicalAggregate(
PhysicalUnaryPlan<PhysicalAggregate, Plan> aggregate, PlanTranslatorContext context) {
PlanFragment inputPlanFragment = visit(aggregate.child(0), context);
PhysicalAggregate physicalAggregate = aggregate.getOperator();
PlanFragment inputPlanFragment = visit(agg.child(0), context);
// TODO: stale planner generate aggregate tuple in a special way. tuple include 2 parts:
// 1. group by expressions: removing duplicate expressions add to tuple
// 2. agg functions: only removing duplicate agg functions in output expression should appear in tuple.
// e.g. select sum(v1) + 1, sum(v1) + 2 from t1 should only generate one sum(v1) in tuple
// We need:
// 1. add a project after agg, if agg function is not the top output expression.
// 2. introduce canonicalized, semanticEquals and deterministic in Expression
// for removing duplicate.
List<Expression> groupByExpressionList = physicalAggregate.getGroupByExprList();
List<NamedExpression> outputExpressionList = physicalAggregate.getOutputExpressionList();
AggregationNode aggregationNode;
List<Slot> slotList = new ArrayList<>();
PhysicalAggregation physicalAggregation = agg.getOperator();
AggregateInfo.AggPhase phase = physicalAggregation.getAggPhase().toExec();
List<Expression> groupByExpressionList = physicalAggregation.getGroupByExprList();
// 1. generate slot reference for each group expression
List<SlotReference> groupSlotList = Lists.newArrayList();
for (Expression e : groupByExpressionList) {
if (e instanceof SlotReference && outputExpressionList.stream().anyMatch(o -> o.anyMatch(e::equals))) {
groupSlotList.add((SlotReference) e);
} else {
groupSlotList.add(new SlotReference(e.toSql(), e.getDataType(), e.nullable(), Collections.emptyList()));
}
}
ArrayList<Expr> execGroupingExpressions = groupByExpressionList.stream()
// Since output of plan doesn't contain the slots of groupBy, which is actually needed by
// the BE execution, so we have to collect them and add to the slotList to generate corresponding
// TupleDesc.
.peek(x -> slotList.addAll(x.collect(SlotReference.class::isInstance)))
.map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toCollection(ArrayList::new));
slotList.addAll(agg.getOutput());
TupleDescriptor outputTupleDesc = generateTupleDesc(slotList, context, null);
List<NamedExpression> outputExpressionList = physicalAggregation.getOutputExpressionList();
ArrayList<FunctionCallExpr> execAggExpressions = outputExpressionList.stream()
.map(e -> e.<List<AggregateFunction>>collect(AggregateFunction.class::isInstance))
// 2. collect agg functions and generate agg function to slot reference map
List<Slot> aggFunctionOutput = Lists.newArrayList();
List<AggregateFunction> aggregateFunctionList = outputExpressionList.stream()
.filter(o -> o.anyMatch(AggregateFunction.class::isInstance))
.peek(o -> aggFunctionOutput.add(o.toSlot()))
.map(o -> o.<List<AggregateFunction>>collect(AggregateFunction.class::isInstance))
.flatMap(List::stream)
.collect(Collectors.toList());
ArrayList<FunctionCallExpr> execAggregateFunctions = aggregateFunctionList.stream()
.map(x -> (FunctionCallExpr) ExpressionTranslator.translate(x, context))
.collect(Collectors.toCollection(ArrayList::new));
List<Expression> partitionExpressionList = physicalAggregation.getPartitionExprList();
// 3. generate output tuple
// TODO: currently, we only support sum(a), if we want to support sum(a) + 1, we need to
// split merge agg to project(agg) and generate tuple like what first phase agg do.
List<Slot> slotList = Lists.newArrayList();
TupleDescriptor outputTupleDesc;
if (physicalAggregate.getAggPhase() == AggPhase.GLOBAL) {
slotList.addAll(groupSlotList);
slotList.addAll(aggFunctionOutput);
outputTupleDesc = generateTupleDesc(slotList, null, context);
} else {
outputTupleDesc = generateTupleDesc(aggregate.getOutput(), null, context);
}
// process partition list
List<Expression> partitionExpressionList = physicalAggregate.getPartitionExprList();
List<Expr> execPartitionExpressions = partitionExpressionList.stream()
.map(e -> (FunctionCallExpr) ExpressionTranslator.translate(e, context)).collect(Collectors.toList());
// todo: support DISTINCT
AggregateInfo aggInfo;
switch (phase) {
case FIRST:
aggInfo = AggregateInfo.create(execGroupingExpressions, execAggExpressions, outputTupleDesc,
outputTupleDesc, AggregateInfo.AggPhase.FIRST);
aggregationNode = new AggregationNode(context.nextNodeId(), inputPlanFragment.getPlanRoot(), aggInfo);
.map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toList());
DataPartition mergePartition = DataPartition.UNPARTITIONED;
if (CollectionUtils.isNotEmpty(execPartitionExpressions)) {
mergePartition = DataPartition.hashPartitioned(execGroupingExpressions);
}
if (physicalAggregate.getAggPhase() == AggPhase.GLOBAL) {
for (FunctionCallExpr execAggregateFunction : execAggregateFunctions) {
execAggregateFunction.setMergeForNereids(true);
}
}
AggregateInfo aggInfo = AggregateInfo.create(execGroupingExpressions, execAggregateFunctions, outputTupleDesc,
outputTupleDesc, physicalAggregate.getAggPhase().toExec());
AggregationNode aggregationNode = new AggregationNode(context.nextPlanNodeId(),
inputPlanFragment.getPlanRoot(), aggInfo);
inputPlanFragment.setPlanRoot(aggregationNode);
switch (physicalAggregate.getAggPhase()) {
case LOCAL:
aggregationNode.unsetNeedsFinalize();
aggregationNode.setUseStreamingPreagg(physicalAggregation.isUsingStream());
aggregationNode.setUseStreamingPreagg(physicalAggregate.isUsingStream());
aggregationNode.setIntermediateTuple();
if (!partitionExpressionList.isEmpty()) {
inputPlanFragment.setOutputPartition(DataPartition.hashPartitioned(execPartitionExpressions));
}
break;
case FIRST_MERGE:
aggInfo = AggregateInfo.create(execGroupingExpressions, execAggExpressions, outputTupleDesc,
outputTupleDesc, AggregateInfo.AggPhase.FIRST_MERGE);
aggregationNode = new AggregationNode(context.nextNodeId(), inputPlanFragment.getPlanRoot(), aggInfo);
break;
return createParentFragment(inputPlanFragment, mergePartition, context);
case GLOBAL:
inputPlanFragment.updateDataPartition(mergePartition);
return inputPlanFragment;
default:
throw new RuntimeException("Unsupported yet");
}
inputPlanFragment.setPlanRoot(aggregationNode);
return inputPlanFragment;
}
@Override
@ -181,9 +234,9 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
.getExpressions()
.stream()
.map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toList());
TupleDescriptor tupleDescriptor = generateTupleDesc(slotList, context, olapTable);
TupleDescriptor tupleDescriptor = generateTupleDesc(slotList, olapTable, context);
tupleDescriptor.setTable(olapTable);
OlapScanNode olapScanNode = new OlapScanNode(context.nextNodeId(), tupleDescriptor, olapTable.getName());
OlapScanNode olapScanNode = new OlapScanNode(context.nextPlanNodeId(), tupleDescriptor, olapTable.getName());
// TODO: Do we really need tableName here?
TableName tableName = new TableName(null, "", "");
TableRef ref = new TableRef(tableName, null, null);
@ -199,6 +252,7 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
olapScanNode.addConjuncts(execConjunctsList);
context.addScanNode(olapScanNode);
// Create PlanFragment
// TODO: add data partition after we have physical properties
PlanFragment planFragment = new PlanFragment(context.nextFragmentId(), olapScanNode, DataPartition.RANDOM);
context.addPlanFragment(planFragment);
return planFragment;
@ -220,12 +274,12 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
* But eg:
* select a+1 from table order by a+1;
* the expressions of the two are inconsistent.
* The former will perform an additional Alisa.
* The former will perform an additional Alias.
* Currently we cannot test whether this will have any effect.
* After a+1 can be parsed , reprocessing.
*/
@Override
public PlanFragment visitPhysicalSort(PhysicalUnaryPlan<PhysicalHeapSort, Plan> sort,
public PlanFragment visitPhysicalHeapSort(PhysicalUnaryPlan<PhysicalHeapSort, Plan> sort,
PlanTranslatorContext context) {
PlanFragment childFragment = visit(sort.child(0), context);
PhysicalHeapSort physicalHeapSort = sort.getOperator();
@ -257,7 +311,7 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
SortInfo sortInfo = new SortInfo(newOrderingExprList, ascOrderList, nullsFirstParamList, tupleDesc);
PlanNode childNode = childFragment.getPlanRoot();
// TODO: notice topN
SortNode sortNode = new SortNode(context.nextNodeId(), childNode, sortInfo, true,
SortNode sortNode = new SortNode(context.nextPlanNodeId(), childNode, sortInfo, true,
physicalHeapSort.hasLimit(), physicalHeapSort.getOffset());
sortNode.finalizeForNereids(tupleDesc, sortTupleOutputList, oldOrderingExprList);
childFragment.addPlanRoot(sortNode);
@ -265,13 +319,13 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
return childFragment;
}
PlanFragment mergeFragment = createParentFragment(childFragment, DataPartition.UNPARTITIONED, context);
ExchangeNode exchNode = (ExchangeNode) mergeFragment.getPlanRoot();
exchNode.unsetLimit();
ExchangeNode exchangeNode = (ExchangeNode) mergeFragment.getPlanRoot();
exchangeNode.unsetLimit();
if (physicalHeapSort.hasLimit()) {
exchNode.setLimit(limit);
exchangeNode.setLimit(limit);
}
long offset = physicalHeapSort.getOffset();
exchNode.setMergeInfo(sortNode.getSortInfo(), offset);
exchangeNode.setMergeInfo(sortNode.getSortInfo(), offset);
// Child nodes should not process the offset. If there is a limit,
// the child nodes need only return (offset + limit) rows.
@ -294,48 +348,33 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
// NOTICE: We must visit from right to left, to ensure the last fragment is root fragment
PlanFragment rightFragment = visit(hashJoin.child(1), context);
PlanFragment leftFragment = visit(hashJoin.child(0), context);
PhysicalHashJoin physicalHashJoin = hashJoin.getOperator();
// Expression predicateExpr = physicalHashJoin.getCondition().get();
// List<Expression> eqExprList = Utils.getEqConjuncts(hashJoin.child(0).getOutput(),
// hashJoin.child(1).getOutput(), predicateExpr);
JoinType joinType = physicalHashJoin.getJoinType();
PlanNode leftFragmentPlanRoot = leftFragment.getPlanRoot();
PlanNode rightFragmentPlanRoot = rightFragment.getPlanRoot();
PhysicalHashJoin physicalHashJoin = hashJoin.getOperator();
JoinType joinType = physicalHashJoin.getJoinType();
if (joinType.equals(JoinType.CROSS_JOIN)
|| physicalHashJoin.getJoinType().equals(JoinType.INNER_JOIN)
&& !physicalHashJoin.getCondition().isPresent()) {
CrossJoinNode crossJoinNode = new CrossJoinNode(context.nextNodeId(), leftFragment.getPlanRoot(),
rightFragment.getPlanRoot(), null);
crossJoinNode.setLimit(physicalHashJoin.getLimit());
ExchangeNode exchangeNode = new ExchangeNode(context.nextNodeId(), rightFragment.getPlanRoot(), false);
exchangeNode.setNumInstances(rightFragmentPlanRoot.getNumInstances());
exchangeNode.setFragment(leftFragment);
leftFragmentPlanRoot.setChild(1, exchangeNode);
rightFragment.setDestination(exchangeNode);
crossJoinNode.setChild(0, leftFragment.getPlanRoot());
leftFragment.setPlanRoot(crossJoinNode);
context.addPlanFragment(leftFragment);
throw new RuntimeException("Physical hash join could not execute without equal join condition.");
} else {
Expression eqJoinExpression = physicalHashJoin.getCondition().get();
List<Expr> execEqConjunctList = ExpressionUtils.extractConjunct(eqJoinExpression).stream()
.map(EqualTo.class::cast)
.map(e -> swapEqualToForChildrenOrder(e, hashJoin.left().getOutput()))
.map(e -> ExpressionTranslator.translate(e, context))
.collect(Collectors.toList());
HashJoinNode hashJoinNode = new HashJoinNode(context.nextPlanNodeId(), leftFragmentPlanRoot,
rightFragmentPlanRoot,
JoinType.toJoinOperator(physicalHashJoin.getJoinType()), execEqConjunctList, Lists.newArrayList());
hashJoinNode.setDistributionMode(DistributionMode.BROADCAST);
hashJoinNode.setChild(0, leftFragmentPlanRoot);
connectChildFragment(hashJoinNode, 1, leftFragment, rightFragment, context);
leftFragment.setPlanRoot(hashJoinNode);
return leftFragment;
}
Expression eqJoinExpression = physicalHashJoin.getCondition().get();
List<Expr> execEqConjunctList = ExpressionUtils.extractConjunct(eqJoinExpression).stream()
.map(EqualTo.class::cast)
.map(e -> swapEqualToForChildrenOrder(e, hashJoin.left().getOutput()))
.map(e -> ExpressionTranslator.translate(e, context))
.collect(Collectors.toList());
HashJoinNode hashJoinNode = new HashJoinNode(context.nextNodeId(), leftFragmentPlanRoot, rightFragmentPlanRoot,
JoinType.toJoinOperator(physicalHashJoin.getJoinType()), execEqConjunctList, Lists.newArrayList());
hashJoinNode.setDistributionMode(DistributionMode.BROADCAST);
hashJoinNode.setChild(0, leftFragmentPlanRoot);
connectChildFragment(hashJoinNode, 1, leftFragment, rightFragment, context);
leftFragment.setPlanRoot(hashJoinNode);
return leftFragment;
}
// TODO: generate expression mapping when be project could do in ExecNode
@ -384,7 +423,7 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
}
}
private TupleDescriptor generateTupleDesc(List<Slot> slotList, PlanTranslatorContext context, Table table) {
private TupleDescriptor generateTupleDesc(List<Slot> slotList, Table table, PlanTranslatorContext context) {
TupleDescriptor tupleDescriptor = context.generateTupleDesc();
tupleDescriptor.setTable(table);
for (Slot slot : slotList) {
@ -397,35 +436,69 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
PlanTranslatorContext context, Table table) {
TupleDescriptor tupleDescriptor = context.generateTupleDesc();
tupleDescriptor.setTable(table);
for (Slot slot : slotList) {
context.createSlotDesc(tupleDescriptor, (SlotReference) slot);
}
Set<ExprId> alreadyExists = Sets.newHashSet();
for (OrderKey orderKey : orderKeyList) {
context.createSlotDesc(tupleDescriptor, orderKey.getExpr());
if (orderKey.getExpr() instanceof SlotReference) {
SlotReference slotReference = (SlotReference) orderKey.getExpr();
// TODO: trick here, we need semanticEquals to remove redundant expression
if (alreadyExists.contains(slotReference.getExprId())) {
continue;
}
context.createSlotDesc(tupleDescriptor, (SlotReference) orderKey.getExpr());
alreadyExists.add(slotReference.getExprId());
}
}
for (Slot slot : slotList) {
if (alreadyExists.contains(slot.getExprId())) {
continue;
}
context.createSlotDesc(tupleDescriptor, (SlotReference) slot);
alreadyExists.add(slot.getExprId());
}
return tupleDescriptor;
}
private PlanFragment createParentFragment(PlanFragment childFragment, DataPartition parentPartition,
PlanTranslatorContext ctx) {
ExchangeNode exchangeNode = new ExchangeNode(ctx.nextNodeId(), childFragment.getPlanRoot(), false);
PlanTranslatorContext context) {
ExchangeNode exchangeNode = new ExchangeNode(context.nextPlanNodeId(), childFragment.getPlanRoot(), false);
exchangeNode.setNumInstances(childFragment.getPlanRoot().getNumInstances());
PlanFragment parentFragment = new PlanFragment(ctx.nextFragmentId(), exchangeNode, parentPartition);
PlanFragment parentFragment = new PlanFragment(context.nextFragmentId(), exchangeNode, parentPartition);
childFragment.setDestination(exchangeNode);
childFragment.setOutputPartition(parentPartition);
context.addPlanFragment(parentFragment);
return parentFragment;
}
private void connectChildFragment(PlanNode node, int childIdx,
PlanFragment parentFragment, PlanFragment childFragment,
PlanTranslatorContext context) {
ExchangeNode exchangeNode = new ExchangeNode(context.nextNodeId(), childFragment.getPlanRoot(), false);
ExchangeNode exchangeNode = new ExchangeNode(context.nextPlanNodeId(), childFragment.getPlanRoot(), false);
exchangeNode.setNumInstances(childFragment.getPlanRoot().getNumInstances());
exchangeNode.setFragment(parentFragment);
node.setChild(childIdx, exchangeNode);
childFragment.setDestination(exchangeNode);
}
/**
* Return unpartitioned fragment that merges the input fragment's output via
* an ExchangeNode.
* Requires that input fragment be partitioned.
*/
private PlanFragment exchangeToMergeFragment(PlanFragment inputFragment, PlanTranslatorContext context) {
Preconditions.checkState(inputFragment.isPartitioned());
// exchange node clones the behavior of its input, aside from the conjuncts
ExchangeNode mergePlan =
new ExchangeNode(context.nextPlanNodeId(), inputFragment.getPlanRoot(), false);
mergePlan.setNumInstances(inputFragment.getPlanRoot().getNumInstances());
PlanFragment fragment = new PlanFragment(context.nextFragmentId(), mergePlan, DataPartition.UNPARTITIONED);
inputFragment.setDestination(mergePlan);
context.addPlanFragment(fragment);
return fragment;
}
/**
* Helper function to eliminate unnecessary checked exception caught requirement from the main logic of translator.
*
@ -437,7 +510,7 @@ public class PhysicalPlanTranslator extends PlanOperatorVisitor<PlanFragment, Pl
try {
f.exec();
} catch (Exception e) {
throw new RuntimeException("Unexpected Exception: ", e);
throw new RuntimeException(e.getMessage(), e);
}
}

View File

@ -18,13 +18,13 @@
package org.apache.doris.nereids.glue.translator;
import org.apache.doris.analysis.DescriptorTable;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.SlotDescriptor;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.TupleDescriptor;
import org.apache.doris.analysis.TupleId;
import org.apache.doris.catalog.Column;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.planner.PlanFragment;
@ -50,7 +50,7 @@ public class PlanTranslatorContext {
/**
* Map expressions of new optimizer to the stale expr.
*/
private Map<Expression, Expr> expressionToExecExpr = new HashMap<>();
private final Map<ExprId, SlotRef> exprIdSlotRefMap = new HashMap<>();
private final List<ScanNode> scanNodeList = new ArrayList<>();
@ -66,7 +66,11 @@ public class PlanTranslatorContext {
return descTable.createTupleDescriptor();
}
public PlanNodeId nextNodeId() {
public PlanFragmentId nextFragmentId() {
return fragmentIdGenerator.getNextId();
}
public PlanNodeId nextPlanNodeId() {
return nodeIdGenerator.getNextId();
}
@ -74,24 +78,16 @@ public class PlanTranslatorContext {
return descTable.addSlotDescriptor(t);
}
public SlotDescriptor addSlotDesc(TupleDescriptor t, int id) {
return descTable.addSlotDescriptor(t, id);
}
public PlanFragmentId nextFragmentId() {
return fragmentIdGenerator.getNextId();
}
public void addPlanFragment(PlanFragment planFragment) {
this.planFragmentList.add(planFragment);
}
public void addSlotRefMapping(Expression expression, Expr expr) {
expressionToExecExpr.put(expression, expr);
public void addExprIdPair(ExprId exprId, SlotRef slotRef) {
exprIdSlotRefMap.put(exprId, slotRef);
}
public Expr findExpr(Expression expression) {
return expressionToExecExpr.get(expression);
public SlotRef findSlotRef(ExprId exprId) {
return exprIdSlotRefMap.get(exprId);
}
public void addScanNode(ScanNode scanNode) {
@ -114,7 +110,7 @@ public class PlanTranslatorContext {
}
slotDescriptor.setType(slotReference.getDataType().toCatalogDataType());
slotDescriptor.setIsMaterialized(true);
this.addSlotRefMapping(slotReference, new SlotRef(slotDescriptor));
this.addExprIdPair(slotReference.getExprId(), new SlotRef(slotDescriptor));
return slotDescriptor;
}
@ -122,11 +118,8 @@ public class PlanTranslatorContext {
* Create slotDesc with Expression.
*/
public void createSlotDesc(TupleDescriptor tupleDesc, Expression expression) {
if (!expressionToExecExpr.containsKey(expression)) {
SlotDescriptor slotDescriptor = this.addSlotDesc(tupleDesc);
slotDescriptor.setType(expression.getDataType().toCatalogDataType());
this.addSlotRefMapping(expression, new SlotRef(slotDescriptor));
}
SlotDescriptor slotDescriptor = this.addSlotDesc(tupleDesc);
slotDescriptor.setType(expression.getDataType().toCatalogDataType());
}
public TupleDescriptor getTupleDesc(TupleId tupleId) {

View File

@ -15,17 +15,18 @@
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs;
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.rules.analysis.BindFunction;
import org.apache.doris.nereids.rules.analysis.BindRelation;
import org.apache.doris.nereids.rules.analysis.BindSlotReference;
import org.apache.doris.nereids.rules.analysis.ProjectToGlobalAggregate;
import com.google.common.collect.ImmutableList;
/**
* Execute the analysis job.
* Execute the analysis rules.
*/
public class AnalyzeRulesJob extends BatchRulesJob {
@ -39,7 +40,8 @@ public class AnalyzeRulesJob extends BatchRulesJob {
bottomUpBatch(ImmutableList.of(
new BindRelation(),
new BindSlotReference(),
new BindFunction())
new BindFunction(),
new ProjectToGlobalAggregate())
)));
}
}

View File

@ -15,9 +15,10 @@
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs;
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.cascades.OptimizeGroupJob;
import org.apache.doris.nereids.jobs.rewrite.RewriteBottomUpJob;
import org.apache.doris.nereids.jobs.rewrite.RewriteTopDownJob;

View File

@ -0,0 +1,36 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
import com.google.common.collect.ImmutableList;
/**
* Execute the disassemble rules.
*/
public class DisassembleRulesJob extends BatchRulesJob {
public DisassembleRulesJob(PlannerContext plannerContext) {
super(plannerContext);
rulesJob.addAll(ImmutableList.of(
topDownBatch(ImmutableList.of(
new AggregateDisassemble())
)));
}
}

View File

@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs;
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.PlannerContext;

View File

@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs;
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.rules.rewrite.logical.PushPredicateThroughJoin;

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.operators;
import org.apache.doris.nereids.operators.plans.physical.PhysicalAggregation;
import org.apache.doris.nereids.operators.plans.physical.PhysicalAggregate;
import org.apache.doris.nereids.operators.plans.physical.PhysicalFilter;
import org.apache.doris.nereids.operators.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.operators.plans.physical.PhysicalHeapSort;
@ -35,7 +35,7 @@ public abstract class OperatorVisitor<R, C> {
public abstract R visitOperator(Operator operator, C context);
public R visitPhysicalAggregation(PhysicalAggregation physicalAggregation, C context) {
public R visitPhysicalAggregation(PhysicalAggregate physicalAggregate, C context) {
return null;
}

View File

@ -24,10 +24,10 @@ import org.apache.doris.analysis.AggregateInfo;
* enum of agg phase definition of stale optimizer.
*/
public enum AggPhase {
FIRST("FIRST", AggregateInfo.AggPhase.FIRST),
FIRST_MERGE("FIRST_MERGE", AggregateInfo.AggPhase.FIRST_MERGE),
SECOND("SECOND", AggregateInfo.AggPhase.SECOND),
SECOND_MERGE("SECOND_MERGE", AggregateInfo.AggPhase.SECOND_MERGE);
LOCAL("LOCAL", AggregateInfo.AggPhase.FIRST),
GLOBAL("GLOBAL", AggregateInfo.AggPhase.FIRST_MERGE),
DISTINCT_LOCAL("DISTINCT_LOCAL", AggregateInfo.AggPhase.SECOND),
DISTINCT_GLOBAL("DISTINCT_GLOBAL", AggregateInfo.AggPhase.SECOND_MERGE);
private final String name;
@ -38,8 +38,8 @@ public enum AggPhase {
this.execAggPhase = execAggPhase;
}
public boolean isMerge() {
return this == FIRST_MERGE || this == SECOND_MERGE;
public boolean isGlobal() {
return this == GLOBAL || this == DISTINCT_GLOBAL;
}
public AggregateInfo.AggPhase toExec() {

View File

@ -46,43 +46,45 @@ import java.util.Objects;
public class LogicalAggregate extends LogicalUnaryOperator {
private final boolean disassembled;
private final List<Expression> groupByExprList;
private final List<Expression> groupByExpressionList;
private final List<NamedExpression> outputExpressionList;
private List<Expression> partitionExprList;
private final List<Expression> partitionExprList;
private final AggPhase aggPhase;
/**
* Desc: Constructor for LogicalAggregation.
* Desc: Constructor for LogicalAggregate.
*/
public LogicalAggregate(List<Expression> groupByExprList, List<NamedExpression> outputExpressionList) {
super(OperatorType.LOGICAL_AGGREGATION);
this.groupByExprList = groupByExprList;
this.outputExpressionList = outputExpressionList;
this.disassembled = false;
this.aggPhase = AggPhase.FIRST;
public LogicalAggregate(List<Expression> groupByExpressionList, List<NamedExpression> outputExpressionList) {
this(groupByExpressionList, outputExpressionList, false, AggPhase.GLOBAL);
}
public LogicalAggregate(List<Expression> groupByExprList,
public LogicalAggregate(List<Expression> groupByExpressionList,
List<NamedExpression> outputExpressionList,
boolean disassembled, AggPhase aggPhase) {
this(groupByExpressionList, outputExpressionList, null, disassembled, aggPhase);
}
/**
* Whole parameters constructor for LogicalAggregate.
*/
public LogicalAggregate(List<Expression> groupByExpressionList,
List<NamedExpression> outputExpressionList,
List<Expression> partitionExprList,
boolean disassembled, AggPhase aggPhase) {
super(OperatorType.LOGICAL_AGGREGATION);
this.groupByExprList = groupByExprList;
this.groupByExpressionList = groupByExpressionList;
this.outputExpressionList = outputExpressionList;
this.partitionExprList = partitionExprList;
this.disassembled = disassembled;
this.aggPhase = aggPhase;
}
public List<Expression> getPartitionExprList() {
return partitionExprList == null ? groupByExprList : partitionExprList;
return partitionExprList == null ? groupByExpressionList : partitionExprList;
}
public void setPartitionExprList(List<Expression> partitionExprList) {
this.partitionExprList = partitionExprList;
}
public List<Expression> getGroupByExprList() {
return groupByExprList;
public List<Expression> getGroupByExpressionList() {
return groupByExpressionList;
}
public List<NamedExpression> getOutputExpressionList() {
@ -95,9 +97,9 @@ public class LogicalAggregate extends LogicalUnaryOperator {
@Override
public String toString() {
return "LogicalAggregate (" + "outputExpressionList: ["
return "LogicalAggregate (phase: [" + aggPhase.name() + "], outputExpressionList: ["
+ StringUtils.join(outputExpressionList, ", ")
+ "], groupByExprList: [" + StringUtils.join(groupByExprList, ", ") + "])";
+ "], groupByExprList: [" + StringUtils.join(groupByExpressionList, ", ") + "])";
}
@Override
@ -109,7 +111,10 @@ public class LogicalAggregate extends LogicalUnaryOperator {
@Override
public List<Expression> getExpressions() {
return new ImmutableList.Builder<Expression>().addAll(groupByExprList).addAll(outputExpressionList).build();
return new ImmutableList.Builder<Expression>()
.addAll(groupByExpressionList)
.addAll(outputExpressionList)
.build();
}
public boolean isDisassembled() {
@ -127,7 +132,7 @@ public class LogicalAggregate extends LogicalUnaryOperator {
return false;
}
LogicalAggregate that = (LogicalAggregate) o;
return Objects.equals(groupByExprList, that.groupByExprList)
return Objects.equals(groupByExpressionList, that.groupByExpressionList)
&& Objects.equals(outputExpressionList, that.outputExpressionList)
&& Objects.equals(partitionExprList, that.partitionExprList)
&& aggPhase == that.aggPhase;
@ -135,6 +140,11 @@ public class LogicalAggregate extends LogicalUnaryOperator {
@Override
public int hashCode() {
return Objects.hash(groupByExprList, outputExpressionList, partitionExprList, aggPhase);
return Objects.hash(groupByExpressionList, outputExpressionList, partitionExprList, aggPhase);
}
public LogicalAggregate withGroupByAndOutput(List<Expression> groupByExprList,
List<NamedExpression> outputExpressionList) {
return new LogicalAggregate(groupByExprList, outputExpressionList, partitionExprList, disassembled, aggPhase);
}
}

View File

@ -19,6 +19,8 @@ package org.apache.doris.nereids.operators.plans.logical;
import org.apache.doris.catalog.Table;
import org.apache.commons.lang3.StringUtils;
import java.util.List;
/**
@ -35,4 +37,9 @@ public class LogicalOlapScan extends LogicalRelation {
public LogicalOlapScan(Table table, List<String> qualifier) {
super(table, qualifier);
}
@Override
public String toString() {
return "ScanOlapTable([" + StringUtils.join(qualifier, ".") + "." + table.getName() + "])";
}
}

View File

@ -34,8 +34,8 @@ import java.util.Objects;
*/
public abstract class LogicalRelation extends LogicalLeafOperator {
private final Table table;
private final List<String> qualifier;
protected final Table table;
protected final List<String> qualifier;
/**
* Constructor for LogicalRelationPlan.

View File

@ -32,7 +32,7 @@ import java.util.List;
/**
* Physical aggregation plan operator.
*/
public class PhysicalAggregation extends PhysicalUnaryOperator {
public class PhysicalAggregate extends PhysicalUnaryOperator {
private final List<Expression> groupByExprList;
@ -52,7 +52,7 @@ public class PhysicalAggregation extends PhysicalUnaryOperator {
* @param partitionExprList partition expr list, used for analytic agg.
* @param usingStream whether it's stream agg.
*/
public PhysicalAggregation(List<Expression> groupByExprList, List<NamedExpression> outputExpressionList,
public PhysicalAggregate(List<Expression> groupByExprList, List<NamedExpression> outputExpressionList,
List<Expression> partitionExprList, AggPhase aggPhase, boolean usingStream) {
super(OperatorType.PHYSICAL_AGGREGATION);
this.groupByExprList = groupByExprList;
@ -84,7 +84,7 @@ public class PhysicalAggregation extends PhysicalUnaryOperator {
@Override
public <R, C> R accept(PlanOperatorVisitor<R, C> visitor, Plan plan, C context) {
return visitor.visitPhysicalAggregation((PhysicalUnaryPlan<PhysicalAggregation, Plan>) plan, context);
return visitor.visitPhysicalAggregate((PhysicalUnaryPlan<PhysicalAggregate, Plan>) plan, context);
}
@Override
@ -93,4 +93,10 @@ public class PhysicalAggregation extends PhysicalUnaryOperator {
return new ImmutableList.Builder<Expression>().addAll(groupByExprList).addAll(outputExpressionList)
.addAll(partitionExprList).build();
}
@Override
public String toString() {
return "PhysicalAggregate([key=" + groupByExprList
+ "], [output=" + outputExpressionList + "])";
}
}

View File

@ -68,4 +68,15 @@ public class PhysicalHashJoin extends PhysicalBinaryOperator {
public List<Expression> getExpressions() {
return condition.<List<Expression>>map(ImmutableList::of).orElseGet(ImmutableList::of);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("PhysicalHashJoin ([").append(joinType).append("]");
if (condition.isPresent()) {
sb.append(", [").append(condition.get()).append("]");
}
sb.append(")");
return sb.toString();
}
}

View File

@ -61,7 +61,7 @@ public class PhysicalHeapSort extends PhysicalUnaryOperator {
@Override
public <R, C> R accept(PlanOperatorVisitor<R, C> visitor, Plan plan, C context) {
return visitor.visitPhysicalSort((PhysicalUnaryPlan<PhysicalHeapSort, Plan>) plan, context);
return visitor.visitPhysicalHeapSort((PhysicalUnaryPlan<PhysicalHeapSort, Plan>) plan, context);
}
@Override

View File

@ -76,9 +76,8 @@ public class PhysicalOlapScan extends PhysicalScan {
@Override
public String toString() {
return "Scan Olap Table " + StringUtils.join(qualifier, ".") + "." + olapTable.getName()
+ " (selected index id: " + selectedTabletId + ", selected partition ids: " + selectedPartitionId
+ ", selected tablet ids: " + selectedTabletId + ")";
return "PhysicalOlapScan([" + StringUtils.join(qualifier, ".") + "." + olapTable.getName()
+ "], [index id=" + selectedIndexId + "])";
}
@Override

View File

@ -398,8 +398,10 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
@Override
public Literal visitStringLiteral(StringLiteralContext ctx) {
// TODO: add unescapeSQLString.
String s = ctx.STRING().stream()
.map(ParseTree::getText)
.map(str -> str.substring(1, str.length() - 1))
.reduce((s1, s2) -> s1 + s2)
.orElse("");
return new Literal(s);

View File

@ -56,6 +56,6 @@ public class OrderKey {
@Override
public String toString() {
return expr.sql();
return expr.toSql();
}
}

View File

@ -38,6 +38,7 @@ public enum RuleType {
PROJECT_TO_GLOBAL_AGGREGATE(RuleTypeClass.REWRITE),
// rewrite rules
AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE),
COLUMN_PRUNE_PROJECTION(RuleTypeClass.REWRITE),
// predicate push down rules
PUSH_DOWN_PREDICATE_THROUGH_JOIN(RuleTypeClass.REWRITE),
// column prune rules,

View File

@ -49,9 +49,9 @@ public class BindFunction implements AnalysisRuleFactory {
),
RuleType.BINDING_AGGREGATE_FUNCTION.build(
logicalAggregate().then(agg -> {
List<Expression> groupBy = bind(agg.operator.getGroupByExprList());
List<Expression> groupBy = bind(agg.operator.getGroupByExpressionList());
List<NamedExpression> output = bind(agg.operator.getOutputExpressionList());
LogicalAggregate op = new LogicalAggregate(groupBy, output);
LogicalAggregate op = agg.operator.withGroupByAndOutput(groupBy, output);
return plan(op, agg.child());
})
)
@ -60,7 +60,7 @@ public class BindFunction implements AnalysisRuleFactory {
private <E extends Expression> List<E> bind(List<E> exprList) {
return exprList.stream()
.map(expr -> FunctionBinder.INSTANCE.bind(expr))
.map(FunctionBinder.INSTANCE::bind)
.collect(Collectors.toList());
}

View File

@ -78,10 +78,10 @@ public class BindSlotReference implements AnalysisRuleFactory {
RuleType.BINDING_AGGREGATE_SLOT.build(
logicalAggregate().then(agg -> {
List<Expression> groupBy = bind(
agg.operator.getGroupByExprList(), agg.children(), agg);
agg.operator.getGroupByExpressionList(), agg.children(), agg);
List<NamedExpression> output = bind(
agg.operator.getOutputExpressionList(), agg.children(), agg);
LogicalAggregate op = new LogicalAggregate(groupBy, output);
LogicalAggregate op = agg.operator.withGroupByAndOutput(groupBy, output);
return plan(op, agg.child());
})
),
@ -146,7 +146,7 @@ public class BindSlotReference implements AnalysisRuleFactory {
return new Alias(child, ((NamedExpression) child).getName());
} else {
// TODO: resolve aliases
return new Alias(child, child.sql());
return new Alias(child, child.toSql());
}
}

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.rules.implementation;
import org.apache.doris.nereids.operators.plans.physical.PhysicalAggregation;
import org.apache.doris.nereids.operators.plans.physical.PhysicalAggregate;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.plans.Plan;
@ -29,9 +29,9 @@ public class LogicalAggToPhysicalHashAgg extends OneImplementationRuleFactory {
@Override
public Rule<Plan> build() {
return logicalAggregate().then(agg -> plan(
new PhysicalAggregation(
new PhysicalAggregate(
// TODO: for use a function to judge whether use stream
agg.getOperator().getGroupByExprList(),
agg.getOperator().getGroupByExpressionList(),
agg.getOperator().getOutputExpressionList(),
agg.getOperator().getPartitionExprList(),
agg.getOperator().getAggPhase(),

View File

@ -17,144 +17,156 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.analysis.FunctionName;
import org.apache.doris.catalog.Catalog;
import org.apache.doris.catalog.Function;
import org.apache.doris.catalog.Function.CompareMode;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.operators.Operator;
import org.apache.doris.nereids.operators.plans.AggPhase;
import org.apache.doris.nereids.operators.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnaryPlan;
import com.clearspring.analytics.util.Lists;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.HashMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* TODO: if instance count is 1, shouldn't disassemble the agg operator
* Used to generate the merge agg node for distributed execution.
* Do this in following steps:
* 1. clone output expr list, find all agg function
* 2. set found agg function intermediaType
* 3. create new child plan rooted at new local agg
* 4. update the slot referenced by expr of merge agg
* 5. create plan rooted at merge agg, return it.
* NOTICE: GLOBAL output expressions' ExprId should SAME with ORIGIN output expressions' ExprId.
* If we have a query: SELECT SUM(v1 * v2) + 1 FROM t GROUP BY k + 1
* the initial plan is:
* Aggregate(phase: [GLOBAL], outputExpr: [Alias(k + 1) #1, Alias(SUM(v1 * v2) + 1) #2], groupByExpr: [k + 1])
* +-- childPlan
* we should rewrite to:
* Aggregate(phase: [GLOBAL], outputExpr: [Alias(b) #1, Alias(SUM(a) + 1) #2], groupByExpr: [b])
* +-- Aggregate(phase: [LOCAL], outputExpr: [SUM(v1 * v2) as a, (k + 1) as b], groupByExpr: [k + 1])
* +-- childPlan
*
* TODO:
* 1. use different class represent different phase aggregate
* 2. if instance count is 1, shouldn't disassemble the agg operator
* 3. we need another rule to removing duplicated expressions in group by expression list
*/
public class AggregateDisassemble extends OneRewriteRuleFactory {
@Override
public Rule<Plan> build() {
return logicalAggregate().when(p -> {
LogicalAggregate logicalAggregation = p.getOperator();
return !logicalAggregation.isDisassembled();
LogicalAggregate logicalAggregate = p.getOperator();
return !logicalAggregate.isDisassembled();
}).thenApply(ctx -> {
Plan plan = ctx.root;
Operator operator = plan.getOperator();
LogicalAggregate agg = (LogicalAggregate) operator;
List<NamedExpression> outputExpressionList = agg.getOutputExpressionList();
List<NamedExpression> intermediateAggExpressionList = Lists.newArrayList();
// TODO: shouldn't extract agg function from this field.
for (NamedExpression namedExpression : outputExpressionList) {
namedExpression = (NamedExpression) namedExpression.clone();
List<AggregateFunction> functionCallList =
namedExpression.collect(org.apache.doris.catalog.AggregateFunction.class::isInstance);
// TODO: we will have another mechanism to get corresponding stale agg func.
for (AggregateFunction functionCall : functionCallList) {
org.apache.doris.catalog.AggregateFunction staleAggFunc = findAggFunc(functionCall);
Type staleIntermediateType = staleAggFunc.getIntermediateType();
Type staleRetType = staleAggFunc.getReturnType();
if (staleIntermediateType != null && !staleIntermediateType.equals(staleRetType)) {
functionCall.setIntermediate(DataType.convertFromCatalogDataType(staleIntermediateType));
}
}
intermediateAggExpressionList.add(namedExpression);
}
LogicalAggregate localAgg = new LogicalAggregate(
agg.getGroupByExprList().stream().map(Expression::clone).collect(Collectors.toList()),
intermediateAggExpressionList,
true,
AggPhase.FIRST
);
LogicalUnaryPlan<LogicalAggregate, GroupPlan> plan = ctx.root;
LogicalAggregate aggregate = plan.getOperator();
List<NamedExpression> originOutputExprs = aggregate.getOutputExpressionList();
List<Expression> originGroupByExprs = aggregate.getGroupByExpressionList();
Plan childPlan = plan(localAgg, plan.child(0));
List<Slot> stalePlanOutputSlotList = plan.getOutput();
List<Slot> childOutputSlotList = childPlan.getOutput();
int childOutputSize = stalePlanOutputSlotList.size();
Preconditions.checkState(childOutputSize == childOutputSlotList.size());
Map<Slot, Slot> staleToNew = new HashMap<>();
for (int i = 0; i < stalePlanOutputSlotList.size(); i++) {
staleToNew.put(stalePlanOutputSlotList.get(i), childOutputSlotList.get(i));
// 1. generate a map from local aggregate output to global aggregate expr substitution.
// inputSubstitutionMap use for replacing expression in global aggregate
// replace rule is:
// a: Expression is a group by key and is a slot reference. e.g. group by k1
// b. Expression is a group by key and is an expression. e.g. group by k1 + 1
// c. Expression is an aggregate function. e.g. sum(v1) in select list
// +-----------+---------------------+-------------------------+--------------------------------+
// | situation | origin expression | local output expression | expression in global aggregate |
// +-----------+---------------------+-------------------------+--------------------------------+
// | a | Ref(k1)#1 | Ref(k1)#1 | Ref(k1)#1 |
// +-----------+---------------------+-------------------------+--------------------------------+
// | b | Ref(k1)#1 + 1 | A(Ref(k1)#1 + 1, key)#2 | Ref(key)#2 |
// +-----------+---------------------+-------------------------+--------------------------------+
// | c | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3 | AF(af#3) |
// +-----------+---------------------+-------------------------+--------------------------------+
// NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x: ExprId x
// 2. collect local aggregate output expressions and local aggregate group by expression list
Map<Expression, Expression> inputSubstitutionMap = Maps.newHashMap();
List<Expression> localGroupByExprs = aggregate.getGroupByExpressionList();
List<NamedExpression> localOutputExprs = Lists.newArrayList();
for (Expression originGroupByExpr : originGroupByExprs) {
if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
continue;
}
if (originGroupByExpr instanceof SlotReference) {
inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
localOutputExprs.add((SlotReference) originGroupByExpr);
} else {
NamedExpression localOutputExpr = new Alias<>(originGroupByExpr, originGroupByExpr.toSql());
inputSubstitutionMap.put(originGroupByExpr, localOutputExpr.toSlot());
localOutputExprs.add(localOutputExpr);
}
}
List<Expression> groupByExpressionList = agg.getGroupByExprList();
for (int i = 0; i < groupByExpressionList.size(); i++) {
replaceSlot(staleToNew, groupByExpressionList, groupByExpressionList.get(i), i);
for (NamedExpression originOutputExpr : originOutputExprs) {
List<AggregateFunction> aggregateFunctions
= originOutputExpr.collect(AggregateFunction.class::isInstance);
for (AggregateFunction aggregateFunction : aggregateFunctions) {
if (inputSubstitutionMap.containsKey(aggregateFunction)) {
continue;
}
NamedExpression localOutputExpr = new Alias<>(aggregateFunction, aggregateFunction.toSql());
Expression substitutionValue = aggregateFunction.withChildren(
Lists.newArrayList(localOutputExpr.toSlot()));
inputSubstitutionMap.put(aggregateFunction, substitutionValue);
localOutputExprs.add(localOutputExpr);
}
}
List<NamedExpression> mergeOutputExpressionList = agg.getOutputExpressionList();
for (int i = 0; i < mergeOutputExpressionList.size(); i++) {
replaceSlot(staleToNew, mergeOutputExpressionList, mergeOutputExpressionList.get(i), i);
}
LogicalAggregate mergeAgg = new LogicalAggregate(
groupByExpressionList,
mergeOutputExpressionList,
// 3. replace expression in globalOutputExprs and globalGroupByExprs
List<NamedExpression> globalOutputExprs = aggregate.getOutputExpressionList().stream()
.map(e -> ExpressionReplacer.INSTANCE.visit(e, inputSubstitutionMap))
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
List<Expression> globalGroupByExprs = localGroupByExprs.stream()
.map(e -> ExpressionReplacer.INSTANCE.visit(e, inputSubstitutionMap)).collect(Collectors.toList());
// 4. generate new plan
LogicalAggregate localAggregate = new LogicalAggregate(
localGroupByExprs,
localOutputExprs,
true,
AggPhase.FIRST_MERGE
AggPhase.LOCAL
);
return plan(mergeAgg, childPlan);
LogicalAggregate globalAggregate = new LogicalAggregate(
globalGroupByExprs,
globalOutputExprs,
true,
AggPhase.GLOBAL
);
return plan(globalAggregate, plan(localAggregate, plan.child(0)));
}).toRule(RuleType.AGGREGATE_DISASSEMBLE);
}
private org.apache.doris.catalog.AggregateFunction findAggFunc(AggregateFunction functionCall) {
FunctionName functionName = new FunctionName(functionCall.getName());
List<Expression> expressionList = functionCall.getArguments();
List<Type> staleTypeList = expressionList.stream().map(Expression::getDataType)
.map(DataType::toCatalogDataType).collect(Collectors.toList());
Function staleFuncDesc = new Function(functionName, staleTypeList,
functionCall.getDataType().toCatalogDataType(),
// I think an aggregate function will never have a variable length parameters
false);
Function staleFunc = Catalog.getCurrentCatalog()
.getFunction(staleFuncDesc, CompareMode.IS_IDENTICAL);
Preconditions.checkArgument(staleFunc instanceof org.apache.doris.catalog.AggregateFunction);
return (org.apache.doris.catalog.AggregateFunction) staleFunc;
}
@SuppressWarnings("InnerClassMayBeStatic")
private static class ExpressionReplacer
extends ExpressionVisitor<Expression, Map<Expression, Expression>> {
private static final ExpressionReplacer INSTANCE = new ExpressionReplacer();
@SuppressWarnings("unchecked")
private <T extends Expression> void replaceSlot(Map<Slot, Slot> staleToNew,
List<T> expressionList, Expression root, int index) {
if (index != -1) {
if (root instanceof Slot) {
Slot v = staleToNew.get(root);
if (v == null) {
return;
@Override
public Expression visit(Expression expr, Map<Expression, Expression> substitutionMap) {
// TODO: we need to do sub tree match and replace. but we do not have semanticEquals now.
// e.g. a + 1 + 2 in output expression should be replaced by
// (slot reference to update phase out (a + 1)) + 2, if we do group by a + 1
// currently, we could only handle output expression same with group by expression
if (substitutionMap.containsKey(expr)) {
return substitutionMap.get(expr);
} else {
List<Expression> newChildren = new ArrayList<>();
boolean hasNewChildren = false;
for (Expression child : expr.children()) {
Expression newChild = visit(child, substitutionMap);
if (newChild != child) {
hasNewChildren = true;
}
newChildren.add(newChild);
}
expressionList.set(index, (T) v);
return;
return hasNewChildren ? expr.withChildren(newChildren) : expr;
}
}
List<Expression> children = root.children();
for (int i = 0; i < children.size(); i++) {
Expression cur = children.get(i);
if (!(cur instanceof Slot)) {
replaceSlot(staleToNew, expressionList, cur, -1);
continue;
}
Expression v = staleToNew.get(cur);
if (v == null) {
continue;
}
children.set(i, v);
}
}
}

View File

@ -90,5 +90,4 @@ public interface TreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>> {
});
return (T) result.build();
}
}

View File

@ -34,9 +34,9 @@ public class Add<LEFT_CHILD_TYPE extends Expression, RIGHT_CHILD_TYPE extends Ex
}
@Override
public String sql() {
return left().sql() + ' ' + getArithmeticOperator().toString()
+ ' ' + right().sql();
public String toSql() {
return left().toSql() + ' ' + getArithmeticOperator().toString()
+ ' ' + right().toSql();
}
@Override
@ -57,6 +57,6 @@ public class Add<LEFT_CHILD_TYPE extends Expression, RIGHT_CHILD_TYPE extends Ex
public String toString() {
return sql();
return left().toString() + ' ' + getArithmeticOperator().toString() + ' ' + right().toString();
}
}

View File

@ -44,8 +44,12 @@ public class Alias<CHILD_TYPE extends Expression> extends NamedExpression
* @param name alias name
*/
public Alias(CHILD_TYPE child, String name) {
this(NamedExpressionUtil.newExprId(), child, name);
}
private Alias(ExprId exprId, CHILD_TYPE child, String name) {
super(NodeType.ALIAS, child);
this.exprId = NamedExpressionUtil.newExprId();
this.exprId = exprId;
this.name = name;
this.qualifier = ImmutableList.of();
}
@ -76,8 +80,8 @@ public class Alias<CHILD_TYPE extends Expression> extends NamedExpression
}
@Override
public String sql() {
return null;
public String toSql() {
return child().toSql() + " AS `" + name + "`";
}
@Override
@ -87,13 +91,7 @@ public class Alias<CHILD_TYPE extends Expression> extends NamedExpression
@Override
public String toString() {
return child().toString() + " AS " + name;
}
@Override
public Alias<CHILD_TYPE> clone() {
CHILD_TYPE childType = (CHILD_TYPE) children.get(0).clone();
return new Alias<>(childType, name);
return child().toString() + " AS `" + name + "`#" + exprId;
}
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
@ -103,7 +101,7 @@ public class Alias<CHILD_TYPE extends Expression> extends NamedExpression
@Override
public Expression withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new Alias<>(children.get(0), name);
return new Alias<>(exprId, children.get(0), name);
}
}

View File

@ -176,6 +176,9 @@ public abstract class Arithmetic extends Expression {
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
Arithmetic that = (Arithmetic) o;
return op == that.op;
}
@ -187,6 +190,6 @@ public abstract class Arithmetic extends Expression {
@Override
public String toString() {
return sql();
return toSql();
}
}

View File

@ -67,8 +67,8 @@ public class Between<
}
@Override
public String sql() {
return compareExpr.sql() + " BETWEEN " + lowerBound.sql() + " AND " + upperBound.sql();
public String toSql() {
return compareExpr.toSql() + " BETWEEN " + lowerBound.toSql() + " AND " + upperBound.toSql();
}
@Override

View File

@ -53,9 +53,9 @@ public abstract class ComparisonPredicate<LEFT_CHILD_TYPE extends Expression, RI
}
@Override
public String sql() {
public String toSql() {
String nodeType = getType().toString();
return left().sql() + ' ' + nodeType + ' ' + right().sql();
return left().toSql() + ' ' + nodeType + ' ' + right().toSql();
}
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

View File

@ -43,9 +43,9 @@ public class CompoundPredicate<LEFT_CHILD_TYPE extends Expression, RIGHT_CHILD_T
}
@Override
public String sql() {
public String toSql() {
String nodeType = getType().toString();
return "(" + left().sql() + " " + nodeType + " " + right().sql() + ")";
return "(" + left().toSql() + " " + nodeType + " " + right().toSql() + ")";
}
@Override

View File

@ -33,9 +33,9 @@ public class Divide<LEFT_CHILD_TYPE extends Expression, RIGHT_CHILD_TYPE extends
}
@Override
public String sql() {
return left().sql() + ' ' + getArithmeticOperator().toString()
+ ' ' + right().sql();
public String toSql() {
return left().toSql() + ' ' + getArithmeticOperator().toString()
+ ' ' + right().toSql();
}
@Override

View File

@ -44,7 +44,7 @@ public abstract class Expression extends AbstractTreeNode<Expression> {
throw new UnboundException("dataType");
}
public String sql() throws UnboundException {
public String toSql() throws UnboundException {
throw new UnboundException("sql");
}
@ -83,10 +83,6 @@ public abstract class Expression extends AbstractTreeNode<Expression> {
return false;
}
public Expression clone() {
throw new RuntimeException("Unimplemented method");
}
@Override
public boolean equals(Object o) {
if (this == o) {
@ -98,4 +94,9 @@ public abstract class Expression extends AbstractTreeNode<Expression> {
Expression that = (Expression) o;
return Objects.equals(children(), that.children());
}
@Override
public int hashCode() {
return 0;
}
}

View File

@ -85,7 +85,7 @@ public class Literal extends Expression implements LeafExpression {
}
@Override
public String sql() {
public String toSql() {
return value.toString();
}

View File

@ -33,9 +33,9 @@ public class Mod<LEFT_CHILD_TYPE extends Expression, RIGHT_CHILD_TYPE extends Ex
}
@Override
public String sql() {
return left().sql() + ' ' + getArithmeticOperator().toString()
+ ' ' + right().sql();
public String toSql() {
return left().toSql() + ' ' + getArithmeticOperator().toString()
+ ' ' + right().toSql();
}
@Override

View File

@ -34,9 +34,9 @@ public class Multiply<LEFT_CHILD_TYPE extends Expression, RIGHT_CHILD_TYPE exten
}
@Override
public String sql() {
return left().sql() + ' ' + getArithmeticOperator().toString()
+ ' ' + right().sql();
public String toSql() {
return left().toSql() + ' ' + getArithmeticOperator().toString()
+ ' ' + right().toSql();
}
@Override

View File

@ -22,7 +22,6 @@ import org.apache.doris.nereids.trees.NodeType;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import com.clearspring.analytics.util.Lists;
import com.google.common.base.Preconditions;
import org.apache.commons.lang.StringUtils;
@ -92,7 +91,7 @@ public class SlotReference extends Slot {
}
@Override
public String sql() {
public String toSql() {
return name;
}
@ -141,11 +140,6 @@ public class SlotReference extends Slot {
return this;
}
@Override
public SlotReference clone() {
return new SlotReference(name, getDataType(), nullable, Lists.newArrayList(qualifier));
}
public Slot withNullable(boolean newNullable) {
if (this.nullable == newNullable) {
return this;

View File

@ -48,9 +48,9 @@ public abstract class StringRegexPredicate<LEFT_CHILD_TYPE extends Expression, R
}
@Override
public String sql() {
public String toSql() {
String nodeType = getType().toString();
return left().sql() + ' ' + nodeType + ' ' + right().sql();
return left().toSql() + ' ' + nodeType + ' ' + right().toSql();
}
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

View File

@ -33,9 +33,9 @@ public class Subtract<LEFT_CHILD_TYPE extends Expression, RIGHT_CHILD_TYPE exten
}
@Override
public String sql() {
return left().sql() + ' ' + getArithmeticOperator().toString()
+ ' ' + right().sql();
public String toSql() {
return left().toSql() + ' ' + getArithmeticOperator().toString()
+ ' ' + right().toSql();
}
@Override

View File

@ -22,14 +22,13 @@ import org.apache.doris.nereids.trees.NodeType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/** BoundFunction. */
public class BoundFunction extends Expression {
private String name;
public abstract class BoundFunction extends Expression {
private final String name;
public BoundFunction(String name, Expression... arguments) {
super(NodeType.BOUND_FUNCTION, arguments);
@ -67,10 +66,10 @@ public class BoundFunction extends Expression {
}
@Override
public String sql() throws UnboundException {
public String toSql() throws UnboundException {
String args = children()
.stream()
.map(Expression::sql)
.map(Expression::toSql)
.collect(Collectors.joining(", "));
return name + "(" + args + ")";
}
@ -83,13 +82,4 @@ public class BoundFunction extends Expression {
.collect(Collectors.joining(", "));
return name + "(" + args + ")";
}
@Override
public BoundFunction clone() {
List<Expression> paramList = new ArrayList<>();
for (Expression param : getArguments()) {
paramList.add(param.clone());
}
return new BoundFunction(this.name, paramList.toArray(new Expression[0]));
}
}

View File

@ -51,7 +51,7 @@ public class Sum extends AggregateFunction implements UnaryExpression<Expression
@Override
public boolean nullable() {
return false;
return child().nullable();
}
@Override

View File

@ -20,7 +20,7 @@ package org.apache.doris.nereids.trees.plans;
import org.apache.doris.nereids.operators.plans.logical.LogicalFilter;
import org.apache.doris.nereids.operators.plans.logical.LogicalJoin;
import org.apache.doris.nereids.operators.plans.logical.LogicalRelation;
import org.apache.doris.nereids.operators.plans.physical.PhysicalAggregation;
import org.apache.doris.nereids.operators.plans.physical.PhysicalAggregate;
import org.apache.doris.nereids.operators.plans.physical.PhysicalFilter;
import org.apache.doris.nereids.operators.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.operators.plans.physical.PhysicalHeapSort;
@ -67,7 +67,7 @@ public abstract class PlanOperatorVisitor<R, C> {
// Physical plans
// *******************************
public R visitPhysicalAggregation(PhysicalUnaryPlan<PhysicalAggregation, Plan> agg, C context) {
public R visitPhysicalAggregate(PhysicalUnaryPlan<PhysicalAggregate, Plan> agg, C context) {
return visit(agg, context);
}
@ -75,7 +75,7 @@ public abstract class PlanOperatorVisitor<R, C> {
return visit(olapScan, context);
}
public R visitPhysicalSort(PhysicalUnaryPlan<PhysicalHeapSort, Plan> sort, C context) {
public R visitPhysicalHeapSort(PhysicalUnaryPlan<PhysicalHeapSort, Plan> sort, C context) {
return visit(sort, context);
}

View File

@ -22,7 +22,6 @@ package org.apache.doris.planner;
import org.apache.doris.analysis.Analyzer;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.ExprId;
import org.apache.doris.analysis.ExprSubstitutionMap;
import org.apache.doris.analysis.SlotDescriptor;
import org.apache.doris.analysis.SlotId;
@ -46,7 +45,6 @@ import com.google.common.collect.Lists;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
@ -270,21 +268,18 @@ public class SortNode extends PlanNode {
/**
* Supplement the information needed by be for the sort node.
* TODO: currently we only process slotref, so when order key is a + 1, we will failed.
*/
public void finalizeForNereids(TupleDescriptor tupleDescriptor,
List<Expr> outputList, List<Expr> orderingExpr) {
List<Expr> sortTupleSlotExprs = new ArrayList<>();
sortTupleSlotExprs.addAll(outputList);
sortTupleSlotExprs.addAll(orderingExpr);
List<Expr> afterDeduplication = new ArrayList<>();
Set<ExprId> exprIds = new HashSet<>();
for (int i = 0; i < sortTupleSlotExprs.size(); i++) {
Expr expr = sortTupleSlotExprs.get(i);
if (!exprIds.contains(expr.getId())) {
afterDeduplication.add(expr);
resolvedTupleExprs = Lists.newArrayList(orderingExpr);
for (Expr output : outputList) {
if (!resolvedTupleExprs.contains(output)) {
resolvedTupleExprs.add(output);
}
}
info.setSortTupleDesc(tupleDescriptor);
info.setSortTupleSlotExprs(afterDeduplication);
info.setSortTupleSlotExprs(resolvedTupleExprs);
}
}

View File

@ -19,7 +19,7 @@ package org.apache.doris.qe;
import org.apache.doris.analysis.InsertStmt;
import org.apache.doris.analysis.KillStmt;
import org.apache.doris.analysis.QueryStmt;
import org.apache.doris.analysis.Queriable;
import org.apache.doris.analysis.SqlParser;
import org.apache.doris.analysis.SqlScanner;
import org.apache.doris.analysis.StatementBase;
@ -138,7 +138,7 @@ public class ConnectProcessor {
// ok query
MetricRepo.HISTO_QUERY_LATENCY.update(elapseMs);
if (elapseMs > Config.qe_slow_log_ms) {
String sqlDigest = DigestUtils.md5Hex(((QueryStmt) parsedStmt).toDigest());
String sqlDigest = DigestUtils.md5Hex(((Queriable) parsedStmt).toDigest());
ctx.getAuditEventBuilder().setSqlDigest(sqlDigest);
}
}

View File

@ -160,24 +160,17 @@ public class AnalyzeSSBTest extends TestWithFeService {
plannerContext.getJobScheduler().executeJobPool(plannerContext);
}
private boolean checkBound(LogicalPlan root) {
if (!checkPlanBound(root)) {
return false;
}
return true;
}
/**
* PlanNode and its expressions are all bound.
*/
private boolean checkPlanBound(LogicalPlan plan) {
private boolean checkBound(LogicalPlan plan) {
if (plan instanceof Unbound) {
return false;
}
List<Plan> children = plan.children();
for (Plan child : children) {
if (!checkPlanBound((LogicalPlan) child)) {
if (!checkBound((LogicalPlan) child)) {
return false;
}
}

View File

@ -0,0 +1,322 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.Table;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.rewrite.RewriteTopDownJob;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.operators.plans.AggPhase;
import org.apache.doris.nereids.operators.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.operators.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Literal;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.Sum;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.Plans;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnaryPlan;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import java.util.List;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class AggregateDisassembleTest implements Plans {
private Plan rStudent;
@BeforeAll
public final void beforeAll() {
Table student = new Table(0L, "student", Table.TableType.OLAP,
ImmutableList.of(new Column("id", Type.INT, true, AggregateType.NONE, true, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, true, "", ""),
new Column("age", Type.INT, true, AggregateType.NONE, true, "", "")));
rStudent = plan(new LogicalOlapScan(student, ImmutableList.of("student")));
}
/**
* the initial plan is:
* Aggregate(phase: [GLOBAL], outputExpr: [age, SUM(id) as sum], groupByExpr: [age])
* +--childPlan(id, name, age)
* we should rewrite to:
* Aggregate(phase: [GLOBAL], outputExpr: [a, SUM(b) as c], groupByExpr: [a])
* +--Aggregate(phase: [LOCAL], outputExpr: [age as a, SUM(id) as b], groupByExpr: [age])
* +--childPlan(id, name, age)
*/
@Test
public void slotReferenceGroupBy() {
List<Expression> groupExpressionList = Lists.newArrayList(
rStudent.getOutput().get(2).toSlot());
List<NamedExpression> outputExpressionList = Lists.newArrayList(
rStudent.getOutput().get(2).toSlot(),
new Alias<>(new Sum(rStudent.getOutput().get(0).toSlot()), "sum"));
Plan root = plan(new LogicalAggregate(groupExpressionList, outputExpressionList), rStudent);
Memo memo = new Memo();
memo.initialize(root);
PlannerContext plannerContext = new PlannerContext(memo, new ConnectContext());
JobContext jobContext = new JobContext(plannerContext, new PhysicalProperties(), 0);
RewriteTopDownJob rewriteTopDownJob = new RewriteTopDownJob(memo.getRoot(),
ImmutableList.of(new AggregateDisassemble().build()), jobContext);
plannerContext.pushJob(rewriteTopDownJob);
plannerContext.getJobScheduler().executeJobPool(plannerContext);
Plan after = memo.copyOut();
Assertions.assertTrue(after instanceof LogicalUnaryPlan);
Assertions.assertTrue(after.getOperator() instanceof LogicalAggregate);
Assertions.assertTrue(after.child(0) instanceof LogicalUnaryPlan);
LogicalAggregate global = (LogicalAggregate) after.getOperator();
LogicalAggregate local = (LogicalAggregate) after.child(0).getOperator();
Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
Expression localOutput0 = rStudent.getOutput().get(2).toSlot();
Expression localOutput1 = new Sum(rStudent.getOutput().get(0).toSlot());
Expression localGroupBy = rStudent.getOutput().get(2).toSlot();
Assertions.assertEquals(2, local.getOutputExpressionList().size());
Assertions.assertTrue(local.getOutputExpressionList().get(0) instanceof SlotReference);
Assertions.assertEquals(localOutput0, local.getOutputExpressionList().get(0));
Assertions.assertTrue(local.getOutputExpressionList().get(1) instanceof Alias);
Assertions.assertEquals(localOutput1, local.getOutputExpressionList().get(1).child(0));
Assertions.assertEquals(1, local.getGroupByExpressionList().size());
Assertions.assertEquals(localGroupBy, local.getGroupByExpressionList().get(0));
Expression globalOutput0 = local.getOutputExpressionList().get(0).toSlot();
Expression globalOutput1 = new Sum(local.getOutputExpressionList().get(1).toSlot());
Expression globalGroupBy = local.getOutputExpressionList().get(0).toSlot();
Assertions.assertEquals(2, global.getOutputExpressionList().size());
Assertions.assertTrue(global.getOutputExpressionList().get(0) instanceof SlotReference);
Assertions.assertEquals(globalOutput0, global.getOutputExpressionList().get(0));
Assertions.assertTrue(global.getOutputExpressionList().get(1) instanceof Alias);
Assertions.assertEquals(globalOutput1, global.getOutputExpressionList().get(1).child(0));
Assertions.assertEquals(1, global.getGroupByExpressionList().size());
Assertions.assertEquals(globalGroupBy, global.getGroupByExpressionList().get(0));
// check id:
Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
global.getOutputExpressionList().get(0).getExprId());
Assertions.assertEquals(outputExpressionList.get(1).getExprId(),
global.getOutputExpressionList().get(1).getExprId());
}
/**
* the initial plan is:
* Aggregate(phase: [GLOBAL], outputExpr: [(age + 1) as key, SUM(id) as sum], groupByExpr: [age + 1])
* +--childPlan(id, name, age)
* we should rewrite to:
* Aggregate(phase: [GLOBAL], outputExpr: [a, SUM(b) as c], groupByExpr: [a])
* +--Aggregate(phase: [LOCAL], outputExpr: [(age + 1) as a, SUM(id) as b], groupByExpr: [age + 1])
* +--childPlan(id, name, age)
*/
@Test
public void aliasGroupBy() {
List<Expression> groupExpressionList = Lists.newArrayList(
new Add<>(rStudent.getOutput().get(2).toSlot(), new Literal(1)));
List<NamedExpression> outputExpressionList = Lists.newArrayList(
new Alias<>(new Add<>(rStudent.getOutput().get(2).toSlot(), new Literal(1)), "key"),
new Alias<>(new Sum(rStudent.getOutput().get(0).toSlot()), "sum"));
Plan root = plan(new LogicalAggregate(groupExpressionList, outputExpressionList), rStudent);
Memo memo = new Memo();
memo.initialize(root);
PlannerContext plannerContext = new PlannerContext(memo, new ConnectContext());
JobContext jobContext = new JobContext(plannerContext, new PhysicalProperties(), 0);
RewriteTopDownJob rewriteTopDownJob = new RewriteTopDownJob(memo.getRoot(),
ImmutableList.of(new AggregateDisassemble().build()), jobContext);
plannerContext.pushJob(rewriteTopDownJob);
plannerContext.getJobScheduler().executeJobPool(plannerContext);
Plan after = memo.copyOut();
Assertions.assertTrue(after instanceof LogicalUnaryPlan);
Assertions.assertTrue(after.getOperator() instanceof LogicalAggregate);
Assertions.assertTrue(after.child(0) instanceof LogicalUnaryPlan);
LogicalAggregate global = (LogicalAggregate) after.getOperator();
LogicalAggregate local = (LogicalAggregate) after.child(0).getOperator();
Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
Expression localOutput0 = new Add<>(rStudent.getOutput().get(2).toSlot(), new Literal(1));
Expression localOutput1 = new Sum(rStudent.getOutput().get(0).toSlot());
Expression localGroupBy = new Add<>(rStudent.getOutput().get(2).toSlot(), new Literal(1));
Assertions.assertEquals(2, local.getOutputExpressionList().size());
Assertions.assertTrue(local.getOutputExpressionList().get(0) instanceof Alias);
Assertions.assertEquals(localOutput0, local.getOutputExpressionList().get(0).child(0));
Assertions.assertTrue(local.getOutputExpressionList().get(1) instanceof Alias);
Assertions.assertEquals(localOutput1, local.getOutputExpressionList().get(1).child(0));
Assertions.assertEquals(1, local.getGroupByExpressionList().size());
Assertions.assertEquals(localGroupBy, local.getGroupByExpressionList().get(0));
Expression globalOutput0 = local.getOutputExpressionList().get(0).toSlot();
Expression globalOutput1 = new Sum(local.getOutputExpressionList().get(1).toSlot());
Expression globalGroupBy = local.getOutputExpressionList().get(0).toSlot();
Assertions.assertEquals(2, global.getOutputExpressionList().size());
Assertions.assertTrue(global.getOutputExpressionList().get(0) instanceof Alias);
Assertions.assertEquals(globalOutput0, global.getOutputExpressionList().get(0).child(0));
Assertions.assertTrue(global.getOutputExpressionList().get(1) instanceof Alias);
Assertions.assertEquals(globalOutput1, global.getOutputExpressionList().get(1).child(0));
Assertions.assertEquals(1, global.getGroupByExpressionList().size());
Assertions.assertEquals(globalGroupBy, global.getGroupByExpressionList().get(0));
// check id:
Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
global.getOutputExpressionList().get(0).getExprId());
Assertions.assertEquals(outputExpressionList.get(1).getExprId(),
global.getOutputExpressionList().get(1).getExprId());
}
/**
* the initial plan is:
* Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr: [])
* +--childPlan(id, name, age)
* we should rewrite to:
* Aggregate(phase: [GLOBAL], outputExpr: [SUM(b) as b], groupByExpr: [])
* +--Aggregate(phase: [LOCAL], outputExpr: [SUM(id) as a], groupByExpr: [])
* +--childPlan(id, name, age)
*/
@Test
public void globalAggregate() {
List<Expression> groupExpressionList = Lists.newArrayList();
List<NamedExpression> outputExpressionList = Lists.newArrayList(
new Alias<>(new Sum(rStudent.getOutput().get(0).toSlot()), "sum"));
Plan root = plan(new LogicalAggregate(groupExpressionList, outputExpressionList), rStudent);
Memo memo = new Memo();
memo.initialize(root);
PlannerContext plannerContext = new PlannerContext(memo, new ConnectContext());
JobContext jobContext = new JobContext(plannerContext, new PhysicalProperties(), 0);
RewriteTopDownJob rewriteTopDownJob = new RewriteTopDownJob(memo.getRoot(),
ImmutableList.of(new AggregateDisassemble().build()), jobContext);
plannerContext.pushJob(rewriteTopDownJob);
plannerContext.getJobScheduler().executeJobPool(plannerContext);
Plan after = memo.copyOut();
Assertions.assertTrue(after instanceof LogicalUnaryPlan);
Assertions.assertTrue(after.getOperator() instanceof LogicalAggregate);
Assertions.assertTrue(after.child(0) instanceof LogicalUnaryPlan);
LogicalAggregate global = (LogicalAggregate) after.getOperator();
LogicalAggregate local = (LogicalAggregate) after.child(0).getOperator();
Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
Expression localOutput0 = new Sum(rStudent.getOutput().get(0).toSlot());
Assertions.assertEquals(1, local.getOutputExpressionList().size());
Assertions.assertTrue(local.getOutputExpressionList().get(0) instanceof Alias);
Assertions.assertEquals(localOutput0, local.getOutputExpressionList().get(0).child(0));
Assertions.assertEquals(0, local.getGroupByExpressionList().size());
Expression globalOutput0 = new Sum(local.getOutputExpressionList().get(0).toSlot());
Assertions.assertEquals(1, global.getOutputExpressionList().size());
Assertions.assertTrue(global.getOutputExpressionList().get(0) instanceof Alias);
Assertions.assertEquals(globalOutput0, global.getOutputExpressionList().get(0).child(0));
Assertions.assertEquals(0, global.getGroupByExpressionList().size());
// check id:
Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
global.getOutputExpressionList().get(0).getExprId());
}
/**
* the initial plan is:
* Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr: [age])
* +--childPlan(id, name, age)
* we should rewrite to:
* Aggregate(phase: [GLOBAL], outputExpr: [SUM(b) as c], groupByExpr: [a])
* +--Aggregate(phase: [LOCAL], outputExpr: [age as a, SUM(id) as b], groupByExpr: [age])
* +--childPlan(id, name, age)
*/
@Test
public void groupExpressionNotInOutput() {
List<Expression> groupExpressionList = Lists.newArrayList(
rStudent.getOutput().get(2).toSlot());
List<NamedExpression> outputExpressionList = Lists.newArrayList(
new Alias<>(new Sum(rStudent.getOutput().get(0).toSlot()), "sum"));
Plan root = plan(new LogicalAggregate(groupExpressionList, outputExpressionList), rStudent);
Memo memo = new Memo();
memo.initialize(root);
PlannerContext plannerContext = new PlannerContext(memo, new ConnectContext());
JobContext jobContext = new JobContext(plannerContext, new PhysicalProperties(), 0);
RewriteTopDownJob rewriteTopDownJob = new RewriteTopDownJob(memo.getRoot(),
ImmutableList.of(new AggregateDisassemble().build()), jobContext);
plannerContext.pushJob(rewriteTopDownJob);
plannerContext.getJobScheduler().executeJobPool(plannerContext);
Plan after = memo.copyOut();
Assertions.assertTrue(after instanceof LogicalUnaryPlan);
Assertions.assertTrue(after.getOperator() instanceof LogicalAggregate);
Assertions.assertTrue(after.child(0) instanceof LogicalUnaryPlan);
LogicalAggregate global = (LogicalAggregate) after.getOperator();
LogicalAggregate local = (LogicalAggregate) after.child(0).getOperator();
Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
Expression localOutput0 = rStudent.getOutput().get(2).toSlot();
Expression localOutput1 = new Sum(rStudent.getOutput().get(0).toSlot());
Expression localGroupBy = rStudent.getOutput().get(2).toSlot();
Assertions.assertEquals(2, local.getOutputExpressionList().size());
Assertions.assertTrue(local.getOutputExpressionList().get(0) instanceof SlotReference);
Assertions.assertEquals(localOutput0, local.getOutputExpressionList().get(0));
Assertions.assertTrue(local.getOutputExpressionList().get(1) instanceof Alias);
Assertions.assertEquals(localOutput1, local.getOutputExpressionList().get(1).child(0));
Assertions.assertEquals(1, local.getGroupByExpressionList().size());
Assertions.assertEquals(localGroupBy, local.getGroupByExpressionList().get(0));
Expression globalOutput0 = new Sum(local.getOutputExpressionList().get(1).toSlot());
Expression globalGroupBy = local.getOutputExpressionList().get(0).toSlot();
Assertions.assertEquals(1, global.getOutputExpressionList().size());
Assertions.assertTrue(global.getOutputExpressionList().get(0) instanceof Alias);
Assertions.assertEquals(globalOutput0, global.getOutputExpressionList().get(0).child(0));
Assertions.assertEquals(1, global.getGroupByExpressionList().size());
Assertions.assertEquals(globalGroupBy, global.getGroupByExpressionList().get(0));
// check id:
Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
global.getOutputExpressionList().get(0).getExprId());
}
}

View File

@ -56,7 +56,7 @@ public class AnalyzeUtils {
new RewriteBottomUpJob(memo.getRoot(), new BindSlotReference().buildRules(), jobContext));
plannerContext.pushJob(
new RewriteBottomUpJob(memo.getRoot(), new BindRelation().buildRules(), jobContext));
jobContext.getPlannerContext().getJobScheduler().executeJobPool(plannerContext);
plannerContext.getJobScheduler().executeJobPool(plannerContext);
return (LogicalPlan) memo.copyOut();
}
}

View File

@ -215,10 +215,10 @@ public class ColumnPruningTest extends TestWithFeService {
private Plan process(Memo memo) {
PlannerContext plannerContext = new PlannerContext(memo, new ConnectContext());
JobContext jobContext = new JobContext(plannerContext, new PhysicalProperties(), 0);
RewriteTopDownJob rewriteTopDownJob = new RewriteTopDownJob(memo.getRoot(), new ColumnPruning().buildRules(),
jobContext);
jobContext.getPlannerContext().pushJob(rewriteTopDownJob);
jobContext.getPlannerContext().getJobScheduler().executeJobPool(plannerContext);
RewriteTopDownJob rewriteTopDownJob = new RewriteTopDownJob(memo.getRoot(),
new ColumnPruning().buildRules(), jobContext);
plannerContext.pushJob(rewriteTopDownJob);
plannerContext.getJobScheduler().executeJobPool(plannerContext);
return memo.copyOut();
}

View File

@ -264,7 +264,7 @@ public class PushDownPredicateTest implements Plans {
Assertions.assertEquals(((LogicalJoin) join2).getCondition().get(), whereCondition2);
Assertions.assertEquals(((LogicalJoin) join3).getCondition().get(), whereCondition1);
Assertions.assertEquals(((LogicalFilter) op1).getPredicates().sql(), whereCondition3result.sql());
Assertions.assertEquals(((LogicalFilter) op1).getPredicates().toSql(), whereCondition3result.toSql());
Assertions.assertEquals(((LogicalFilter) op2).getPredicates(), whereCondition4);
}
}

View File

@ -32,7 +32,7 @@ public class ExpressionParserTest {
private void assertExpr(String expr) {
Expression expression = PARSER.parseExpression(expr);
System.out.println(expression.sql());
System.out.println(expression.toSql());
}
@Test

View File

@ -23,4 +23,4 @@ WHERE
AND p_category = 'MFGR#12'
AND s_region = 'AMERICA'
GROUP BY d_year, p_brand
ORDER BY d_year, p_brand;
ORDER BY p_brand;