From eef9367705f0fd7ea6e412c64581f5c5c64a52dc Mon Sep 17 00:00:00 2001 From: morrySnow <101034200+morrySnow@users.noreply.github.com> Date: Wed, 28 Sep 2022 10:38:03 +0800 Subject: [PATCH] [feature](Nereids) use one stage aggregation if available (#12849) Currently, we always disassemble aggregation into two stage: local and global. However, in some case, one stage aggregation is enough, there are two advantage of one stage aggregation. 1. avoid unnecessary exchange. 2. have a chance to do colocate join on the top of aggregation. This PR move AggregateDisassemble rule from rewrite stage to optimization stage. And choose one stage or two stage aggregation according to cost. --- .../translator/PhysicalPlanTranslator.java | 9 +-- .../jobs/batch/NereidsRewriteJobExecutor.java | 2 - .../jobs/cascades/CostAndEnforcerJob.java | 4 ++ .../ChildrenPropertiesRegulator.java | 13 ++++ .../properties/RequestPropertyDeriver.java | 2 +- .../apache/doris/nereids/rules/RuleSet.java | 2 + .../rules/rewrite/AggregateDisassemble.java | 68 +++++++++++-------- .../logical/MergeConsecutiveProjects.java | 12 +++- .../rewrite/logical/NormalizeAggregate.java | 38 ++++++++--- .../doris/nereids/trees/expressions/Add.java | 6 ++ .../functions/ExecutableFunctions.java | 20 +++--- .../trees/plans/logical/LogicalAggregate.java | 9 +-- .../nereids/parser/HavingClauseTest.java | 5 +- .../RequestPropertyDeriverTest.java | 2 +- .../expression/rewrite/FoldConstantTest.java | 12 ++-- .../logical/MergeConsecutiveProjectsTest.java | 10 +-- .../logical/NormalizeAggregateTest.java | 19 ++++-- .../nereids/trees/plans/PlanToStringTest.java | 2 +- .../util/AnalyzeWhereSubqueryTest.java | 3 +- 19 files changed, 155 insertions(+), 83 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 7d924a0838..60e31fc904 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -193,9 +193,8 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor slotList = Lists.newArrayList(); TupleDescriptor outputTupleDesc; - if (aggregate.getAggPhase() == AggPhase.LOCAL) { - outputTupleDesc = generateTupleDesc(aggregate.getOutput(), null, context); - } else if ((aggregate.getAggPhase() == AggPhase.GLOBAL && aggregate.isFinalPhase()) + if (aggregate.getAggPhase() == AggPhase.LOCAL + || (aggregate.getAggPhase() == AggPhase.GLOBAL && aggregate.isFinalPhase()) || aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) { slotList.addAll(groupSlotList); slotList.addAll(aggFunctionOutput); @@ -222,10 +221,12 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor { return enforceCost; } + @Override + public Double visitPhysicalAggregate(PhysicalAggregate agg, Void context) { + if (agg.isFinalPhase() + && agg.getAggPhase() == AggPhase.LOCAL + && children.get(0).getPlan() instanceof PhysicalDistribute) { + return -1.0; + } + return 0.0; + } + @Override public Double visitPhysicalHashJoin(PhysicalHashJoin hashJoin, Void context) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java index 5063909d21..48024ee091 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java @@ -80,7 +80,7 @@ public class RequestPropertyDeriver extends PlanVisitor { @Override public Void visitPhysicalAggregate(PhysicalAggregate agg, PlanContext context) { // 1. first phase agg just return any - if (agg.getAggPhase().isLocal()) { + if (agg.getAggPhase().isLocal() && !agg.isFinalPhase()) { addToRequestPropertyToChildren(PhysicalProperties.ANY); return null; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index 69a56cd30c..8456b73d64 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -37,6 +37,7 @@ import org.apache.doris.nereids.rules.implementation.LogicalOneRowRelationToPhys import org.apache.doris.nereids.rules.implementation.LogicalProjectToPhysicalProject; import org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickSort; import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN; +import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble; import org.apache.doris.nereids.rules.rewrite.logical.EliminateOuter; import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveFilters; import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveLimits; @@ -65,6 +66,7 @@ public class RuleSet { .add(SemiJoinLogicalJoinTranspose.LEFT_DEEP) .add(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP) .add(SemiJoinSemiJoinTranspose.INSTANCE) + .add(new AggregateDisassemble()) .add(new PushdownFilterThroughProject()) .add(new MergeConsecutiveProjects()) .build(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java index 13efb1e877..77135074d4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Alias; @@ -66,36 +67,39 @@ import java.util.stream.Collectors; * 2. if instance count is 1, shouldn't disassemble the agg plan */ public class AggregateDisassemble extends OneRewriteRuleFactory { - // used in secondDisassemble to transform local expressions into global - private final Map globalOutputSubstitutionMap = Maps.newHashMap(); - // used in secondDisassemble to transform local expressions into global - private final Map globalGroupBySubstitutionMap = Maps.newHashMap(); - // used to indicate the existence of a distinct function for the entire phase - private boolean hasDistinctAgg = false; @Override public Rule build() { - return logicalAggregate().when(agg -> !agg.isDisassembled()).thenApply(ctx -> { - LogicalAggregate aggregate = ctx.root; - LogicalAggregate firstAggregate = firstDisassemble(aggregate); - if (!hasDistinctAgg) { - return firstAggregate; - } - return secondDisassemble(firstAggregate); - }).toRule(RuleType.AGGREGATE_DISASSEMBLE); + return logicalAggregate() + .whenNot(LogicalAggregate::isDisassembled) + .then(aggregate -> { + // used in secondDisassemble to transform local expressions into global + final Map globalOutputSMap = Maps.newHashMap(); + // used in secondDisassemble to transform local expressions into global + final Map globalGroupBySMap = Maps.newHashMap(); + Pair ret = firstDisassemble(aggregate, globalOutputSMap, + globalGroupBySMap); + if (!ret.second) { + return ret.first; + } + return secondDisassemble(ret.first, globalOutputSMap, globalGroupBySMap); + }).toRule(RuleType.AGGREGATE_DISASSEMBLE); } // only support distinct function with group by // TODO: support distinct function without group by. (add second global phase) - private LogicalAggregate secondDisassemble(LogicalAggregate aggregate) { + private LogicalAggregate secondDisassemble( + LogicalAggregate aggregate, + Map globalOutputSMap, + Map globalGroupBySMap) { LogicalAggregate local = aggregate.child(); // replace expression in globalOutputExprs and globalGroupByExprs List globalOutputExprs = local.getOutputExpressions().stream() - .map(e -> ExpressionUtils.replace(e, globalOutputSubstitutionMap)) + .map(e -> ExpressionUtils.replace(e, globalOutputSMap)) .map(NamedExpression.class::cast) .collect(Collectors.toList()); List globalGroupByExprs = local.getGroupByExpressions().stream() - .map(e -> ExpressionUtils.replace(e, globalGroupBySubstitutionMap)) + .map(e -> ExpressionUtils.replace(e, globalGroupBySMap)) .collect(Collectors.toList()); // generate new plan @@ -119,7 +123,11 @@ public class AggregateDisassemble extends OneRewriteRuleFactory { ); } - private LogicalAggregate firstDisassemble(LogicalAggregate aggregate) { + private Pair firstDisassemble( + LogicalAggregate aggregate, + Map globalOutputSMap, + Map globalGroupBySMap) { + Boolean hasDistinct = Boolean.FALSE; List originOutputExprs = aggregate.getOutputExpressions(); List originGroupByExprs = aggregate.getGroupByExpressions(); Map inputSubstitutionMap = Maps.newHashMap(); @@ -149,14 +157,14 @@ public class AggregateDisassemble extends OneRewriteRuleFactory { } if (originGroupByExpr instanceof SlotReference) { inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr); - globalOutputSubstitutionMap.put(originGroupByExpr, originGroupByExpr); - globalGroupBySubstitutionMap.put(originGroupByExpr, originGroupByExpr); + globalOutputSMap.put(originGroupByExpr, originGroupByExpr); + globalGroupBySMap.put(originGroupByExpr, originGroupByExpr); localOutputExprs.add((SlotReference) originGroupByExpr); } else { NamedExpression localOutputExpr = new Alias(originGroupByExpr, originGroupByExpr.toSql()); inputSubstitutionMap.put(originGroupByExpr, localOutputExpr.toSlot()); - globalOutputSubstitutionMap.put(localOutputExpr, localOutputExpr.toSlot()); - globalGroupBySubstitutionMap.put(originGroupByExpr, localOutputExpr.toSlot()); + globalOutputSMap.put(localOutputExpr, localOutputExpr.toSlot()); + globalGroupBySMap.put(originGroupByExpr, localOutputExpr.toSlot()); localOutputExprs.add(localOutputExpr); } } @@ -170,22 +178,22 @@ public class AggregateDisassemble extends OneRewriteRuleFactory { continue; } if (aggregateFunction.isDistinct()) { - hasDistinctAgg = true; + hasDistinct = Boolean.TRUE; for (Expression expr : aggregateFunction.children()) { if (expr instanceof SlotReference) { distinctExprsForLocalOutput.add((SlotReference) expr); if (!inputSubstitutionMap.containsKey(expr)) { inputSubstitutionMap.put(expr, expr); - globalOutputSubstitutionMap.put(expr, expr); - globalGroupBySubstitutionMap.put(expr, expr); + globalOutputSMap.put(expr, expr); + globalGroupBySMap.put(expr, expr); } } else { NamedExpression globalOutputExpr = new Alias(expr, expr.toSql()); distinctExprsForLocalOutput.add(globalOutputExpr); if (!inputSubstitutionMap.containsKey(expr)) { inputSubstitutionMap.put(expr, globalOutputExpr.toSlot()); - globalOutputSubstitutionMap.put(globalOutputExpr, globalOutputExpr.toSlot()); - globalGroupBySubstitutionMap.put(expr, globalOutputExpr.toSlot()); + globalOutputSMap.put(globalOutputExpr, globalOutputExpr.toSlot()); + globalGroupBySMap.put(expr, globalOutputExpr.toSlot()); } } distinctExprsForLocalGroupBy.add(expr); @@ -196,7 +204,7 @@ public class AggregateDisassemble extends OneRewriteRuleFactory { Expression substitutionValue = aggregateFunction.withChildren( Lists.newArrayList(localOutputExpr.toSlot())); inputSubstitutionMap.put(aggregateFunction, substitutionValue); - globalOutputSubstitutionMap.put(aggregateFunction, substitutionValue); + globalOutputSMap.put(aggregateFunction, substitutionValue); localOutputExprs.add(localOutputExpr); } } @@ -222,7 +230,7 @@ public class AggregateDisassemble extends OneRewriteRuleFactory { AggPhase.LOCAL, aggregate.child() ); - return new LogicalAggregate<>( + return Pair.of(new LogicalAggregate<>( globalGroupByExprs, globalOutputExprs, true, @@ -230,6 +238,6 @@ public class AggregateDisassemble extends OneRewriteRuleFactory { true, AggPhase.GLOBAL, localAggregate - ); + ), hasDistinct); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjects.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjects.java index 7738ee2772..626d6d22cd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjects.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjects.java @@ -56,6 +56,14 @@ public class MergeConsecutiveProjects extends OneRewriteRuleFactory { private static class ExpressionReplacer extends DefaultExpressionRewriter> { public static final ExpressionReplacer INSTANCE = new ExpressionReplacer(); + public Expression replace(Expression expr, Map substitutionMap) { + if (expr instanceof SlotReference) { + Slot ref = ((SlotReference) expr).withQualifier(Collections.emptyList()); + return substitutionMap.getOrDefault(ref, expr); + } + return visit(expr, substitutionMap); + } + /** * case 1: * project(alias(c) as d, alias(x) as y) @@ -84,7 +92,7 @@ public class MergeConsecutiveProjects extends OneRewriteRuleFactory { Slot ref = ((SlotReference) expr).withQualifier(Collections.emptyList()); if (substitutionMap.containsKey(ref)) { Alias res = (Alias) substitutionMap.get(ref); - return (res.child() instanceof SlotReference) ? res : res.child(); + return res.child(); } } else if (substitutionMap.containsKey(expr)) { return substitutionMap.get(expr).child(0); @@ -106,7 +114,7 @@ public class MergeConsecutiveProjects extends OneRewriteRuleFactory { ); projectExpressions = projectExpressions.stream() - .map(e -> MergeConsecutiveProjects.ExpressionReplacer.INSTANCE.visit(e, childAliasMap)) + .map(e -> MergeConsecutiveProjects.ExpressionReplacer.INSTANCE.replace(e, childAliasMap)) .map(NamedExpression.class::cast) .collect(Collectors.toList()); return new LogicalProject<>(projectExpressions, childProject.children().get(0)); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java index 5a6243809e..a888e06ce4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java @@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; @@ -61,7 +62,7 @@ import java.util.stream.Collectors; public class NormalizeAggregate extends OneRewriteRuleFactory { @Override public Rule build() { - return logicalAggregate().when(aggregate -> !aggregate.isNormalized()).then(aggregate -> { + return logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> { // substitution map used to substitute expression in aggregate's output to use it as top projections Map substitutionMap = Maps.newHashMap(); List keys = aggregate.getGroupByExpressions(); @@ -102,25 +103,40 @@ public class NormalizeAggregate extends OneRewriteRuleFactory { List outputs = aggregate.getOutputExpressions(); Map> partitionedOutputs = outputs.stream() .collect(Collectors.groupingBy(e -> e.anyMatch(AggregateFunction.class::isInstance))); + + boolean needBottomProjects = partitionedKeys.containsKey(false); if (partitionedOutputs.containsKey(true)) { // process expressions that contain aggregate function Set aggregateFunctions = partitionedOutputs.get(true).stream() .flatMap(e -> e.>collect(AggregateFunction.class::isInstance).stream()) .collect(Collectors.toSet()); - newOutputs.addAll(aggregateFunctions.stream() - .map(f -> new Alias(f, f.toSql())) - .peek(a -> substitutionMap.put(a.child(), a.toSlot())) - .collect(Collectors.toList())); - // add slot references in aggregate function to bottom projections - bottomProjections.addAll(aggregateFunctions.stream() - .flatMap(f -> f.>collect(SlotReference.class::isInstance).stream()) - .map(SlotReference.class::cast) - .collect(Collectors.toSet())); + + // replace all non-slot expression in aggregate functions children. + for (AggregateFunction aggregateFunction : aggregateFunctions) { + List newChildren = Lists.newArrayList(); + for (Expression child : aggregateFunction.getArguments()) { + if (child instanceof SlotReference || child instanceof Literal) { + newChildren.add(child); + if (child instanceof SlotReference) { + bottomProjections.add((SlotReference) child); + } + } else { + needBottomProjects = true; + Alias alias = new Alias(child, child.toSql()); + bottomProjections.add(alias); + newChildren.add(alias.toSlot()); + } + } + AggregateFunction newFunction = (AggregateFunction) aggregateFunction.withChildren(newChildren); + Alias alias = new Alias(newFunction, newFunction.toSql()); + newOutputs.add(alias); + substitutionMap.put(aggregateFunction, alias.toSlot()); + } } // assemble LogicalPlan root = aggregate.child(); - if (partitionedKeys.containsKey(false)) { + if (needBottomProjects) { root = new LogicalProject<>(bottomProjections, root); } root = new LogicalAggregate<>(newKeys, newOutputs, aggregate.isDisassembled(), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Add.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Add.java index 34e807a1c5..13ccd1ca95 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Add.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Add.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions; import org.apache.doris.analysis.ArithmeticExpr.Operator; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.coercion.AbstractDataType; import org.apache.doris.nereids.types.coercion.NumericType; @@ -41,6 +42,11 @@ public class Add extends BinaryArithmetic { return new Add(children.get(0), children.get(1)); } + @Override + public DataType getDataType() { + return left().getDataType().promotion(); + } + @Override public R accept(ExpressionVisitor visitor, C context) { return visitor.visitAdd(this, context); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExecutableFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExecutableFunctions.java index 4c65d19883..2fab50e38e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExecutableFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExecutableFunctions.java @@ -44,24 +44,24 @@ public class ExecutableFunctions { * Executable arithmetic functions */ - @ExecFunction(name = "add", argTypes = {"TINYINT", "TINYINT"}, returnType = "TINYINT") - public static TinyIntLiteral addTinyint(TinyIntLiteral first, TinyIntLiteral second) { - byte result = (byte) Math.addExact(first.getValue(), second.getValue()); - return new TinyIntLiteral(result); - } - - @ExecFunction(name = "add", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "SMALLINT") - public static SmallIntLiteral addSmallint(SmallIntLiteral first, SmallIntLiteral second) { + @ExecFunction(name = "add", argTypes = {"TINYINT", "TINYINT"}, returnType = "SMALLINT") + public static SmallIntLiteral addTinyint(TinyIntLiteral first, TinyIntLiteral second) { short result = (short) Math.addExact(first.getValue(), second.getValue()); return new SmallIntLiteral(result); } - @ExecFunction(name = "add", argTypes = {"INT", "INT"}, returnType = "INT") - public static IntegerLiteral addInt(IntegerLiteral first, IntegerLiteral second) { + @ExecFunction(name = "add", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "INT") + public static IntegerLiteral addSmallint(SmallIntLiteral first, SmallIntLiteral second) { int result = Math.addExact(first.getValue(), second.getValue()); return new IntegerLiteral(result); } + @ExecFunction(name = "add", argTypes = {"INT", "INT"}, returnType = "BIGINT") + public static BigIntLiteral addInt(IntegerLiteral first, IntegerLiteral second) { + long result = Math.addExact(first.getValue(), second.getValue()); + return new BigIntLiteral(result); + } + @ExecFunction(name = "add", argTypes = {"BIGINT", "BIGINT"}, returnType = "BIGINT") public static BigIntLiteral addBigint(BigIntLiteral first, BigIntLiteral second) { long result = Math.addExact(first.getValue(), second.getValue()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java index 0cca04950d..06df06b92b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java @@ -60,9 +60,10 @@ public class LogicalAggregate extends LogicalUnary extends LogicalUnary groupByExpressions, List outputExpressions, CHILD_TYPE child) { - this(groupByExpressions, outputExpressions, false, false, true, AggPhase.GLOBAL, child); + this(groupByExpressions, outputExpressions, false, false, true, AggPhase.LOCAL, child); } public LogicalAggregate( diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java index 05d256f395..ca96dcceb3 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java @@ -34,6 +34,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.Min; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.TinyIntType; @@ -256,7 +257,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat NamedExpressionUtil.clear(); sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 HAVING SUM(a1 + a2 + 3) > 0"; - Alias sumA1A23 = new Alias(new ExprId(4), new Sum(new Add(new Add(a1, a2), new TinyIntLiteral((byte) 3))), + Alias sumA1A23 = new Alias(new ExprId(4), new Sum(new Add(new Add(a1, a2), new SmallIntLiteral((byte) 3))), "sum(((a1 + a2) + 3))"); PlanChecker.from(connectContext).analyze(sql) .matchesFromRoot( @@ -360,7 +361,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat ImmutableList.of("default_cluster:test_having", "t1") ); Alias pk1 = new Alias(new ExprId(7), new Add(pk, Literal.of((byte) 1)), "(pk + 1)"); - Alias pk11 = new Alias(new ExprId(8), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)"); + Alias pk11 = new Alias(new ExprId(8), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((short) 1)), "((pk + 1) + 1)"); Alias pk2 = new Alias(new ExprId(9), new Add(pk, Literal.of((byte) 2)), "(pk + 2)"); Alias sumA1 = new Alias(new ExprId(10), new Sum(a1), "SUM(a1)"); Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1, false), Literal.of(1L)), "(COUNT(a1) + 1)"); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java index c179823a92..c133bf8ed3 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java @@ -150,7 +150,7 @@ public class RequestPropertyDeriverTest { Lists.newArrayList(key), AggPhase.LOCAL, true, - true, + false, logicalProperties, groupPlan ); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/FoldConstantTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/FoldConstantTest.java index ec21777e20..4769edcf2c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/FoldConstantTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/FoldConstantTest.java @@ -171,9 +171,9 @@ public class FoldConstantTest { @Test public void testArithmeticFold() { executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE)); - assertRewrite("1 + 1", Literal.of((byte) 2)); + assertRewrite("1 + 1", Literal.of((short) 2)); assertRewrite("1 - 1", Literal.of((byte) 0)); - assertRewrite("100 + 100", Literal.of((byte) 200)); + assertRewrite("100 + 100", Literal.of((short) 200)); assertRewrite("1 - 2", Literal.of((byte) -1)); assertRewrite("1 - 2 > 1", "false"); @@ -284,9 +284,9 @@ public class FoldConstantTest { private void assertRewrite(String expression, String expected) { Map mem = Maps.newHashMap(); Expression needRewriteExpression = PARSER.parseExpression(expression); - needRewriteExpression = replaceUnboundSlot(needRewriteExpression, mem); + needRewriteExpression = typeCoercion(replaceUnboundSlot(needRewriteExpression, mem)); Expression expectedExpression = PARSER.parseExpression(expected); - expectedExpression = replaceUnboundSlot(expectedExpression, mem); + expectedExpression = typeCoercion(replaceUnboundSlot(expectedExpression, mem)); Expression rewrittenExpression = executor.rewrite(needRewriteExpression); Assertions.assertEquals(expectedExpression, rewrittenExpression); } @@ -320,6 +320,10 @@ public class FoldConstantTest { return hasNewChildren ? expression.withChildren(children) : expression; } + private Expression typeCoercion(Expression expression) { + return TypeCoercion.INSTANCE.visit(expression, null); + } + private DataType getType(char t) { switch (t) { case 'T': diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjectsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjectsTest.java index 0816e7401b..08c9af2946 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjectsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjectsTest.java @@ -89,7 +89,8 @@ public class MergeConsecutiveProjectsTest { relation); LogicalProject project2 = new LogicalProject<>( Lists.newArrayList( - new Alias(new Add(aliasRef, new IntegerLiteral(2)), "Y") + new Alias(new Add(aliasRef, new IntegerLiteral(2)), "Y"), + aliasRef ), project1); @@ -100,11 +101,12 @@ public class MergeConsecutiveProjectsTest { System.out.println(plan.treeString()); Assertions.assertTrue(plan instanceof LogicalProject); LogicalProject finalProject = (LogicalProject) plan; - Add finalExpression = new Add( + Add aPlus1Plus2 = new Add( new Add(colA, new IntegerLiteral(1)), new IntegerLiteral(2) ); - Assertions.assertEquals(1, finalProject.getProjects().size()); - Assertions.assertEquals(((Alias) finalProject.getProjects().get(0)).child(), finalExpression); + Assertions.assertEquals(2, finalProject.getProjects().size()); + Assertions.assertEquals(aPlus1Plus2, ((Alias) finalProject.getProjects().get(0)).child()); + Assertions.assertEquals(alias, finalProject.getProjects().get(1)); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java index 9efa81376c..ef819a66f6 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Multiply; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.plans.Plan; @@ -93,15 +94,17 @@ public class NormalizeAggregateTest implements PatternMatchSupported { * +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, age#3]) * * after rewrite: - * LogicalProject ((sum((id * 1))#5 + 2) AS `(sum((id * 1)) + 2)`#4) - * +--LogicalAggregate (phase: [GLOBAL], output: [sum((id#0 * 1)) AS `sum((id * 1))`#5], groupBy: [name#2]) - * +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, age#3]) + * LogicalProject ( projects=[(sum((id * 1))#6 + 2) AS `(sum((id * 1)) + 2)`#4] ) + * +--LogicalAggregate ( phase=LOCAL, outputExpr=[sum((id * 1)#5) AS `sum((id * 1))`#6], groupByExpr=[name#2] ) + * +--LogicalProject ( projects=[name#2, (id#0 * 1) AS `(id * 1)`#5] ) + * +--GroupPlan( GroupId#0 ) */ @Test public void testComplexFuncWithComplexOutputOfFunc() { NamedExpression key = rStudent.getOutput().get(2).toSlot(); List groupExpressionList = Lists.newArrayList(key); - Expression aggregateFunction = new Sum(new Multiply(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(1))); + Expression multiply = new Multiply(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(1)); + Expression aggregateFunction = new Sum(multiply); Expression complexOutput = new Add(aggregateFunction, new IntegerLiteral(2)); Alias output = new Alias(complexOutput, complexOutput.toSql()); List outputExpressionList = Lists.newArrayList(output); @@ -112,10 +115,14 @@ public class NormalizeAggregateTest implements PatternMatchSupported { .matchesFromRoot( logicalProject( logicalAggregate( - logicalOlapScan() + logicalProject( + logicalOlapScan() + ).when(project -> project.getProjects().size() == 2) + .when(project -> project.getProjects().get(0) instanceof SlotReference) + .when(project -> project.getProjects().get(1).child(0).equals(multiply)) ).when(FieldChecker.check("groupByExpressions", ImmutableList.of(key))) .when(aggregate -> aggregate.getOutputExpressions().size() == 1) - .when(aggregate -> aggregate.getOutputExpressions().get(0).child(0).equals(aggregateFunction)) + .when(aggregate -> aggregate.getOutputExpressions().get(0).child(0) instanceof AggregateFunction) ).when(project -> project.getProjects().size() == 1) .when(project -> project.getProjects().get(0) instanceof Alias) .when(project -> project.getProjects().get(0).getExprId().equals(output.getExprId())) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java index 508c6f45da..803523eb0c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java @@ -57,7 +57,7 @@ public class PlanToStringTest { new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList())), child); Assertions.assertTrue(plan.toString() - .matches("LogicalAggregate \\( phase=GLOBAL, outputExpr=\\[a#\\d+], groupByExpr=\\[] \\)")); + .matches("LogicalAggregate \\( phase=LOCAL, outputExpr=\\[a#\\d+], groupByExpr=\\[] \\)")); } @Test diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java index ce3c7392bf..60bba6ba38 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java @@ -25,6 +25,7 @@ import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleSet; +import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble; import org.apache.doris.nereids.rules.rewrite.logical.ApplyPullFilterOnAgg; import org.apache.doris.nereids.rules.rewrite.logical.ApplyPullFilterOnProjectUnderAgg; import org.apache.doris.nereids.rules.rewrite.logical.ExistsApplyToJoin; @@ -120,7 +121,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte new MockUp() { @Mock public List getExplorationRules() { - return Lists.newArrayList(); + return Lists.newArrayList(new AggregateDisassemble().build()); } };