[fix](Nereids) result nullable of sum distinct in scalar agg is wrong (#30221)

This commit is contained in:
morrySnow
2024-01-22 17:53:25 +08:00
committed by yiguolei
parent 60ce22f15e
commit ce47354d59
5 changed files with 26 additions and 31 deletions

View File

@ -90,8 +90,7 @@ public class AdjustAggregateNullableForEmptySet implements RewriteRuleFactory {
@Override
public Expression visitNullableAggregateFunction(NullableAggregateFunction nullableAggregateFunction,
Boolean alwaysNullable) {
return nullableAggregateFunction.isDistinct() ? nullableAggregateFunction
: nullableAggregateFunction.withAlwaysNullable(alwaysNullable);
return nullableAggregateFunction.withAlwaysNullable(alwaysNullable);
}
}
}

View File

@ -50,7 +50,6 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctSum;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
@ -108,9 +107,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
logicalAggregate(
logicalFilter(
logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable)
).when(filter -> filter.getConjuncts().size() > 0))
).when(filter -> !filter.getConjuncts().isEmpty()))
.when(agg -> enablePushDownCountOnIndex())
.when(agg -> agg.getGroupByExpressions().size() == 0)
.when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
@ -128,9 +127,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
logicalProject(
logicalFilter(
logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable)
).when(filter -> filter.getConjuncts().size() > 0)))
).when(filter -> !filter.getConjuncts().isEmpty())))
.when(agg -> enablePushDownCountOnIndex())
.when(agg -> agg.getGroupByExpressions().size() == 0)
.when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream().allMatch(f -> f instanceof Count && !f.isDistinct());
@ -154,7 +153,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
Expression childExpr = filter.getConjuncts().iterator().next().children().get(0);
if (childExpr instanceof SlotReference) {
Optional<Column> column = ((SlotReference) childExpr).getColumn();
return column.isPresent() ? column.get().isDeleteSignColumn() : false;
return column.map(Column::isDeleteSignColumn).orElse(false);
}
return false;
})
@ -187,8 +186,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
.children().get(0);
if (childExpr instanceof SlotReference) {
Optional<Column> column = ((SlotReference) childExpr).getColumn();
return column.isPresent() ? column.get().isDeleteSignColumn()
: false;
return column.map(Column::isDeleteSignColumn).orElse(false);
}
return false;
}))
@ -253,12 +251,12 @@ public class AggregateStrategies implements ImplementationRuleFactory {
),
RuleType.ONE_PHASE_AGGREGATE_WITHOUT_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 0)
.when(agg -> agg.getDistinctArguments().isEmpty())
.thenApplyMulti(ctx -> onePhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext))
),
RuleType.TWO_PHASE_AGGREGATE_WITHOUT_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 0)
.when(agg -> agg.getDistinctArguments().isEmpty())
.thenApplyMulti(ctx -> twoPhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext))
),
// RuleType.TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
@ -435,12 +433,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
boolean onlyContainsSlotOrNumericCastSlot = aggregateFunctions.stream()
.map(ExpressionTrait::getArguments)
.flatMap(List::stream)
.allMatch(argument -> {
if (argument instanceof SlotReference) {
return true;
}
return false;
});
.allMatch(argument -> argument instanceof SlotReference);
if (!onlyContainsSlotOrNumericCastSlot) {
return false;
}
@ -457,19 +450,13 @@ public class AggregateStrategies implements ImplementationRuleFactory {
}
onlyContainsSlotOrNumericCastSlot = argumentsOfAggregateFunction
.stream()
.allMatch(argument -> {
if (argument instanceof SlotReference) {
return true;
}
return false;
});
.allMatch(argument -> argument instanceof SlotReference);
if (!onlyContainsSlotOrNumericCastSlot) {
return false;
}
Set<SlotReference> aggUsedSlots = ExpressionUtils.collect(argumentsOfAggregateFunction,
SlotReference.class::isInstance);
List<SlotReference> usedSlotInTable = (List<SlotReference>) Project.findProject(aggUsedSlots,
outPutSlots);
List<SlotReference> usedSlotInTable = (List<SlotReference>) Project.findProject(aggUsedSlots, outPutSlots);
for (SlotReference slot : usedSlotInTable) {
Column column = slot.getColumn().get();
PrimitiveType colType = column.getType().getPrimitiveType();
@ -630,7 +617,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
if (logicalScan instanceof LogicalOlapScan) {
PhysicalOlapScan physicalScan = (PhysicalOlapScan) new LogicalOlapScanToPhysicalOlapScan()
.build()
.transform((LogicalOlapScan) logicalScan, cascadesContext)
.transform(logicalScan, cascadesContext)
.get(0);
if (project != null) {
@ -647,7 +634,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
} else if (logicalScan instanceof LogicalFileScan) {
PhysicalFileScan physicalScan = (PhysicalFileScan) new LogicalFileScanToPhysicalFileScan()
.build()
.transform((LogicalFileScan) logicalScan, cascadesContext)
.transform(logicalScan, cascadesContext)
.get(0);
if (project != null) {
return aggregate.withChildren(ImmutableList.of(
@ -1193,8 +1180,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
return new AggregateExpression(nonDistinct, AggregateParam.LOCAL_RESULT);
} else {
Alias alias = nonDistinctAggFunctionToAliasPhase1.get(outputChild);
return new AggregateExpression(
aggregateFunction, bufferToResultParam, alias.toSlot());
return new AggregateExpression(aggregateFunction, bufferToResultParam, alias.toSlot());
}
} else {
return outputChild;
@ -1582,7 +1568,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
return new MultiDistinctCount(function.getArgument(0),
function.getArguments().subList(1, function.arity()).toArray(new Expression[0]));
} else if (function instanceof Sum && function.isDistinct()) {
return new MultiDistinctSum(function.getArgument(0));
return ((Sum) function).convertToMultiDistinct();
} else if (function instanceof GroupConcat && function.isDistinct()) {
return ((GroupConcat) function).convertToMultiDistinct();
}

View File

@ -78,6 +78,12 @@ public class Sum extends NullableAggregateFunction
super("sum", distinct, alwaysNullable, arg);
}
public MultiDistinctSum convertToMultiDistinct() {
Preconditions.checkArgument(distinct,
"can't convert to multi_distinct_sum because there is no distinct args");
return new MultiDistinctSum(false, alwaysNullable, child());
}
@Override
public void checkLegalityBeforeTypeCoercion() {
DataType argType = child().getDataType();