diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 7d924a0838..60e31fc904 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -193,9 +193,8 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor slotList = Lists.newArrayList(); TupleDescriptor outputTupleDesc; - if (aggregate.getAggPhase() == AggPhase.LOCAL) { - outputTupleDesc = generateTupleDesc(aggregate.getOutput(), null, context); - } else if ((aggregate.getAggPhase() == AggPhase.GLOBAL && aggregate.isFinalPhase()) + if (aggregate.getAggPhase() == AggPhase.LOCAL + || (aggregate.getAggPhase() == AggPhase.GLOBAL && aggregate.isFinalPhase()) || aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) { slotList.addAll(groupSlotList); slotList.addAll(aggFunctionOutput); @@ -222,10 +221,12 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor { return enforceCost; } + @Override + public Double visitPhysicalAggregate(PhysicalAggregate agg, Void context) { + if (agg.isFinalPhase() + && agg.getAggPhase() == AggPhase.LOCAL + && children.get(0).getPlan() instanceof PhysicalDistribute) { + return -1.0; + } + return 0.0; + } + @Override public Double visitPhysicalHashJoin(PhysicalHashJoin hashJoin, Void context) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java index 5063909d21..48024ee091 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java @@ -80,7 +80,7 @@ public class RequestPropertyDeriver extends PlanVisitor { @Override public Void visitPhysicalAggregate(PhysicalAggregate agg, PlanContext context) { // 1. first phase agg just return any - if (agg.getAggPhase().isLocal()) { + if (agg.getAggPhase().isLocal() && !agg.isFinalPhase()) { addToRequestPropertyToChildren(PhysicalProperties.ANY); return null; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index 69a56cd30c..8456b73d64 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -37,6 +37,7 @@ import org.apache.doris.nereids.rules.implementation.LogicalOneRowRelationToPhys import org.apache.doris.nereids.rules.implementation.LogicalProjectToPhysicalProject; import org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickSort; import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN; +import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble; import org.apache.doris.nereids.rules.rewrite.logical.EliminateOuter; import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveFilters; import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveLimits; @@ -65,6 +66,7 @@ public class RuleSet { .add(SemiJoinLogicalJoinTranspose.LEFT_DEEP) .add(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP) .add(SemiJoinSemiJoinTranspose.INSTANCE) + .add(new AggregateDisassemble()) .add(new PushdownFilterThroughProject()) .add(new MergeConsecutiveProjects()) .build(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java index 13efb1e877..77135074d4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Alias; @@ -66,36 +67,39 @@ import java.util.stream.Collectors; * 2. if instance count is 1, shouldn't disassemble the agg plan */ public class AggregateDisassemble extends OneRewriteRuleFactory { - // used in secondDisassemble to transform local expressions into global - private final Map globalOutputSubstitutionMap = Maps.newHashMap(); - // used in secondDisassemble to transform local expressions into global - private final Map globalGroupBySubstitutionMap = Maps.newHashMap(); - // used to indicate the existence of a distinct function for the entire phase - private boolean hasDistinctAgg = false; @Override public Rule build() { - return logicalAggregate().when(agg -> !agg.isDisassembled()).thenApply(ctx -> { - LogicalAggregate aggregate = ctx.root; - LogicalAggregate firstAggregate = firstDisassemble(aggregate); - if (!hasDistinctAgg) { - return firstAggregate; - } - return secondDisassemble(firstAggregate); - }).toRule(RuleType.AGGREGATE_DISASSEMBLE); + return logicalAggregate() + .whenNot(LogicalAggregate::isDisassembled) + .then(aggregate -> { + // used in secondDisassemble to transform local expressions into global + final Map globalOutputSMap = Maps.newHashMap(); + // used in secondDisassemble to transform local expressions into global + final Map globalGroupBySMap = Maps.newHashMap(); + Pair ret = firstDisassemble(aggregate, globalOutputSMap, + globalGroupBySMap); + if (!ret.second) { + return ret.first; + } + return secondDisassemble(ret.first, globalOutputSMap, globalGroupBySMap); + }).toRule(RuleType.AGGREGATE_DISASSEMBLE); } // only support distinct function with group by // TODO: support distinct function without group by. (add second global phase) - private LogicalAggregate secondDisassemble(LogicalAggregate aggregate) { + private LogicalAggregate secondDisassemble( + LogicalAggregate aggregate, + Map globalOutputSMap, + Map globalGroupBySMap) { LogicalAggregate local = aggregate.child(); // replace expression in globalOutputExprs and globalGroupByExprs List globalOutputExprs = local.getOutputExpressions().stream() - .map(e -> ExpressionUtils.replace(e, globalOutputSubstitutionMap)) + .map(e -> ExpressionUtils.replace(e, globalOutputSMap)) .map(NamedExpression.class::cast) .collect(Collectors.toList()); List globalGroupByExprs = local.getGroupByExpressions().stream() - .map(e -> ExpressionUtils.replace(e, globalGroupBySubstitutionMap)) + .map(e -> ExpressionUtils.replace(e, globalGroupBySMap)) .collect(Collectors.toList()); // generate new plan @@ -119,7 +123,11 @@ public class AggregateDisassemble extends OneRewriteRuleFactory { ); } - private LogicalAggregate firstDisassemble(LogicalAggregate aggregate) { + private Pair firstDisassemble( + LogicalAggregate aggregate, + Map globalOutputSMap, + Map globalGroupBySMap) { + Boolean hasDistinct = Boolean.FALSE; List originOutputExprs = aggregate.getOutputExpressions(); List originGroupByExprs = aggregate.getGroupByExpressions(); Map inputSubstitutionMap = Maps.newHashMap(); @@ -149,14 +157,14 @@ public class AggregateDisassemble extends OneRewriteRuleFactory { } if (originGroupByExpr instanceof SlotReference) { inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr); - globalOutputSubstitutionMap.put(originGroupByExpr, originGroupByExpr); - globalGroupBySubstitutionMap.put(originGroupByExpr, originGroupByExpr); + globalOutputSMap.put(originGroupByExpr, originGroupByExpr); + globalGroupBySMap.put(originGroupByExpr, originGroupByExpr); localOutputExprs.add((SlotReference) originGroupByExpr); } else { NamedExpression localOutputExpr = new Alias(originGroupByExpr, originGroupByExpr.toSql()); inputSubstitutionMap.put(originGroupByExpr, localOutputExpr.toSlot()); - globalOutputSubstitutionMap.put(localOutputExpr, localOutputExpr.toSlot()); - globalGroupBySubstitutionMap.put(originGroupByExpr, localOutputExpr.toSlot()); + globalOutputSMap.put(localOutputExpr, localOutputExpr.toSlot()); + globalGroupBySMap.put(originGroupByExpr, localOutputExpr.toSlot()); localOutputExprs.add(localOutputExpr); } } @@ -170,22 +178,22 @@ public class AggregateDisassemble extends OneRewriteRuleFactory { continue; } if (aggregateFunction.isDistinct()) { - hasDistinctAgg = true; + hasDistinct = Boolean.TRUE; for (Expression expr : aggregateFunction.children()) { if (expr instanceof SlotReference) { distinctExprsForLocalOutput.add((SlotReference) expr); if (!inputSubstitutionMap.containsKey(expr)) { inputSubstitutionMap.put(expr, expr); - globalOutputSubstitutionMap.put(expr, expr); - globalGroupBySubstitutionMap.put(expr, expr); + globalOutputSMap.put(expr, expr); + globalGroupBySMap.put(expr, expr); } } else { NamedExpression globalOutputExpr = new Alias(expr, expr.toSql()); distinctExprsForLocalOutput.add(globalOutputExpr); if (!inputSubstitutionMap.containsKey(expr)) { inputSubstitutionMap.put(expr, globalOutputExpr.toSlot()); - globalOutputSubstitutionMap.put(globalOutputExpr, globalOutputExpr.toSlot()); - globalGroupBySubstitutionMap.put(expr, globalOutputExpr.toSlot()); + globalOutputSMap.put(globalOutputExpr, globalOutputExpr.toSlot()); + globalGroupBySMap.put(expr, globalOutputExpr.toSlot()); } } distinctExprsForLocalGroupBy.add(expr); @@ -196,7 +204,7 @@ public class AggregateDisassemble extends OneRewriteRuleFactory { Expression substitutionValue = aggregateFunction.withChildren( Lists.newArrayList(localOutputExpr.toSlot())); inputSubstitutionMap.put(aggregateFunction, substitutionValue); - globalOutputSubstitutionMap.put(aggregateFunction, substitutionValue); + globalOutputSMap.put(aggregateFunction, substitutionValue); localOutputExprs.add(localOutputExpr); } } @@ -222,7 +230,7 @@ public class AggregateDisassemble extends OneRewriteRuleFactory { AggPhase.LOCAL, aggregate.child() ); - return new LogicalAggregate<>( + return Pair.of(new LogicalAggregate<>( globalGroupByExprs, globalOutputExprs, true, @@ -230,6 +238,6 @@ public class AggregateDisassemble extends OneRewriteRuleFactory { true, AggPhase.GLOBAL, localAggregate - ); + ), hasDistinct); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjects.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjects.java index 7738ee2772..626d6d22cd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjects.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjects.java @@ -56,6 +56,14 @@ public class MergeConsecutiveProjects extends OneRewriteRuleFactory { private static class ExpressionReplacer extends DefaultExpressionRewriter> { public static final ExpressionReplacer INSTANCE = new ExpressionReplacer(); + public Expression replace(Expression expr, Map substitutionMap) { + if (expr instanceof SlotReference) { + Slot ref = ((SlotReference) expr).withQualifier(Collections.emptyList()); + return substitutionMap.getOrDefault(ref, expr); + } + return visit(expr, substitutionMap); + } + /** * case 1: * project(alias(c) as d, alias(x) as y) @@ -84,7 +92,7 @@ public class MergeConsecutiveProjects extends OneRewriteRuleFactory { Slot ref = ((SlotReference) expr).withQualifier(Collections.emptyList()); if (substitutionMap.containsKey(ref)) { Alias res = (Alias) substitutionMap.get(ref); - return (res.child() instanceof SlotReference) ? res : res.child(); + return res.child(); } } else if (substitutionMap.containsKey(expr)) { return substitutionMap.get(expr).child(0); @@ -106,7 +114,7 @@ public class MergeConsecutiveProjects extends OneRewriteRuleFactory { ); projectExpressions = projectExpressions.stream() - .map(e -> MergeConsecutiveProjects.ExpressionReplacer.INSTANCE.visit(e, childAliasMap)) + .map(e -> MergeConsecutiveProjects.ExpressionReplacer.INSTANCE.replace(e, childAliasMap)) .map(NamedExpression.class::cast) .collect(Collectors.toList()); return new LogicalProject<>(projectExpressions, childProject.children().get(0)); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java index 5a6243809e..a888e06ce4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java @@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; @@ -61,7 +62,7 @@ import java.util.stream.Collectors; public class NormalizeAggregate extends OneRewriteRuleFactory { @Override public Rule build() { - return logicalAggregate().when(aggregate -> !aggregate.isNormalized()).then(aggregate -> { + return logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> { // substitution map used to substitute expression in aggregate's output to use it as top projections Map substitutionMap = Maps.newHashMap(); List keys = aggregate.getGroupByExpressions(); @@ -102,25 +103,40 @@ public class NormalizeAggregate extends OneRewriteRuleFactory { List outputs = aggregate.getOutputExpressions(); Map> partitionedOutputs = outputs.stream() .collect(Collectors.groupingBy(e -> e.anyMatch(AggregateFunction.class::isInstance))); + + boolean needBottomProjects = partitionedKeys.containsKey(false); if (partitionedOutputs.containsKey(true)) { // process expressions that contain aggregate function Set aggregateFunctions = partitionedOutputs.get(true).stream() .flatMap(e -> e.>collect(AggregateFunction.class::isInstance).stream()) .collect(Collectors.toSet()); - newOutputs.addAll(aggregateFunctions.stream() - .map(f -> new Alias(f, f.toSql())) - .peek(a -> substitutionMap.put(a.child(), a.toSlot())) - .collect(Collectors.toList())); - // add slot references in aggregate function to bottom projections - bottomProjections.addAll(aggregateFunctions.stream() - .flatMap(f -> f.>collect(SlotReference.class::isInstance).stream()) - .map(SlotReference.class::cast) - .collect(Collectors.toSet())); + + // replace all non-slot expression in aggregate functions children. + for (AggregateFunction aggregateFunction : aggregateFunctions) { + List newChildren = Lists.newArrayList(); + for (Expression child : aggregateFunction.getArguments()) { + if (child instanceof SlotReference || child instanceof Literal) { + newChildren.add(child); + if (child instanceof SlotReference) { + bottomProjections.add((SlotReference) child); + } + } else { + needBottomProjects = true; + Alias alias = new Alias(child, child.toSql()); + bottomProjections.add(alias); + newChildren.add(alias.toSlot()); + } + } + AggregateFunction newFunction = (AggregateFunction) aggregateFunction.withChildren(newChildren); + Alias alias = new Alias(newFunction, newFunction.toSql()); + newOutputs.add(alias); + substitutionMap.put(aggregateFunction, alias.toSlot()); + } } // assemble LogicalPlan root = aggregate.child(); - if (partitionedKeys.containsKey(false)) { + if (needBottomProjects) { root = new LogicalProject<>(bottomProjections, root); } root = new LogicalAggregate<>(newKeys, newOutputs, aggregate.isDisassembled(), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Add.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Add.java index 34e807a1c5..13ccd1ca95 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Add.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Add.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions; import org.apache.doris.analysis.ArithmeticExpr.Operator; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.coercion.AbstractDataType; import org.apache.doris.nereids.types.coercion.NumericType; @@ -41,6 +42,11 @@ public class Add extends BinaryArithmetic { return new Add(children.get(0), children.get(1)); } + @Override + public DataType getDataType() { + return left().getDataType().promotion(); + } + @Override public R accept(ExpressionVisitor visitor, C context) { return visitor.visitAdd(this, context); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExecutableFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExecutableFunctions.java index 4c65d19883..2fab50e38e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExecutableFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ExecutableFunctions.java @@ -44,24 +44,24 @@ public class ExecutableFunctions { * Executable arithmetic functions */ - @ExecFunction(name = "add", argTypes = {"TINYINT", "TINYINT"}, returnType = "TINYINT") - public static TinyIntLiteral addTinyint(TinyIntLiteral first, TinyIntLiteral second) { - byte result = (byte) Math.addExact(first.getValue(), second.getValue()); - return new TinyIntLiteral(result); - } - - @ExecFunction(name = "add", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "SMALLINT") - public static SmallIntLiteral addSmallint(SmallIntLiteral first, SmallIntLiteral second) { + @ExecFunction(name = "add", argTypes = {"TINYINT", "TINYINT"}, returnType = "SMALLINT") + public static SmallIntLiteral addTinyint(TinyIntLiteral first, TinyIntLiteral second) { short result = (short) Math.addExact(first.getValue(), second.getValue()); return new SmallIntLiteral(result); } - @ExecFunction(name = "add", argTypes = {"INT", "INT"}, returnType = "INT") - public static IntegerLiteral addInt(IntegerLiteral first, IntegerLiteral second) { + @ExecFunction(name = "add", argTypes = {"SMALLINT", "SMALLINT"}, returnType = "INT") + public static IntegerLiteral addSmallint(SmallIntLiteral first, SmallIntLiteral second) { int result = Math.addExact(first.getValue(), second.getValue()); return new IntegerLiteral(result); } + @ExecFunction(name = "add", argTypes = {"INT", "INT"}, returnType = "BIGINT") + public static BigIntLiteral addInt(IntegerLiteral first, IntegerLiteral second) { + long result = Math.addExact(first.getValue(), second.getValue()); + return new BigIntLiteral(result); + } + @ExecFunction(name = "add", argTypes = {"BIGINT", "BIGINT"}, returnType = "BIGINT") public static BigIntLiteral addBigint(BigIntLiteral first, BigIntLiteral second) { long result = Math.addExact(first.getValue(), second.getValue()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java index 0cca04950d..06df06b92b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java @@ -60,9 +60,10 @@ public class LogicalAggregate extends LogicalUnary extends LogicalUnary groupByExpressions, List outputExpressions, CHILD_TYPE child) { - this(groupByExpressions, outputExpressions, false, false, true, AggPhase.GLOBAL, child); + this(groupByExpressions, outputExpressions, false, false, true, AggPhase.LOCAL, child); } public LogicalAggregate( diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java index 05d256f395..ca96dcceb3 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java @@ -34,6 +34,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.Min; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.TinyIntType; @@ -256,7 +257,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat NamedExpressionUtil.clear(); sql = "SELECT a1, SUM(a1 + a2) FROM t1 GROUP BY a1 HAVING SUM(a1 + a2 + 3) > 0"; - Alias sumA1A23 = new Alias(new ExprId(4), new Sum(new Add(new Add(a1, a2), new TinyIntLiteral((byte) 3))), + Alias sumA1A23 = new Alias(new ExprId(4), new Sum(new Add(new Add(a1, a2), new SmallIntLiteral((byte) 3))), "sum(((a1 + a2) + 3))"); PlanChecker.from(connectContext).analyze(sql) .matchesFromRoot( @@ -360,7 +361,7 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat ImmutableList.of("default_cluster:test_having", "t1") ); Alias pk1 = new Alias(new ExprId(7), new Add(pk, Literal.of((byte) 1)), "(pk + 1)"); - Alias pk11 = new Alias(new ExprId(8), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)"); + Alias pk11 = new Alias(new ExprId(8), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((short) 1)), "((pk + 1) + 1)"); Alias pk2 = new Alias(new ExprId(9), new Add(pk, Literal.of((byte) 2)), "(pk + 2)"); Alias sumA1 = new Alias(new ExprId(10), new Sum(a1), "SUM(a1)"); Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1, false), Literal.of(1L)), "(COUNT(a1) + 1)"); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java index c179823a92..c133bf8ed3 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java @@ -150,7 +150,7 @@ public class RequestPropertyDeriverTest { Lists.newArrayList(key), AggPhase.LOCAL, true, - true, + false, logicalProperties, groupPlan ); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/FoldConstantTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/FoldConstantTest.java index ec21777e20..4769edcf2c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/FoldConstantTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/FoldConstantTest.java @@ -171,9 +171,9 @@ public class FoldConstantTest { @Test public void testArithmeticFold() { executor = new ExpressionRuleExecutor(ImmutableList.of(TypeCoercion.INSTANCE, FoldConstantRuleOnFE.INSTANCE)); - assertRewrite("1 + 1", Literal.of((byte) 2)); + assertRewrite("1 + 1", Literal.of((short) 2)); assertRewrite("1 - 1", Literal.of((byte) 0)); - assertRewrite("100 + 100", Literal.of((byte) 200)); + assertRewrite("100 + 100", Literal.of((short) 200)); assertRewrite("1 - 2", Literal.of((byte) -1)); assertRewrite("1 - 2 > 1", "false"); @@ -284,9 +284,9 @@ public class FoldConstantTest { private void assertRewrite(String expression, String expected) { Map mem = Maps.newHashMap(); Expression needRewriteExpression = PARSER.parseExpression(expression); - needRewriteExpression = replaceUnboundSlot(needRewriteExpression, mem); + needRewriteExpression = typeCoercion(replaceUnboundSlot(needRewriteExpression, mem)); Expression expectedExpression = PARSER.parseExpression(expected); - expectedExpression = replaceUnboundSlot(expectedExpression, mem); + expectedExpression = typeCoercion(replaceUnboundSlot(expectedExpression, mem)); Expression rewrittenExpression = executor.rewrite(needRewriteExpression); Assertions.assertEquals(expectedExpression, rewrittenExpression); } @@ -320,6 +320,10 @@ public class FoldConstantTest { return hasNewChildren ? expression.withChildren(children) : expression; } + private Expression typeCoercion(Expression expression) { + return TypeCoercion.INSTANCE.visit(expression, null); + } + private DataType getType(char t) { switch (t) { case 'T': diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjectsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjectsTest.java index 0816e7401b..08c9af2946 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjectsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeConsecutiveProjectsTest.java @@ -89,7 +89,8 @@ public class MergeConsecutiveProjectsTest { relation); LogicalProject project2 = new LogicalProject<>( Lists.newArrayList( - new Alias(new Add(aliasRef, new IntegerLiteral(2)), "Y") + new Alias(new Add(aliasRef, new IntegerLiteral(2)), "Y"), + aliasRef ), project1); @@ -100,11 +101,12 @@ public class MergeConsecutiveProjectsTest { System.out.println(plan.treeString()); Assertions.assertTrue(plan instanceof LogicalProject); LogicalProject finalProject = (LogicalProject) plan; - Add finalExpression = new Add( + Add aPlus1Plus2 = new Add( new Add(colA, new IntegerLiteral(1)), new IntegerLiteral(2) ); - Assertions.assertEquals(1, finalProject.getProjects().size()); - Assertions.assertEquals(((Alias) finalProject.getProjects().get(0)).child(), finalExpression); + Assertions.assertEquals(2, finalProject.getProjects().size()); + Assertions.assertEquals(aPlus1Plus2, ((Alias) finalProject.getProjects().get(0)).child()); + Assertions.assertEquals(alias, finalProject.getProjects().get(1)); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java index 9efa81376c..ef819a66f6 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Multiply; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.plans.Plan; @@ -93,15 +94,17 @@ public class NormalizeAggregateTest implements PatternMatchSupported { * +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, age#3]) * * after rewrite: - * LogicalProject ((sum((id * 1))#5 + 2) AS `(sum((id * 1)) + 2)`#4) - * +--LogicalAggregate (phase: [GLOBAL], output: [sum((id#0 * 1)) AS `sum((id * 1))`#5], groupBy: [name#2]) - * +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, age#3]) + * LogicalProject ( projects=[(sum((id * 1))#6 + 2) AS `(sum((id * 1)) + 2)`#4] ) + * +--LogicalAggregate ( phase=LOCAL, outputExpr=[sum((id * 1)#5) AS `sum((id * 1))`#6], groupByExpr=[name#2] ) + * +--LogicalProject ( projects=[name#2, (id#0 * 1) AS `(id * 1)`#5] ) + * +--GroupPlan( GroupId#0 ) */ @Test public void testComplexFuncWithComplexOutputOfFunc() { NamedExpression key = rStudent.getOutput().get(2).toSlot(); List groupExpressionList = Lists.newArrayList(key); - Expression aggregateFunction = new Sum(new Multiply(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(1))); + Expression multiply = new Multiply(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(1)); + Expression aggregateFunction = new Sum(multiply); Expression complexOutput = new Add(aggregateFunction, new IntegerLiteral(2)); Alias output = new Alias(complexOutput, complexOutput.toSql()); List outputExpressionList = Lists.newArrayList(output); @@ -112,10 +115,14 @@ public class NormalizeAggregateTest implements PatternMatchSupported { .matchesFromRoot( logicalProject( logicalAggregate( - logicalOlapScan() + logicalProject( + logicalOlapScan() + ).when(project -> project.getProjects().size() == 2) + .when(project -> project.getProjects().get(0) instanceof SlotReference) + .when(project -> project.getProjects().get(1).child(0).equals(multiply)) ).when(FieldChecker.check("groupByExpressions", ImmutableList.of(key))) .when(aggregate -> aggregate.getOutputExpressions().size() == 1) - .when(aggregate -> aggregate.getOutputExpressions().get(0).child(0).equals(aggregateFunction)) + .when(aggregate -> aggregate.getOutputExpressions().get(0).child(0) instanceof AggregateFunction) ).when(project -> project.getProjects().size() == 1) .when(project -> project.getProjects().get(0) instanceof Alias) .when(project -> project.getProjects().get(0).getExprId().equals(output.getExprId())) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java index 508c6f45da..803523eb0c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java @@ -57,7 +57,7 @@ public class PlanToStringTest { new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList())), child); Assertions.assertTrue(plan.toString() - .matches("LogicalAggregate \\( phase=GLOBAL, outputExpr=\\[a#\\d+], groupByExpr=\\[] \\)")); + .matches("LogicalAggregate \\( phase=LOCAL, outputExpr=\\[a#\\d+], groupByExpr=\\[] \\)")); } @Test diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java index ce3c7392bf..60bba6ba38 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java @@ -25,6 +25,7 @@ import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleSet; +import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble; import org.apache.doris.nereids.rules.rewrite.logical.ApplyPullFilterOnAgg; import org.apache.doris.nereids.rules.rewrite.logical.ApplyPullFilterOnProjectUnderAgg; import org.apache.doris.nereids.rules.rewrite.logical.ExistsApplyToJoin; @@ -120,7 +121,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte new MockUp() { @Mock public List getExplorationRules() { - return Lists.newArrayList(); + return Lists.newArrayList(new AggregateDisassemble().build()); } };