From ce47354d598b099fb0ebe2784b38978c13a4c808 Mon Sep 17 00:00:00 2001 From: morrySnow <101034200+morrySnow@users.noreply.github.com> Date: Mon, 22 Jan 2024 17:53:25 +0800 Subject: [PATCH] [fix](Nereids) result nullable of sum distinct in scalar agg is wrong (#30221) --- .../AdjustAggregateNullableForEmptySet.java | 3 +- .../implementation/AggregateStrategies.java | 44 +++++++------------ .../trees/expressions/functions/agg/Sum.java | 6 +++ .../nereids_syntax_p0/agg_with_empty_set.out | 3 ++ .../agg_with_empty_set.groovy | 1 + 5 files changed, 26 insertions(+), 31 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java index 75400a8e6b..86a70d35cc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java @@ -90,8 +90,7 @@ public class AdjustAggregateNullableForEmptySet implements RewriteRuleFactory { @Override public Expression visitNullableAggregateFunction(NullableAggregateFunction nullableAggregateFunction, Boolean alwaysNullable) { - return nullableAggregateFunction.isDistinct() ? nullableAggregateFunction - : nullableAggregateFunction.withAlwaysNullable(alwaysNullable); + return nullableAggregateFunction.withAlwaysNullable(alwaysNullable); } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java index a0eb011ba9..c9907ae7c3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java @@ -50,7 +50,6 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.Min; import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctCount; -import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctSum; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.functions.scalar.If; import org.apache.doris.nereids.trees.expressions.literal.Literal; @@ -108,9 +107,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { logicalAggregate( logicalFilter( logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable) - ).when(filter -> filter.getConjuncts().size() > 0)) + ).when(filter -> !filter.getConjuncts().isEmpty())) .when(agg -> enablePushDownCountOnIndex()) - .when(agg -> agg.getGroupByExpressions().size() == 0) + .when(agg -> agg.getGroupByExpressions().isEmpty()) .when(agg -> { Set funcs = agg.getAggregateFunctions(); return !funcs.isEmpty() && funcs.stream() @@ -128,9 +127,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { logicalProject( logicalFilter( logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable) - ).when(filter -> filter.getConjuncts().size() > 0))) + ).when(filter -> !filter.getConjuncts().isEmpty()))) .when(agg -> enablePushDownCountOnIndex()) - .when(agg -> agg.getGroupByExpressions().size() == 0) + .when(agg -> agg.getGroupByExpressions().isEmpty()) .when(agg -> { Set funcs = agg.getAggregateFunctions(); return !funcs.isEmpty() && funcs.stream().allMatch(f -> f instanceof Count && !f.isDistinct()); @@ -154,7 +153,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { Expression childExpr = filter.getConjuncts().iterator().next().children().get(0); if (childExpr instanceof SlotReference) { Optional column = ((SlotReference) childExpr).getColumn(); - return column.isPresent() ? column.get().isDeleteSignColumn() : false; + return column.map(Column::isDeleteSignColumn).orElse(false); } return false; }) @@ -187,8 +186,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { .children().get(0); if (childExpr instanceof SlotReference) { Optional column = ((SlotReference) childExpr).getColumn(); - return column.isPresent() ? column.get().isDeleteSignColumn() - : false; + return column.map(Column::isDeleteSignColumn).orElse(false); } return false; })) @@ -253,12 +251,12 @@ public class AggregateStrategies implements ImplementationRuleFactory { ), RuleType.ONE_PHASE_AGGREGATE_WITHOUT_DISTINCT.build( basePattern - .when(agg -> agg.getDistinctArguments().size() == 0) + .when(agg -> agg.getDistinctArguments().isEmpty()) .thenApplyMulti(ctx -> onePhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext)) ), RuleType.TWO_PHASE_AGGREGATE_WITHOUT_DISTINCT.build( basePattern - .when(agg -> agg.getDistinctArguments().size() == 0) + .when(agg -> agg.getDistinctArguments().isEmpty()) .thenApplyMulti(ctx -> twoPhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext)) ), // RuleType.TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build( @@ -435,12 +433,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { boolean onlyContainsSlotOrNumericCastSlot = aggregateFunctions.stream() .map(ExpressionTrait::getArguments) .flatMap(List::stream) - .allMatch(argument -> { - if (argument instanceof SlotReference) { - return true; - } - return false; - }); + .allMatch(argument -> argument instanceof SlotReference); if (!onlyContainsSlotOrNumericCastSlot) { return false; } @@ -457,19 +450,13 @@ public class AggregateStrategies implements ImplementationRuleFactory { } onlyContainsSlotOrNumericCastSlot = argumentsOfAggregateFunction .stream() - .allMatch(argument -> { - if (argument instanceof SlotReference) { - return true; - } - return false; - }); + .allMatch(argument -> argument instanceof SlotReference); if (!onlyContainsSlotOrNumericCastSlot) { return false; } Set aggUsedSlots = ExpressionUtils.collect(argumentsOfAggregateFunction, SlotReference.class::isInstance); - List usedSlotInTable = (List) Project.findProject(aggUsedSlots, - outPutSlots); + List usedSlotInTable = (List) Project.findProject(aggUsedSlots, outPutSlots); for (SlotReference slot : usedSlotInTable) { Column column = slot.getColumn().get(); PrimitiveType colType = column.getType().getPrimitiveType(); @@ -630,7 +617,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { if (logicalScan instanceof LogicalOlapScan) { PhysicalOlapScan physicalScan = (PhysicalOlapScan) new LogicalOlapScanToPhysicalOlapScan() .build() - .transform((LogicalOlapScan) logicalScan, cascadesContext) + .transform(logicalScan, cascadesContext) .get(0); if (project != null) { @@ -647,7 +634,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { } else if (logicalScan instanceof LogicalFileScan) { PhysicalFileScan physicalScan = (PhysicalFileScan) new LogicalFileScanToPhysicalFileScan() .build() - .transform((LogicalFileScan) logicalScan, cascadesContext) + .transform(logicalScan, cascadesContext) .get(0); if (project != null) { return aggregate.withChildren(ImmutableList.of( @@ -1193,8 +1180,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { return new AggregateExpression(nonDistinct, AggregateParam.LOCAL_RESULT); } else { Alias alias = nonDistinctAggFunctionToAliasPhase1.get(outputChild); - return new AggregateExpression( - aggregateFunction, bufferToResultParam, alias.toSlot()); + return new AggregateExpression(aggregateFunction, bufferToResultParam, alias.toSlot()); } } else { return outputChild; @@ -1582,7 +1568,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { return new MultiDistinctCount(function.getArgument(0), function.getArguments().subList(1, function.arity()).toArray(new Expression[0])); } else if (function instanceof Sum && function.isDistinct()) { - return new MultiDistinctSum(function.getArgument(0)); + return ((Sum) function).convertToMultiDistinct(); } else if (function instanceof GroupConcat && function.isDistinct()) { return ((GroupConcat) function).convertToMultiDistinct(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java index f0dbd83958..0b00536d6a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java @@ -78,6 +78,12 @@ public class Sum extends NullableAggregateFunction super("sum", distinct, alwaysNullable, arg); } + public MultiDistinctSum convertToMultiDistinct() { + Preconditions.checkArgument(distinct, + "can't convert to multi_distinct_sum because there is no distinct args"); + return new MultiDistinctSum(false, alwaysNullable, child()); + } + @Override public void checkLegalityBeforeTypeCoercion() { DataType argType = child().getDataType(); diff --git a/regression-test/data/nereids_syntax_p0/agg_with_empty_set.out b/regression-test/data/nereids_syntax_p0/agg_with_empty_set.out index 1db851c70c..ffe8f93eb5 100644 --- a/regression-test/data/nereids_syntax_p0/agg_with_empty_set.out +++ b/regression-test/data/nereids_syntax_p0/agg_with_empty_set.out @@ -18,3 +18,6 @@ -- !select6 -- 0 \N \N \N \N +-- !ditinct_sum -- +\N + diff --git a/regression-test/suites/nereids_syntax_p0/agg_with_empty_set.groovy b/regression-test/suites/nereids_syntax_p0/agg_with_empty_set.groovy index 5fc117445a..bdcb4526a2 100644 --- a/regression-test/suites/nereids_syntax_p0/agg_with_empty_set.groovy +++ b/regression-test/suites/nereids_syntax_p0/agg_with_empty_set.groovy @@ -29,4 +29,5 @@ suite("agg_with_empty_set") { (select min(c_custkey) from customer)""" qt_select6 """select count(c_custkey), max(c_custkey), min(c_custkey), avg(c_custkey), sum(c_custkey) from customer where c_custkey < (select min(c_custkey) from customer) having min(c_custkey) is null""" + qt_ditinct_sum """select sum(distinct ifnull(c_custkey, 0)) from customer where 1 = 0""" } \ No newline at end of file