[feature](Nereids) use one stage aggregation if available (#12849)

Currently, we always disassemble aggregation into two stage: local and global. However, in some case, one stage aggregation is enough, there are two advantage of one stage aggregation.
1. avoid unnecessary exchange.
2. have a chance to do colocate join on the top of aggregation.

This PR move AggregateDisassemble rule from rewrite stage to optimization stage. And choose one stage or two stage aggregation according to cost.
This commit is contained in:
morrySnow
2022-09-28 10:38:03 +08:00
committed by GitHub
parent 1ba9e4b568
commit eef9367705
19 changed files with 155 additions and 83 deletions

View File

@ -193,9 +193,8 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
// 3. generate output tuple
List<Slot> slotList = Lists.newArrayList();
TupleDescriptor outputTupleDesc;
if (aggregate.getAggPhase() == AggPhase.LOCAL) {
outputTupleDesc = generateTupleDesc(aggregate.getOutput(), null, context);
} else if ((aggregate.getAggPhase() == AggPhase.GLOBAL && aggregate.isFinalPhase())
if (aggregate.getAggPhase() == AggPhase.LOCAL
|| (aggregate.getAggPhase() == AggPhase.GLOBAL && aggregate.isFinalPhase())
|| aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
slotList.addAll(groupSlotList);
slotList.addAll(aggFunctionOutput);
@ -222,10 +221,12 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
outputTupleDesc, aggregate.getAggPhase().toExec());
AggregationNode aggregationNode = new AggregationNode(context.nextPlanNodeId(),
inputPlanFragment.getPlanRoot(), aggInfo);
if (!aggregate.isFinalPhase()) {
aggregationNode.unsetNeedsFinalize();
}
PlanFragment currentFragment = inputPlanFragment;
switch (aggregate.getAggPhase()) {
case LOCAL:
aggregationNode.unsetNeedsFinalize();
aggregationNode.setUseStreamingPreagg(aggregate.isUsingStream());
aggregationNode.setIntermediateTuple();
break;

View File

@ -24,7 +24,6 @@ import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionOptimization;
import org.apache.doris.nereids.rules.mv.SelectRollupWithAggregate;
import org.apache.doris.nereids.rules.mv.SelectRollupWithoutAggregate;
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
import org.apache.doris.nereids.rules.rewrite.logical.ColumnPruning;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateFilter;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateLimit;
@ -68,7 +67,6 @@ public class NereidsRewriteJobExecutor extends BatchRulesJob {
.add(topDownBatch(ImmutableList.of(new ColumnPruning())))
.add(topDownBatch(RuleSet.PUSH_DOWN_JOIN_CONDITION_RULES, false))
.add(topDownBatch(ImmutableList.of(new FindHashConditionForJoin())))
.add(topDownBatch(ImmutableList.of(new AggregateDisassemble())))
.add(topDownBatch(ImmutableList.of(new LimitPushDown())))
.add(topDownBatch(ImmutableList.of(new EliminateLimit())))
.add(topDownBatch(ImmutableList.of(new EliminateFilter())))

View File

@ -172,6 +172,10 @@ public class CostAndEnforcerJob extends Job implements Cloneable {
ChildrenPropertiesRegulator regulator = new ChildrenPropertiesRegulator(groupExpression,
lowestCostChildren, requestChildrenProperties, requestChildrenProperties, context);
double enforceCost = regulator.adjustChildrenProperties();
if (enforceCost < 0) {
// invalid enforce, return.
return;
}
curTotalCost += enforceCost;
// Not need to do pruning here because it has been done when we get the

View File

@ -22,7 +22,10 @@ import org.apache.doris.nereids.cost.CostCalculator;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
import org.apache.doris.nereids.trees.plans.AggPhase;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.JoinUtils;
@ -64,6 +67,16 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Double, Void> {
return enforceCost;
}
@Override
public Double visitPhysicalAggregate(PhysicalAggregate<? extends Plan> agg, Void context) {
if (agg.isFinalPhase()
&& agg.getAggPhase() == AggPhase.LOCAL
&& children.get(0).getPlan() instanceof PhysicalDistribute) {
return -1.0;
}
return 0.0;
}
@Override
public Double visitPhysicalHashJoin(PhysicalHashJoin<? extends Plan, ? extends Plan> hashJoin,
Void context) {

View File

@ -80,7 +80,7 @@ public class RequestPropertyDeriver extends PlanVisitor<Void, PlanContext> {
@Override
public Void visitPhysicalAggregate(PhysicalAggregate<? extends Plan> agg, PlanContext context) {
// 1. first phase agg just return any
if (agg.getAggPhase().isLocal()) {
if (agg.getAggPhase().isLocal() && !agg.isFinalPhase()) {
addToRequestPropertyToChildren(PhysicalProperties.ANY);
return null;
}

View File

@ -37,6 +37,7 @@ import org.apache.doris.nereids.rules.implementation.LogicalOneRowRelationToPhys
import org.apache.doris.nereids.rules.implementation.LogicalProjectToPhysicalProject;
import org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickSort;
import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN;
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateOuter;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveFilters;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveLimits;
@ -65,6 +66,7 @@ public class RuleSet {
.add(SemiJoinLogicalJoinTranspose.LEFT_DEEP)
.add(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP)
.add(SemiJoinSemiJoinTranspose.INSTANCE)
.add(new AggregateDisassemble())
.add(new PushdownFilterThroughProject())
.add(new MergeConsecutiveProjects())
.build();

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
@ -66,36 +67,39 @@ import java.util.stream.Collectors;
* 2. if instance count is 1, shouldn't disassemble the agg plan
*/
public class AggregateDisassemble extends OneRewriteRuleFactory {
// used in secondDisassemble to transform local expressions into global
private final Map<Expression, Expression> globalOutputSubstitutionMap = Maps.newHashMap();
// used in secondDisassemble to transform local expressions into global
private final Map<Expression, Expression> globalGroupBySubstitutionMap = Maps.newHashMap();
// used to indicate the existence of a distinct function for the entire phase
private boolean hasDistinctAgg = false;
@Override
public Rule build() {
return logicalAggregate().when(agg -> !agg.isDisassembled()).thenApply(ctx -> {
LogicalAggregate<GroupPlan> aggregate = ctx.root;
LogicalAggregate firstAggregate = firstDisassemble(aggregate);
if (!hasDistinctAgg) {
return firstAggregate;
}
return secondDisassemble(firstAggregate);
}).toRule(RuleType.AGGREGATE_DISASSEMBLE);
return logicalAggregate()
.whenNot(LogicalAggregate::isDisassembled)
.then(aggregate -> {
// used in secondDisassemble to transform local expressions into global
final Map<Expression, Expression> globalOutputSMap = Maps.newHashMap();
// used in secondDisassemble to transform local expressions into global
final Map<Expression, Expression> globalGroupBySMap = Maps.newHashMap();
Pair<LogicalAggregate, Boolean> ret = firstDisassemble(aggregate, globalOutputSMap,
globalGroupBySMap);
if (!ret.second) {
return ret.first;
}
return secondDisassemble(ret.first, globalOutputSMap, globalGroupBySMap);
}).toRule(RuleType.AGGREGATE_DISASSEMBLE);
}
// only support distinct function with group by
// TODO: support distinct function without group by. (add second global phase)
private LogicalAggregate secondDisassemble(LogicalAggregate<LogicalAggregate> aggregate) {
private LogicalAggregate secondDisassemble(
LogicalAggregate<LogicalAggregate> aggregate,
Map<Expression, Expression> globalOutputSMap,
Map<Expression, Expression> globalGroupBySMap) {
LogicalAggregate<GroupPlan> local = aggregate.child();
// replace expression in globalOutputExprs and globalGroupByExprs
List<NamedExpression> globalOutputExprs = local.getOutputExpressions().stream()
.map(e -> ExpressionUtils.replace(e, globalOutputSubstitutionMap))
.map(e -> ExpressionUtils.replace(e, globalOutputSMap))
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
List<Expression> globalGroupByExprs = local.getGroupByExpressions().stream()
.map(e -> ExpressionUtils.replace(e, globalGroupBySubstitutionMap))
.map(e -> ExpressionUtils.replace(e, globalGroupBySMap))
.collect(Collectors.toList());
// generate new plan
@ -119,7 +123,11 @@ public class AggregateDisassemble extends OneRewriteRuleFactory {
);
}
private LogicalAggregate firstDisassemble(LogicalAggregate<GroupPlan> aggregate) {
private Pair<LogicalAggregate, Boolean> firstDisassemble(
LogicalAggregate<GroupPlan> aggregate,
Map<Expression, Expression> globalOutputSMap,
Map<Expression, Expression> globalGroupBySMap) {
Boolean hasDistinct = Boolean.FALSE;
List<NamedExpression> originOutputExprs = aggregate.getOutputExpressions();
List<Expression> originGroupByExprs = aggregate.getGroupByExpressions();
Map<Expression, Expression> inputSubstitutionMap = Maps.newHashMap();
@ -149,14 +157,14 @@ public class AggregateDisassemble extends OneRewriteRuleFactory {
}
if (originGroupByExpr instanceof SlotReference) {
inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
globalOutputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
globalGroupBySubstitutionMap.put(originGroupByExpr, originGroupByExpr);
globalOutputSMap.put(originGroupByExpr, originGroupByExpr);
globalGroupBySMap.put(originGroupByExpr, originGroupByExpr);
localOutputExprs.add((SlotReference) originGroupByExpr);
} else {
NamedExpression localOutputExpr = new Alias(originGroupByExpr, originGroupByExpr.toSql());
inputSubstitutionMap.put(originGroupByExpr, localOutputExpr.toSlot());
globalOutputSubstitutionMap.put(localOutputExpr, localOutputExpr.toSlot());
globalGroupBySubstitutionMap.put(originGroupByExpr, localOutputExpr.toSlot());
globalOutputSMap.put(localOutputExpr, localOutputExpr.toSlot());
globalGroupBySMap.put(originGroupByExpr, localOutputExpr.toSlot());
localOutputExprs.add(localOutputExpr);
}
}
@ -170,22 +178,22 @@ public class AggregateDisassemble extends OneRewriteRuleFactory {
continue;
}
if (aggregateFunction.isDistinct()) {
hasDistinctAgg = true;
hasDistinct = Boolean.TRUE;
for (Expression expr : aggregateFunction.children()) {
if (expr instanceof SlotReference) {
distinctExprsForLocalOutput.add((SlotReference) expr);
if (!inputSubstitutionMap.containsKey(expr)) {
inputSubstitutionMap.put(expr, expr);
globalOutputSubstitutionMap.put(expr, expr);
globalGroupBySubstitutionMap.put(expr, expr);
globalOutputSMap.put(expr, expr);
globalGroupBySMap.put(expr, expr);
}
} else {
NamedExpression globalOutputExpr = new Alias(expr, expr.toSql());
distinctExprsForLocalOutput.add(globalOutputExpr);
if (!inputSubstitutionMap.containsKey(expr)) {
inputSubstitutionMap.put(expr, globalOutputExpr.toSlot());
globalOutputSubstitutionMap.put(globalOutputExpr, globalOutputExpr.toSlot());
globalGroupBySubstitutionMap.put(expr, globalOutputExpr.toSlot());
globalOutputSMap.put(globalOutputExpr, globalOutputExpr.toSlot());
globalGroupBySMap.put(expr, globalOutputExpr.toSlot());
}
}
distinctExprsForLocalGroupBy.add(expr);
@ -196,7 +204,7 @@ public class AggregateDisassemble extends OneRewriteRuleFactory {
Expression substitutionValue = aggregateFunction.withChildren(
Lists.newArrayList(localOutputExpr.toSlot()));
inputSubstitutionMap.put(aggregateFunction, substitutionValue);
globalOutputSubstitutionMap.put(aggregateFunction, substitutionValue);
globalOutputSMap.put(aggregateFunction, substitutionValue);
localOutputExprs.add(localOutputExpr);
}
}
@ -222,7 +230,7 @@ public class AggregateDisassemble extends OneRewriteRuleFactory {
AggPhase.LOCAL,
aggregate.child()
);
return new LogicalAggregate<>(
return Pair.of(new LogicalAggregate<>(
globalGroupByExprs,
globalOutputExprs,
true,
@ -230,6 +238,6 @@ public class AggregateDisassemble extends OneRewriteRuleFactory {
true,
AggPhase.GLOBAL,
localAggregate
);
), hasDistinct);
}
}

View File

@ -56,6 +56,14 @@ public class MergeConsecutiveProjects extends OneRewriteRuleFactory {
private static class ExpressionReplacer extends DefaultExpressionRewriter<Map<Expression, Expression>> {
public static final ExpressionReplacer INSTANCE = new ExpressionReplacer();
public Expression replace(Expression expr, Map<Expression, Expression> substitutionMap) {
if (expr instanceof SlotReference) {
Slot ref = ((SlotReference) expr).withQualifier(Collections.emptyList());
return substitutionMap.getOrDefault(ref, expr);
}
return visit(expr, substitutionMap);
}
/**
* case 1:
* project(alias(c) as d, alias(x) as y)
@ -84,7 +92,7 @@ public class MergeConsecutiveProjects extends OneRewriteRuleFactory {
Slot ref = ((SlotReference) expr).withQualifier(Collections.emptyList());
if (substitutionMap.containsKey(ref)) {
Alias res = (Alias) substitutionMap.get(ref);
return (res.child() instanceof SlotReference) ? res : res.child();
return res.child();
}
} else if (substitutionMap.containsKey(expr)) {
return substitutionMap.get(expr).child(0);
@ -106,7 +114,7 @@ public class MergeConsecutiveProjects extends OneRewriteRuleFactory {
);
projectExpressions = projectExpressions.stream()
.map(e -> MergeConsecutiveProjects.ExpressionReplacer.INSTANCE.visit(e, childAliasMap))
.map(e -> MergeConsecutiveProjects.ExpressionReplacer.INSTANCE.replace(e, childAliasMap))
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
return new LogicalProject<>(projectExpressions, childProject.children().get(0));

View File

@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
@ -61,7 +62,7 @@ import java.util.stream.Collectors;
public class NormalizeAggregate extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalAggregate().when(aggregate -> !aggregate.isNormalized()).then(aggregate -> {
return logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> {
// substitution map used to substitute expression in aggregate's output to use it as top projections
Map<Expression, Expression> substitutionMap = Maps.newHashMap();
List<Expression> keys = aggregate.getGroupByExpressions();
@ -102,25 +103,40 @@ public class NormalizeAggregate extends OneRewriteRuleFactory {
List<NamedExpression> outputs = aggregate.getOutputExpressions();
Map<Boolean, List<NamedExpression>> partitionedOutputs = outputs.stream()
.collect(Collectors.groupingBy(e -> e.anyMatch(AggregateFunction.class::isInstance)));
boolean needBottomProjects = partitionedKeys.containsKey(false);
if (partitionedOutputs.containsKey(true)) {
// process expressions that contain aggregate function
Set<AggregateFunction> aggregateFunctions = partitionedOutputs.get(true).stream()
.flatMap(e -> e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream())
.collect(Collectors.toSet());
newOutputs.addAll(aggregateFunctions.stream()
.map(f -> new Alias(f, f.toSql()))
.peek(a -> substitutionMap.put(a.child(), a.toSlot()))
.collect(Collectors.toList()));
// add slot references in aggregate function to bottom projections
bottomProjections.addAll(aggregateFunctions.stream()
.flatMap(f -> f.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
.map(SlotReference.class::cast)
.collect(Collectors.toSet()));
// replace all non-slot expression in aggregate functions children.
for (AggregateFunction aggregateFunction : aggregateFunctions) {
List<Expression> newChildren = Lists.newArrayList();
for (Expression child : aggregateFunction.getArguments()) {
if (child instanceof SlotReference || child instanceof Literal) {
newChildren.add(child);
if (child instanceof SlotReference) {
bottomProjections.add((SlotReference) child);
}
} else {
needBottomProjects = true;
Alias alias = new Alias(child, child.toSql());
bottomProjections.add(alias);
newChildren.add(alias.toSlot());
}
}
AggregateFunction newFunction = (AggregateFunction) aggregateFunction.withChildren(newChildren);
Alias alias = new Alias(newFunction, newFunction.toSql());
newOutputs.add(alias);
substitutionMap.put(aggregateFunction, alias.toSlot());
}
}
// assemble
LogicalPlan root = aggregate.child();
if (partitionedKeys.containsKey(false)) {
if (needBottomProjects) {
root = new LogicalProject<>(bottomProjections, root);
}
root = new LogicalAggregate<>(newKeys, newOutputs, aggregate.isDisassembled(),

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import org.apache.doris.nereids.types.coercion.NumericType;
@ -41,6 +42,11 @@ public class Add extends BinaryArithmetic {
return new Add(children.get(0), children.get(1));
}
@Override
public DataType getDataType() {
return left().getDataType().promotion();
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitAdd(this, context);

View File

@ -44,24 +44,24 @@ public class ExecutableFunctions {
* Executable arithmetic functions
*/
@ExecFunction(name = "add", argTypes = {"TINYINT", "TINYINT"}, returnType = "TINYINT")
public static TinyIntLiteral addTinyint(TinyIntLiteral first, TinyIntLiteral second) {
byte result = (byte) Math.addExact(first.getValue(), second.getValue());
return new TinyIntLiteral(result);
}
@ExecFunction(name = "add", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "SMALLINT")
public static SmallIntLiteral addSmallint(SmallIntLiteral first, SmallIntLiteral second) {
@ExecFunction(name = "add", argTypes = {"TINYINT", "TINYINT"}, returnType = "SMALLINT")
public static SmallIntLiteral addTinyint(TinyIntLiteral first, TinyIntLiteral second) {
short result = (short) Math.addExact(first.getValue(), second.getValue());
return new SmallIntLiteral(result);
}
@ExecFunction(name = "add", argTypes = {"INT", "INT"}, returnType = "INT")
public static IntegerLiteral addInt(IntegerLiteral first, IntegerLiteral second) {
@ExecFunction(name = "add", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "INT")
public static IntegerLiteral addSmallint(SmallIntLiteral first, SmallIntLiteral second) {
int result = Math.addExact(first.getValue(), second.getValue());
return new IntegerLiteral(result);
}
@ExecFunction(name = "add", argTypes = {"INT", "INT"}, returnType = "BIGINT")
public static BigIntLiteral addInt(IntegerLiteral first, IntegerLiteral second) {
long result = Math.addExact(first.getValue(), second.getValue());
return new BigIntLiteral(result);
}
@ExecFunction(name = "add", argTypes = {"BIGINT", "BIGINT"}, returnType = "BIGINT")
public static BigIntLiteral addBigint(BigIntLiteral first, BigIntLiteral second) {
long result = Math.addExact(first.getValue(), second.getValue());

View File

@ -60,9 +60,10 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
private final AggPhase aggPhase;
// use for scenes containing distinct agg
// 1. If there are LOCAL and GLOBAL phases, global is the final phase
// 2. If there are LOCAL, GLOBAL and DISTINCT_LOCAL phases, DISTINCT_LOCAL is the final phase
// 3. If there are LOCAL, GLOBAL, DISTINCT_LOCAL, DISTINCT_GLOBAL phases,
// 1. If there is LOCAL only, LOCAL is the final phase
// 2. If there are LOCAL and GLOBAL phases, global is the final phase
// 3. If there are LOCAL, GLOBAL and DISTINCT_LOCAL phases, DISTINCT_LOCAL is the final phase
// 4. If there are LOCAL, GLOBAL, DISTINCT_LOCAL, DISTINCT_GLOBAL phases,
// DISTINCT_GLOBAL is the final phase
private final boolean isFinalPhase;
@ -73,7 +74,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
List<Expression> groupByExpressions,
List<NamedExpression> outputExpressions,
CHILD_TYPE child) {
this(groupByExpressions, outputExpressions, false, false, true, AggPhase.GLOBAL, child);
this(groupByExpressions, outputExpressions, false, false, true, AggPhase.LOCAL, child);
}
public LogicalAggregate(

View File

@ -34,6 +34,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.TinyIntType;
@ -256,7 +257,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
NamedExpressionUtil.clear();
sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 HAVING SUM(a1 + a2 + 3) > 0";
Alias sumA1A23 = new Alias(new ExprId(4), new Sum(new Add(new Add(a1, a2), new TinyIntLiteral((byte) 3))),
Alias sumA1A23 = new Alias(new ExprId(4), new Sum(new Add(new Add(a1, a2), new SmallIntLiteral((byte) 3))),
"sum(((a1 + a2) + 3))");
PlanChecker.from(connectContext).analyze(sql)
.matchesFromRoot(
@ -360,7 +361,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
ImmutableList.of("default_cluster:test_having", "t1")
);
Alias pk1 = new Alias(new ExprId(7), new Add(pk, Literal.of((byte) 1)), "(pk + 1)");
Alias pk11 = new Alias(new ExprId(8), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)");
Alias pk11 = new Alias(new ExprId(8), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((short) 1)), "((pk + 1) + 1)");
Alias pk2 = new Alias(new ExprId(9), new Add(pk, Literal.of((byte) 2)), "(pk + 2)");
Alias sumA1 = new Alias(new ExprId(10), new Sum(a1), "SUM(a1)");
Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1, false), Literal.of(1L)), "(COUNT(a1) + 1)");

View File

@ -150,7 +150,7 @@ public class RequestPropertyDeriverTest {
Lists.newArrayList(key),
AggPhase.LOCAL,
true,
true,
false,
logicalProperties,
groupPlan
);

View File

@ -171,9 +171,9 @@ public class FoldConstantTest {
@Test
public void testArithmeticFold() {
executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE));
assertRewrite("1 + 1", Literal.of((byte) 2));
assertRewrite("1 + 1", Literal.of((short) 2));
assertRewrite("1 - 1", Literal.of((byte) 0));
assertRewrite("100 + 100", Literal.of((byte) 200));
assertRewrite("100 + 100", Literal.of((short) 200));
assertRewrite("1 - 2", Literal.of((byte) -1));
assertRewrite("1 - 2 > 1", "false");
@ -284,9 +284,9 @@ public class FoldConstantTest {
private void assertRewrite(String expression, String expected) {
Map<String, Slot> mem = Maps.newHashMap();
Expression needRewriteExpression = PARSER.parseExpression(expression);
needRewriteExpression = replaceUnboundSlot(needRewriteExpression, mem);
needRewriteExpression = typeCoercion(replaceUnboundSlot(needRewriteExpression, mem));
Expression expectedExpression = PARSER.parseExpression(expected);
expectedExpression = replaceUnboundSlot(expectedExpression, mem);
expectedExpression = typeCoercion(replaceUnboundSlot(expectedExpression, mem));
Expression rewrittenExpression = executor.rewrite(needRewriteExpression);
Assertions.assertEquals(expectedExpression, rewrittenExpression);
}
@ -320,6 +320,10 @@ public class FoldConstantTest {
return hasNewChildren ? expression.withChildren(children) : expression;
}
private Expression typeCoercion(Expression expression) {
return TypeCoercion.INSTANCE.visit(expression, null);
}
private DataType getType(char t) {
switch (t) {
case 'T':

View File

@ -89,7 +89,8 @@ public class MergeConsecutiveProjectsTest {
relation);
LogicalProject project2 = new LogicalProject<>(
Lists.newArrayList(
new Alias(new Add(aliasRef, new IntegerLiteral(2)), "Y")
new Alias(new Add(aliasRef, new IntegerLiteral(2)), "Y"),
aliasRef
),
project1);
@ -100,11 +101,12 @@ public class MergeConsecutiveProjectsTest {
System.out.println(plan.treeString());
Assertions.assertTrue(plan instanceof LogicalProject);
LogicalProject finalProject = (LogicalProject) plan;
Add finalExpression = new Add(
Add aPlus1Plus2 = new Add(
new Add(colA, new IntegerLiteral(1)),
new IntegerLiteral(2)
);
Assertions.assertEquals(1, finalProject.getProjects().size());
Assertions.assertEquals(((Alias) finalProject.getProjects().get(0)).child(), finalExpression);
Assertions.assertEquals(2, finalProject.getProjects().size());
Assertions.assertEquals(aPlus1Plus2, ((Alias) finalProject.getProjects().get(0)).child());
Assertions.assertEquals(alias, finalProject.getProjects().get(1));
}
}

View File

@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
@ -93,15 +94,17 @@ public class NormalizeAggregateTest implements PatternMatchSupported {
* +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, age#3])
*
* after rewrite:
* LogicalProject ((sum((id * 1))#5 + 2) AS `(sum((id * 1)) + 2)`#4)
* +--LogicalAggregate (phase: [GLOBAL], output: [sum((id#0 * 1)) AS `sum((id * 1))`#5], groupBy: [name#2])
* +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, age#3])
* LogicalProject ( projects=[(sum((id * 1))#6 + 2) AS `(sum((id * 1)) + 2)`#4] )
* +--LogicalAggregate ( phase=LOCAL, outputExpr=[sum((id * 1)#5) AS `sum((id * 1))`#6], groupByExpr=[name#2] )
* +--LogicalProject ( projects=[name#2, (id#0 * 1) AS `(id * 1)`#5] )
* +--GroupPlan( GroupId#0 )
*/
@Test
public void testComplexFuncWithComplexOutputOfFunc() {
NamedExpression key = rStudent.getOutput().get(2).toSlot();
List<Expression> groupExpressionList = Lists.newArrayList(key);
Expression aggregateFunction = new Sum(new Multiply(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(1)));
Expression multiply = new Multiply(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(1));
Expression aggregateFunction = new Sum(multiply);
Expression complexOutput = new Add(aggregateFunction, new IntegerLiteral(2));
Alias output = new Alias(complexOutput, complexOutput.toSql());
List<NamedExpression> outputExpressionList = Lists.newArrayList(output);
@ -112,10 +115,14 @@ public class NormalizeAggregateTest implements PatternMatchSupported {
.matchesFromRoot(
logicalProject(
logicalAggregate(
logicalOlapScan()
logicalProject(
logicalOlapScan()
).when(project -> project.getProjects().size() == 2)
.when(project -> project.getProjects().get(0) instanceof SlotReference)
.when(project -> project.getProjects().get(1).child(0).equals(multiply))
).when(FieldChecker.check("groupByExpressions", ImmutableList.of(key)))
.when(aggregate -> aggregate.getOutputExpressions().size() == 1)
.when(aggregate -> aggregate.getOutputExpressions().get(0).child(0).equals(aggregateFunction))
.when(aggregate -> aggregate.getOutputExpressions().get(0).child(0) instanceof AggregateFunction)
).when(project -> project.getProjects().size() == 1)
.when(project -> project.getProjects().get(0) instanceof Alias)
.when(project -> project.getProjects().get(0).getExprId().equals(output.getExprId()))

View File

@ -57,7 +57,7 @@ public class PlanToStringTest {
new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList())), child);
Assertions.assertTrue(plan.toString()
.matches("LogicalAggregate \\( phase=GLOBAL, outputExpr=\\[a#\\d+], groupByExpr=\\[] \\)"));
.matches("LogicalAggregate \\( phase=LOCAL, outputExpr=\\[a#\\d+], groupByExpr=\\[] \\)"));
}
@Test

View File

@ -25,6 +25,7 @@ import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
import org.apache.doris.nereids.rules.rewrite.logical.ApplyPullFilterOnAgg;
import org.apache.doris.nereids.rules.rewrite.logical.ApplyPullFilterOnProjectUnderAgg;
import org.apache.doris.nereids.rules.rewrite.logical.ExistsApplyToJoin;
@ -120,7 +121,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte
new MockUp<RuleSet>() {
@Mock
public List<Rule> getExplorationRules() {
return Lists.newArrayList();
return Lists.newArrayList(new AggregateDisassemble().build());
}
};