[feature](Nereids) use one stage aggregation if available (#12849)
Currently, we always disassemble aggregation into two stage: local and global. However, in some case, one stage aggregation is enough, there are two advantage of one stage aggregation. 1. avoid unnecessary exchange. 2. have a chance to do colocate join on the top of aggregation. This PR move AggregateDisassemble rule from rewrite stage to optimization stage. And choose one stage or two stage aggregation according to cost.
This commit is contained in:
@ -193,9 +193,8 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
|
||||
// 3. generate output tuple
|
||||
List<Slot> slotList = Lists.newArrayList();
|
||||
TupleDescriptor outputTupleDesc;
|
||||
if (aggregate.getAggPhase() == AggPhase.LOCAL) {
|
||||
outputTupleDesc = generateTupleDesc(aggregate.getOutput(), null, context);
|
||||
} else if ((aggregate.getAggPhase() == AggPhase.GLOBAL && aggregate.isFinalPhase())
|
||||
if (aggregate.getAggPhase() == AggPhase.LOCAL
|
||||
|| (aggregate.getAggPhase() == AggPhase.GLOBAL && aggregate.isFinalPhase())
|
||||
|| aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
|
||||
slotList.addAll(groupSlotList);
|
||||
slotList.addAll(aggFunctionOutput);
|
||||
@ -222,10 +221,12 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
|
||||
outputTupleDesc, aggregate.getAggPhase().toExec());
|
||||
AggregationNode aggregationNode = new AggregationNode(context.nextPlanNodeId(),
|
||||
inputPlanFragment.getPlanRoot(), aggInfo);
|
||||
if (!aggregate.isFinalPhase()) {
|
||||
aggregationNode.unsetNeedsFinalize();
|
||||
}
|
||||
PlanFragment currentFragment = inputPlanFragment;
|
||||
switch (aggregate.getAggPhase()) {
|
||||
case LOCAL:
|
||||
aggregationNode.unsetNeedsFinalize();
|
||||
aggregationNode.setUseStreamingPreagg(aggregate.isUsingStream());
|
||||
aggregationNode.setIntermediateTuple();
|
||||
break;
|
||||
|
||||
@ -24,7 +24,6 @@ import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionOptimization;
|
||||
import org.apache.doris.nereids.rules.mv.SelectRollupWithAggregate;
|
||||
import org.apache.doris.nereids.rules.mv.SelectRollupWithoutAggregate;
|
||||
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.ColumnPruning;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.EliminateFilter;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.EliminateLimit;
|
||||
@ -68,7 +67,6 @@ public class NereidsRewriteJobExecutor extends BatchRulesJob {
|
||||
.add(topDownBatch(ImmutableList.of(new ColumnPruning())))
|
||||
.add(topDownBatch(RuleSet.PUSH_DOWN_JOIN_CONDITION_RULES, false))
|
||||
.add(topDownBatch(ImmutableList.of(new FindHashConditionForJoin())))
|
||||
.add(topDownBatch(ImmutableList.of(new AggregateDisassemble())))
|
||||
.add(topDownBatch(ImmutableList.of(new LimitPushDown())))
|
||||
.add(topDownBatch(ImmutableList.of(new EliminateLimit())))
|
||||
.add(topDownBatch(ImmutableList.of(new EliminateFilter())))
|
||||
|
||||
@ -172,6 +172,10 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
|
||||
ChildrenPropertiesRegulator regulator = new ChildrenPropertiesRegulator(groupExpression,
|
||||
lowestCostChildren, requestChildrenProperties, requestChildrenProperties, context);
|
||||
double enforceCost = regulator.adjustChildrenProperties();
|
||||
if (enforceCost < 0) {
|
||||
// invalid enforce, return.
|
||||
return;
|
||||
}
|
||||
curTotalCost += enforceCost;
|
||||
|
||||
// Not need to do pruning here because it has been done when we get the
|
||||
|
||||
@ -22,7 +22,10 @@ import org.apache.doris.nereids.cost.CostCalculator;
|
||||
import org.apache.doris.nereids.jobs.JobContext;
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
|
||||
import org.apache.doris.nereids.trees.plans.AggPhase;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
|
||||
import org.apache.doris.nereids.util.JoinUtils;
|
||||
@ -64,6 +67,16 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Double, Void> {
|
||||
return enforceCost;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double visitPhysicalAggregate(PhysicalAggregate<? extends Plan> agg, Void context) {
|
||||
if (agg.isFinalPhase()
|
||||
&& agg.getAggPhase() == AggPhase.LOCAL
|
||||
&& children.get(0).getPlan() instanceof PhysicalDistribute) {
|
||||
return -1.0;
|
||||
}
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double visitPhysicalHashJoin(PhysicalHashJoin<? extends Plan, ? extends Plan> hashJoin,
|
||||
Void context) {
|
||||
|
||||
@ -80,7 +80,7 @@ public class RequestPropertyDeriver extends PlanVisitor<Void, PlanContext> {
|
||||
@Override
|
||||
public Void visitPhysicalAggregate(PhysicalAggregate<? extends Plan> agg, PlanContext context) {
|
||||
// 1. first phase agg just return any
|
||||
if (agg.getAggPhase().isLocal()) {
|
||||
if (agg.getAggPhase().isLocal() && !agg.isFinalPhase()) {
|
||||
addToRequestPropertyToChildren(PhysicalProperties.ANY);
|
||||
return null;
|
||||
}
|
||||
|
||||
@ -37,6 +37,7 @@ import org.apache.doris.nereids.rules.implementation.LogicalOneRowRelationToPhys
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalProjectToPhysicalProject;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickSort;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN;
|
||||
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.EliminateOuter;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveFilters;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveLimits;
|
||||
@ -65,6 +66,7 @@ public class RuleSet {
|
||||
.add(SemiJoinLogicalJoinTranspose.LEFT_DEEP)
|
||||
.add(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP)
|
||||
.add(SemiJoinSemiJoinTranspose.INSTANCE)
|
||||
.add(new AggregateDisassemble())
|
||||
.add(new PushdownFilterThroughProject())
|
||||
.add(new MergeConsecutiveProjects())
|
||||
.build();
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
@ -66,36 +67,39 @@ import java.util.stream.Collectors;
|
||||
* 2. if instance count is 1, shouldn't disassemble the agg plan
|
||||
*/
|
||||
public class AggregateDisassemble extends OneRewriteRuleFactory {
|
||||
// used in secondDisassemble to transform local expressions into global
|
||||
private final Map<Expression, Expression> globalOutputSubstitutionMap = Maps.newHashMap();
|
||||
// used in secondDisassemble to transform local expressions into global
|
||||
private final Map<Expression, Expression> globalGroupBySubstitutionMap = Maps.newHashMap();
|
||||
// used to indicate the existence of a distinct function for the entire phase
|
||||
private boolean hasDistinctAgg = false;
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalAggregate().when(agg -> !agg.isDisassembled()).thenApply(ctx -> {
|
||||
LogicalAggregate<GroupPlan> aggregate = ctx.root;
|
||||
LogicalAggregate firstAggregate = firstDisassemble(aggregate);
|
||||
if (!hasDistinctAgg) {
|
||||
return firstAggregate;
|
||||
}
|
||||
return secondDisassemble(firstAggregate);
|
||||
}).toRule(RuleType.AGGREGATE_DISASSEMBLE);
|
||||
return logicalAggregate()
|
||||
.whenNot(LogicalAggregate::isDisassembled)
|
||||
.then(aggregate -> {
|
||||
// used in secondDisassemble to transform local expressions into global
|
||||
final Map<Expression, Expression> globalOutputSMap = Maps.newHashMap();
|
||||
// used in secondDisassemble to transform local expressions into global
|
||||
final Map<Expression, Expression> globalGroupBySMap = Maps.newHashMap();
|
||||
Pair<LogicalAggregate, Boolean> ret = firstDisassemble(aggregate, globalOutputSMap,
|
||||
globalGroupBySMap);
|
||||
if (!ret.second) {
|
||||
return ret.first;
|
||||
}
|
||||
return secondDisassemble(ret.first, globalOutputSMap, globalGroupBySMap);
|
||||
}).toRule(RuleType.AGGREGATE_DISASSEMBLE);
|
||||
}
|
||||
|
||||
// only support distinct function with group by
|
||||
// TODO: support distinct function without group by. (add second global phase)
|
||||
private LogicalAggregate secondDisassemble(LogicalAggregate<LogicalAggregate> aggregate) {
|
||||
private LogicalAggregate secondDisassemble(
|
||||
LogicalAggregate<LogicalAggregate> aggregate,
|
||||
Map<Expression, Expression> globalOutputSMap,
|
||||
Map<Expression, Expression> globalGroupBySMap) {
|
||||
LogicalAggregate<GroupPlan> local = aggregate.child();
|
||||
// replace expression in globalOutputExprs and globalGroupByExprs
|
||||
List<NamedExpression> globalOutputExprs = local.getOutputExpressions().stream()
|
||||
.map(e -> ExpressionUtils.replace(e, globalOutputSubstitutionMap))
|
||||
.map(e -> ExpressionUtils.replace(e, globalOutputSMap))
|
||||
.map(NamedExpression.class::cast)
|
||||
.collect(Collectors.toList());
|
||||
List<Expression> globalGroupByExprs = local.getGroupByExpressions().stream()
|
||||
.map(e -> ExpressionUtils.replace(e, globalGroupBySubstitutionMap))
|
||||
.map(e -> ExpressionUtils.replace(e, globalGroupBySMap))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// generate new plan
|
||||
@ -119,7 +123,11 @@ public class AggregateDisassemble extends OneRewriteRuleFactory {
|
||||
);
|
||||
}
|
||||
|
||||
private LogicalAggregate firstDisassemble(LogicalAggregate<GroupPlan> aggregate) {
|
||||
private Pair<LogicalAggregate, Boolean> firstDisassemble(
|
||||
LogicalAggregate<GroupPlan> aggregate,
|
||||
Map<Expression, Expression> globalOutputSMap,
|
||||
Map<Expression, Expression> globalGroupBySMap) {
|
||||
Boolean hasDistinct = Boolean.FALSE;
|
||||
List<NamedExpression> originOutputExprs = aggregate.getOutputExpressions();
|
||||
List<Expression> originGroupByExprs = aggregate.getGroupByExpressions();
|
||||
Map<Expression, Expression> inputSubstitutionMap = Maps.newHashMap();
|
||||
@ -149,14 +157,14 @@ public class AggregateDisassemble extends OneRewriteRuleFactory {
|
||||
}
|
||||
if (originGroupByExpr instanceof SlotReference) {
|
||||
inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
|
||||
globalOutputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
|
||||
globalGroupBySubstitutionMap.put(originGroupByExpr, originGroupByExpr);
|
||||
globalOutputSMap.put(originGroupByExpr, originGroupByExpr);
|
||||
globalGroupBySMap.put(originGroupByExpr, originGroupByExpr);
|
||||
localOutputExprs.add((SlotReference) originGroupByExpr);
|
||||
} else {
|
||||
NamedExpression localOutputExpr = new Alias(originGroupByExpr, originGroupByExpr.toSql());
|
||||
inputSubstitutionMap.put(originGroupByExpr, localOutputExpr.toSlot());
|
||||
globalOutputSubstitutionMap.put(localOutputExpr, localOutputExpr.toSlot());
|
||||
globalGroupBySubstitutionMap.put(originGroupByExpr, localOutputExpr.toSlot());
|
||||
globalOutputSMap.put(localOutputExpr, localOutputExpr.toSlot());
|
||||
globalGroupBySMap.put(originGroupByExpr, localOutputExpr.toSlot());
|
||||
localOutputExprs.add(localOutputExpr);
|
||||
}
|
||||
}
|
||||
@ -170,22 +178,22 @@ public class AggregateDisassemble extends OneRewriteRuleFactory {
|
||||
continue;
|
||||
}
|
||||
if (aggregateFunction.isDistinct()) {
|
||||
hasDistinctAgg = true;
|
||||
hasDistinct = Boolean.TRUE;
|
||||
for (Expression expr : aggregateFunction.children()) {
|
||||
if (expr instanceof SlotReference) {
|
||||
distinctExprsForLocalOutput.add((SlotReference) expr);
|
||||
if (!inputSubstitutionMap.containsKey(expr)) {
|
||||
inputSubstitutionMap.put(expr, expr);
|
||||
globalOutputSubstitutionMap.put(expr, expr);
|
||||
globalGroupBySubstitutionMap.put(expr, expr);
|
||||
globalOutputSMap.put(expr, expr);
|
||||
globalGroupBySMap.put(expr, expr);
|
||||
}
|
||||
} else {
|
||||
NamedExpression globalOutputExpr = new Alias(expr, expr.toSql());
|
||||
distinctExprsForLocalOutput.add(globalOutputExpr);
|
||||
if (!inputSubstitutionMap.containsKey(expr)) {
|
||||
inputSubstitutionMap.put(expr, globalOutputExpr.toSlot());
|
||||
globalOutputSubstitutionMap.put(globalOutputExpr, globalOutputExpr.toSlot());
|
||||
globalGroupBySubstitutionMap.put(expr, globalOutputExpr.toSlot());
|
||||
globalOutputSMap.put(globalOutputExpr, globalOutputExpr.toSlot());
|
||||
globalGroupBySMap.put(expr, globalOutputExpr.toSlot());
|
||||
}
|
||||
}
|
||||
distinctExprsForLocalGroupBy.add(expr);
|
||||
@ -196,7 +204,7 @@ public class AggregateDisassemble extends OneRewriteRuleFactory {
|
||||
Expression substitutionValue = aggregateFunction.withChildren(
|
||||
Lists.newArrayList(localOutputExpr.toSlot()));
|
||||
inputSubstitutionMap.put(aggregateFunction, substitutionValue);
|
||||
globalOutputSubstitutionMap.put(aggregateFunction, substitutionValue);
|
||||
globalOutputSMap.put(aggregateFunction, substitutionValue);
|
||||
localOutputExprs.add(localOutputExpr);
|
||||
}
|
||||
}
|
||||
@ -222,7 +230,7 @@ public class AggregateDisassemble extends OneRewriteRuleFactory {
|
||||
AggPhase.LOCAL,
|
||||
aggregate.child()
|
||||
);
|
||||
return new LogicalAggregate<>(
|
||||
return Pair.of(new LogicalAggregate<>(
|
||||
globalGroupByExprs,
|
||||
globalOutputExprs,
|
||||
true,
|
||||
@ -230,6 +238,6 @@ public class AggregateDisassemble extends OneRewriteRuleFactory {
|
||||
true,
|
||||
AggPhase.GLOBAL,
|
||||
localAggregate
|
||||
);
|
||||
), hasDistinct);
|
||||
}
|
||||
}
|
||||
|
||||
@ -56,6 +56,14 @@ public class MergeConsecutiveProjects extends OneRewriteRuleFactory {
|
||||
private static class ExpressionReplacer extends DefaultExpressionRewriter<Map<Expression, Expression>> {
|
||||
public static final ExpressionReplacer INSTANCE = new ExpressionReplacer();
|
||||
|
||||
public Expression replace(Expression expr, Map<Expression, Expression> substitutionMap) {
|
||||
if (expr instanceof SlotReference) {
|
||||
Slot ref = ((SlotReference) expr).withQualifier(Collections.emptyList());
|
||||
return substitutionMap.getOrDefault(ref, expr);
|
||||
}
|
||||
return visit(expr, substitutionMap);
|
||||
}
|
||||
|
||||
/**
|
||||
* case 1:
|
||||
* project(alias(c) as d, alias(x) as y)
|
||||
@ -84,7 +92,7 @@ public class MergeConsecutiveProjects extends OneRewriteRuleFactory {
|
||||
Slot ref = ((SlotReference) expr).withQualifier(Collections.emptyList());
|
||||
if (substitutionMap.containsKey(ref)) {
|
||||
Alias res = (Alias) substitutionMap.get(ref);
|
||||
return (res.child() instanceof SlotReference) ? res : res.child();
|
||||
return res.child();
|
||||
}
|
||||
} else if (substitutionMap.containsKey(expr)) {
|
||||
return substitutionMap.get(expr).child(0);
|
||||
@ -106,7 +114,7 @@ public class MergeConsecutiveProjects extends OneRewriteRuleFactory {
|
||||
);
|
||||
|
||||
projectExpressions = projectExpressions.stream()
|
||||
.map(e -> MergeConsecutiveProjects.ExpressionReplacer.INSTANCE.visit(e, childAliasMap))
|
||||
.map(e -> MergeConsecutiveProjects.ExpressionReplacer.INSTANCE.replace(e, childAliasMap))
|
||||
.map(NamedExpression.class::cast)
|
||||
.collect(Collectors.toList());
|
||||
return new LogicalProject<>(projectExpressions, childProject.children().get(0));
|
||||
|
||||
@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
@ -61,7 +62,7 @@ import java.util.stream.Collectors;
|
||||
public class NormalizeAggregate extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalAggregate().when(aggregate -> !aggregate.isNormalized()).then(aggregate -> {
|
||||
return logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> {
|
||||
// substitution map used to substitute expression in aggregate's output to use it as top projections
|
||||
Map<Expression, Expression> substitutionMap = Maps.newHashMap();
|
||||
List<Expression> keys = aggregate.getGroupByExpressions();
|
||||
@ -102,25 +103,40 @@ public class NormalizeAggregate extends OneRewriteRuleFactory {
|
||||
List<NamedExpression> outputs = aggregate.getOutputExpressions();
|
||||
Map<Boolean, List<NamedExpression>> partitionedOutputs = outputs.stream()
|
||||
.collect(Collectors.groupingBy(e -> e.anyMatch(AggregateFunction.class::isInstance)));
|
||||
|
||||
boolean needBottomProjects = partitionedKeys.containsKey(false);
|
||||
if (partitionedOutputs.containsKey(true)) {
|
||||
// process expressions that contain aggregate function
|
||||
Set<AggregateFunction> aggregateFunctions = partitionedOutputs.get(true).stream()
|
||||
.flatMap(e -> e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream())
|
||||
.collect(Collectors.toSet());
|
||||
newOutputs.addAll(aggregateFunctions.stream()
|
||||
.map(f -> new Alias(f, f.toSql()))
|
||||
.peek(a -> substitutionMap.put(a.child(), a.toSlot()))
|
||||
.collect(Collectors.toList()));
|
||||
// add slot references in aggregate function to bottom projections
|
||||
bottomProjections.addAll(aggregateFunctions.stream()
|
||||
.flatMap(f -> f.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
|
||||
.map(SlotReference.class::cast)
|
||||
.collect(Collectors.toSet()));
|
||||
|
||||
// replace all non-slot expression in aggregate functions children.
|
||||
for (AggregateFunction aggregateFunction : aggregateFunctions) {
|
||||
List<Expression> newChildren = Lists.newArrayList();
|
||||
for (Expression child : aggregateFunction.getArguments()) {
|
||||
if (child instanceof SlotReference || child instanceof Literal) {
|
||||
newChildren.add(child);
|
||||
if (child instanceof SlotReference) {
|
||||
bottomProjections.add((SlotReference) child);
|
||||
}
|
||||
} else {
|
||||
needBottomProjects = true;
|
||||
Alias alias = new Alias(child, child.toSql());
|
||||
bottomProjections.add(alias);
|
||||
newChildren.add(alias.toSlot());
|
||||
}
|
||||
}
|
||||
AggregateFunction newFunction = (AggregateFunction) aggregateFunction.withChildren(newChildren);
|
||||
Alias alias = new Alias(newFunction, newFunction.toSql());
|
||||
newOutputs.add(alias);
|
||||
substitutionMap.put(aggregateFunction, alias.toSlot());
|
||||
}
|
||||
}
|
||||
|
||||
// assemble
|
||||
LogicalPlan root = aggregate.child();
|
||||
if (partitionedKeys.containsKey(false)) {
|
||||
if (needBottomProjects) {
|
||||
root = new LogicalProject<>(bottomProjections, root);
|
||||
}
|
||||
root = new LogicalAggregate<>(newKeys, newOutputs, aggregate.isDisassembled(),
|
||||
|
||||
@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions;
|
||||
|
||||
import org.apache.doris.analysis.ArithmeticExpr.Operator;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
import org.apache.doris.nereids.types.coercion.NumericType;
|
||||
|
||||
@ -41,6 +42,11 @@ public class Add extends BinaryArithmetic {
|
||||
return new Add(children.get(0), children.get(1));
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDataType() {
|
||||
return left().getDataType().promotion();
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitAdd(this, context);
|
||||
|
||||
@ -44,24 +44,24 @@ public class ExecutableFunctions {
|
||||
* Executable arithmetic functions
|
||||
*/
|
||||
|
||||
@ExecFunction(name = "add", argTypes = {"TINYINT", "TINYINT"}, returnType = "TINYINT")
|
||||
public static TinyIntLiteral addTinyint(TinyIntLiteral first, TinyIntLiteral second) {
|
||||
byte result = (byte) Math.addExact(first.getValue(), second.getValue());
|
||||
return new TinyIntLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "add", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "SMALLINT")
|
||||
public static SmallIntLiteral addSmallint(SmallIntLiteral first, SmallIntLiteral second) {
|
||||
@ExecFunction(name = "add", argTypes = {"TINYINT", "TINYINT"}, returnType = "SMALLINT")
|
||||
public static SmallIntLiteral addTinyint(TinyIntLiteral first, TinyIntLiteral second) {
|
||||
short result = (short) Math.addExact(first.getValue(), second.getValue());
|
||||
return new SmallIntLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "add", argTypes = {"INT", "INT"}, returnType = "INT")
|
||||
public static IntegerLiteral addInt(IntegerLiteral first, IntegerLiteral second) {
|
||||
@ExecFunction(name = "add", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "INT")
|
||||
public static IntegerLiteral addSmallint(SmallIntLiteral first, SmallIntLiteral second) {
|
||||
int result = Math.addExact(first.getValue(), second.getValue());
|
||||
return new IntegerLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "add", argTypes = {"INT", "INT"}, returnType = "BIGINT")
|
||||
public static BigIntLiteral addInt(IntegerLiteral first, IntegerLiteral second) {
|
||||
long result = Math.addExact(first.getValue(), second.getValue());
|
||||
return new BigIntLiteral(result);
|
||||
}
|
||||
|
||||
@ExecFunction(name = "add", argTypes = {"BIGINT", "BIGINT"}, returnType = "BIGINT")
|
||||
public static BigIntLiteral addBigint(BigIntLiteral first, BigIntLiteral second) {
|
||||
long result = Math.addExact(first.getValue(), second.getValue());
|
||||
|
||||
@ -60,9 +60,10 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
|
||||
private final AggPhase aggPhase;
|
||||
|
||||
// use for scenes containing distinct agg
|
||||
// 1. If there are LOCAL and GLOBAL phases, global is the final phase
|
||||
// 2. If there are LOCAL, GLOBAL and DISTINCT_LOCAL phases, DISTINCT_LOCAL is the final phase
|
||||
// 3. If there are LOCAL, GLOBAL, DISTINCT_LOCAL, DISTINCT_GLOBAL phases,
|
||||
// 1. If there is LOCAL only, LOCAL is the final phase
|
||||
// 2. If there are LOCAL and GLOBAL phases, global is the final phase
|
||||
// 3. If there are LOCAL, GLOBAL and DISTINCT_LOCAL phases, DISTINCT_LOCAL is the final phase
|
||||
// 4. If there are LOCAL, GLOBAL, DISTINCT_LOCAL, DISTINCT_GLOBAL phases,
|
||||
// DISTINCT_GLOBAL is the final phase
|
||||
private final boolean isFinalPhase;
|
||||
|
||||
@ -73,7 +74,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
|
||||
List<Expression> groupByExpressions,
|
||||
List<NamedExpression> outputExpressions,
|
||||
CHILD_TYPE child) {
|
||||
this(groupByExpressions, outputExpressions, false, false, true, AggPhase.GLOBAL, child);
|
||||
this(groupByExpressions, outputExpressions, false, false, true, AggPhase.LOCAL, child);
|
||||
}
|
||||
|
||||
public LogicalAggregate(
|
||||
|
||||
@ -34,6 +34,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.TinyIntType;
|
||||
@ -256,7 +257,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
|
||||
NamedExpressionUtil.clear();
|
||||
|
||||
sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 HAVING SUM(a1 + a2 + 3) > 0";
|
||||
Alias sumA1A23 = new Alias(new ExprId(4), new Sum(new Add(new Add(a1, a2), new TinyIntLiteral((byte) 3))),
|
||||
Alias sumA1A23 = new Alias(new ExprId(4), new Sum(new Add(new Add(a1, a2), new SmallIntLiteral((byte) 3))),
|
||||
"sum(((a1 + a2) + 3))");
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matchesFromRoot(
|
||||
@ -360,7 +361,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
|
||||
ImmutableList.of("default_cluster:test_having", "t1")
|
||||
);
|
||||
Alias pk1 = new Alias(new ExprId(7), new Add(pk, Literal.of((byte) 1)), "(pk + 1)");
|
||||
Alias pk11 = new Alias(new ExprId(8), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)");
|
||||
Alias pk11 = new Alias(new ExprId(8), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((short) 1)), "((pk + 1) + 1)");
|
||||
Alias pk2 = new Alias(new ExprId(9), new Add(pk, Literal.of((byte) 2)), "(pk + 2)");
|
||||
Alias sumA1 = new Alias(new ExprId(10), new Sum(a1), "SUM(a1)");
|
||||
Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1, false), Literal.of(1L)), "(COUNT(a1) + 1)");
|
||||
|
||||
@ -150,7 +150,7 @@ public class RequestPropertyDeriverTest {
|
||||
Lists.newArrayList(key),
|
||||
AggPhase.LOCAL,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
logicalProperties,
|
||||
groupPlan
|
||||
);
|
||||
|
||||
@ -171,9 +171,9 @@ public class FoldConstantTest {
|
||||
@Test
|
||||
public void testArithmeticFold() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
|
||||
assertRewrite("1 + 1", Literal.of((byte) 2));
|
||||
assertRewrite("1 + 1", Literal.of((short) 2));
|
||||
assertRewrite("1 - 1", Literal.of((byte) 0));
|
||||
assertRewrite("100 + 100", Literal.of((byte) 200));
|
||||
assertRewrite("100 + 100", Literal.of((short) 200));
|
||||
assertRewrite("1 - 2", Literal.of((byte) -1));
|
||||
|
||||
assertRewrite("1 - 2 > 1", "false");
|
||||
@ -284,9 +284,9 @@ public class FoldConstantTest {
|
||||
private void assertRewrite(String expression, String expected) {
|
||||
Map<String, Slot> mem = Maps.newHashMap();
|
||||
Expression needRewriteExpression = PARSER.parseExpression(expression);
|
||||
needRewriteExpression = replaceUnboundSlot(needRewriteExpression, mem);
|
||||
needRewriteExpression = typeCoercion(replaceUnboundSlot(needRewriteExpression, mem));
|
||||
Expression expectedExpression = PARSER.parseExpression(expected);
|
||||
expectedExpression = replaceUnboundSlot(expectedExpression, mem);
|
||||
expectedExpression = typeCoercion(replaceUnboundSlot(expectedExpression, mem));
|
||||
Expression rewrittenExpression = executor.rewrite(needRewriteExpression);
|
||||
Assertions.assertEquals(expectedExpression, rewrittenExpression);
|
||||
}
|
||||
@ -320,6 +320,10 @@ public class FoldConstantTest {
|
||||
return hasNewChildren ? expression.withChildren(children) : expression;
|
||||
}
|
||||
|
||||
private Expression typeCoercion(Expression expression) {
|
||||
return TypeCoercion.INSTANCE.visit(expression, null);
|
||||
}
|
||||
|
||||
private DataType getType(char t) {
|
||||
switch (t) {
|
||||
case 'T':
|
||||
|
||||
@ -89,7 +89,8 @@ public class MergeConsecutiveProjectsTest {
|
||||
relation);
|
||||
LogicalProject project2 = new LogicalProject<>(
|
||||
Lists.newArrayList(
|
||||
new Alias(new Add(aliasRef, new IntegerLiteral(2)), "Y")
|
||||
new Alias(new Add(aliasRef, new IntegerLiteral(2)), "Y"),
|
||||
aliasRef
|
||||
),
|
||||
project1);
|
||||
|
||||
@ -100,11 +101,12 @@ public class MergeConsecutiveProjectsTest {
|
||||
System.out.println(plan.treeString());
|
||||
Assertions.assertTrue(plan instanceof LogicalProject);
|
||||
LogicalProject finalProject = (LogicalProject) plan;
|
||||
Add finalExpression = new Add(
|
||||
Add aPlus1Plus2 = new Add(
|
||||
new Add(colA, new IntegerLiteral(1)),
|
||||
new IntegerLiteral(2)
|
||||
);
|
||||
Assertions.assertEquals(1, finalProject.getProjects().size());
|
||||
Assertions.assertEquals(((Alias) finalProject.getProjects().get(0)).child(), finalExpression);
|
||||
Assertions.assertEquals(2, finalProject.getProjects().size());
|
||||
Assertions.assertEquals(aPlus1Plus2, ((Alias) finalProject.getProjects().get(0)).child());
|
||||
Assertions.assertEquals(alias, finalProject.getProjects().get(1));
|
||||
}
|
||||
}
|
||||
|
||||
@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Multiply;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
@ -93,15 +94,17 @@ public class NormalizeAggregateTest implements PatternMatchSupported {
|
||||
* +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, age#3])
|
||||
*
|
||||
* after rewrite:
|
||||
* LogicalProject ((sum((id * 1))#5 + 2) AS `(sum((id * 1)) + 2)`#4)
|
||||
* +--LogicalAggregate (phase: [GLOBAL], output: [sum((id#0 * 1)) AS `sum((id * 1))`#5], groupBy: [name#2])
|
||||
* +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, age#3])
|
||||
* LogicalProject ( projects=[(sum((id * 1))#6 + 2) AS `(sum((id * 1)) + 2)`#4] )
|
||||
* +--LogicalAggregate ( phase=LOCAL, outputExpr=[sum((id * 1)#5) AS `sum((id * 1))`#6], groupByExpr=[name#2] )
|
||||
* +--LogicalProject ( projects=[name#2, (id#0 * 1) AS `(id * 1)`#5] )
|
||||
* +--GroupPlan( GroupId#0 )
|
||||
*/
|
||||
@Test
|
||||
public void testComplexFuncWithComplexOutputOfFunc() {
|
||||
NamedExpression key = rStudent.getOutput().get(2).toSlot();
|
||||
List<Expression> groupExpressionList = Lists.newArrayList(key);
|
||||
Expression aggregateFunction = new Sum(new Multiply(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(1)));
|
||||
Expression multiply = new Multiply(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(1));
|
||||
Expression aggregateFunction = new Sum(multiply);
|
||||
Expression complexOutput = new Add(aggregateFunction, new IntegerLiteral(2));
|
||||
Alias output = new Alias(complexOutput, complexOutput.toSql());
|
||||
List<NamedExpression> outputExpressionList = Lists.newArrayList(output);
|
||||
@ -112,10 +115,14 @@ public class NormalizeAggregateTest implements PatternMatchSupported {
|
||||
.matchesFromRoot(
|
||||
logicalProject(
|
||||
logicalAggregate(
|
||||
logicalOlapScan()
|
||||
logicalProject(
|
||||
logicalOlapScan()
|
||||
).when(project -> project.getProjects().size() == 2)
|
||||
.when(project -> project.getProjects().get(0) instanceof SlotReference)
|
||||
.when(project -> project.getProjects().get(1).child(0).equals(multiply))
|
||||
).when(FieldChecker.check("groupByExpressions", ImmutableList.of(key)))
|
||||
.when(aggregate -> aggregate.getOutputExpressions().size() == 1)
|
||||
.when(aggregate -> aggregate.getOutputExpressions().get(0).child(0).equals(aggregateFunction))
|
||||
.when(aggregate -> aggregate.getOutputExpressions().get(0).child(0) instanceof AggregateFunction)
|
||||
).when(project -> project.getProjects().size() == 1)
|
||||
.when(project -> project.getProjects().get(0) instanceof Alias)
|
||||
.when(project -> project.getProjects().get(0).getExprId().equals(output.getExprId()))
|
||||
|
||||
@ -57,7 +57,7 @@ public class PlanToStringTest {
|
||||
new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList())), child);
|
||||
|
||||
Assertions.assertTrue(plan.toString()
|
||||
.matches("LogicalAggregate \\( phase=GLOBAL, outputExpr=\\[a#\\d+], groupByExpr=\\[] \\)"));
|
||||
.matches("LogicalAggregate \\( phase=LOCAL, outputExpr=\\[a#\\d+], groupByExpr=\\[] \\)"));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@ -25,6 +25,7 @@ import org.apache.doris.nereids.parser.NereidsParser;
|
||||
import org.apache.doris.nereids.properties.PhysicalProperties;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleSet;
|
||||
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.ApplyPullFilterOnAgg;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.ApplyPullFilterOnProjectUnderAgg;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.ExistsApplyToJoin;
|
||||
@ -120,7 +121,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte
|
||||
new MockUp<RuleSet>() {
|
||||
@Mock
|
||||
public List<Rule> getExplorationRules() {
|
||||
return Lists.newArrayList();
|
||||
return Lists.newArrayList(new AggregateDisassemble().build());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user