[fix](Nereids) result nullable of sum distinct in scalar agg is wrong (#30221)
This commit is contained in:
@ -90,8 +90,7 @@ public class AdjustAggregateNullableForEmptySet implements RewriteRuleFactory {
|
||||
@Override
|
||||
public Expression visitNullableAggregateFunction(NullableAggregateFunction nullableAggregateFunction,
|
||||
Boolean alwaysNullable) {
|
||||
return nullableAggregateFunction.isDistinct() ? nullableAggregateFunction
|
||||
: nullableAggregateFunction.withAlwaysNullable(alwaysNullable);
|
||||
return nullableAggregateFunction.withAlwaysNullable(alwaysNullable);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -50,7 +50,6 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctCount;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctSum;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
@ -108,9 +107,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
logicalAggregate(
|
||||
logicalFilter(
|
||||
logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable)
|
||||
).when(filter -> filter.getConjuncts().size() > 0))
|
||||
).when(filter -> !filter.getConjuncts().isEmpty()))
|
||||
.when(agg -> enablePushDownCountOnIndex())
|
||||
.when(agg -> agg.getGroupByExpressions().size() == 0)
|
||||
.when(agg -> agg.getGroupByExpressions().isEmpty())
|
||||
.when(agg -> {
|
||||
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
|
||||
return !funcs.isEmpty() && funcs.stream()
|
||||
@ -128,9 +127,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
logicalProject(
|
||||
logicalFilter(
|
||||
logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable)
|
||||
).when(filter -> filter.getConjuncts().size() > 0)))
|
||||
).when(filter -> !filter.getConjuncts().isEmpty())))
|
||||
.when(agg -> enablePushDownCountOnIndex())
|
||||
.when(agg -> agg.getGroupByExpressions().size() == 0)
|
||||
.when(agg -> agg.getGroupByExpressions().isEmpty())
|
||||
.when(agg -> {
|
||||
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
|
||||
return !funcs.isEmpty() && funcs.stream().allMatch(f -> f instanceof Count && !f.isDistinct());
|
||||
@ -154,7 +153,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
Expression childExpr = filter.getConjuncts().iterator().next().children().get(0);
|
||||
if (childExpr instanceof SlotReference) {
|
||||
Optional<Column> column = ((SlotReference) childExpr).getColumn();
|
||||
return column.isPresent() ? column.get().isDeleteSignColumn() : false;
|
||||
return column.map(Column::isDeleteSignColumn).orElse(false);
|
||||
}
|
||||
return false;
|
||||
})
|
||||
@ -187,8 +186,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
.children().get(0);
|
||||
if (childExpr instanceof SlotReference) {
|
||||
Optional<Column> column = ((SlotReference) childExpr).getColumn();
|
||||
return column.isPresent() ? column.get().isDeleteSignColumn()
|
||||
: false;
|
||||
return column.map(Column::isDeleteSignColumn).orElse(false);
|
||||
}
|
||||
return false;
|
||||
}))
|
||||
@ -253,12 +251,12 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
),
|
||||
RuleType.ONE_PHASE_AGGREGATE_WITHOUT_DISTINCT.build(
|
||||
basePattern
|
||||
.when(agg -> agg.getDistinctArguments().size() == 0)
|
||||
.when(agg -> agg.getDistinctArguments().isEmpty())
|
||||
.thenApplyMulti(ctx -> onePhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext))
|
||||
),
|
||||
RuleType.TWO_PHASE_AGGREGATE_WITHOUT_DISTINCT.build(
|
||||
basePattern
|
||||
.when(agg -> agg.getDistinctArguments().size() == 0)
|
||||
.when(agg -> agg.getDistinctArguments().isEmpty())
|
||||
.thenApplyMulti(ctx -> twoPhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext))
|
||||
),
|
||||
// RuleType.TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
|
||||
@ -435,12 +433,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
boolean onlyContainsSlotOrNumericCastSlot = aggregateFunctions.stream()
|
||||
.map(ExpressionTrait::getArguments)
|
||||
.flatMap(List::stream)
|
||||
.allMatch(argument -> {
|
||||
if (argument instanceof SlotReference) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
.allMatch(argument -> argument instanceof SlotReference);
|
||||
if (!onlyContainsSlotOrNumericCastSlot) {
|
||||
return false;
|
||||
}
|
||||
@ -457,19 +450,13 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
}
|
||||
onlyContainsSlotOrNumericCastSlot = argumentsOfAggregateFunction
|
||||
.stream()
|
||||
.allMatch(argument -> {
|
||||
if (argument instanceof SlotReference) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
.allMatch(argument -> argument instanceof SlotReference);
|
||||
if (!onlyContainsSlotOrNumericCastSlot) {
|
||||
return false;
|
||||
}
|
||||
Set<SlotReference> aggUsedSlots = ExpressionUtils.collect(argumentsOfAggregateFunction,
|
||||
SlotReference.class::isInstance);
|
||||
List<SlotReference> usedSlotInTable = (List<SlotReference>) Project.findProject(aggUsedSlots,
|
||||
outPutSlots);
|
||||
List<SlotReference> usedSlotInTable = (List<SlotReference>) Project.findProject(aggUsedSlots, outPutSlots);
|
||||
for (SlotReference slot : usedSlotInTable) {
|
||||
Column column = slot.getColumn().get();
|
||||
PrimitiveType colType = column.getType().getPrimitiveType();
|
||||
@ -630,7 +617,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
if (logicalScan instanceof LogicalOlapScan) {
|
||||
PhysicalOlapScan physicalScan = (PhysicalOlapScan) new LogicalOlapScanToPhysicalOlapScan()
|
||||
.build()
|
||||
.transform((LogicalOlapScan) logicalScan, cascadesContext)
|
||||
.transform(logicalScan, cascadesContext)
|
||||
.get(0);
|
||||
|
||||
if (project != null) {
|
||||
@ -647,7 +634,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
} else if (logicalScan instanceof LogicalFileScan) {
|
||||
PhysicalFileScan physicalScan = (PhysicalFileScan) new LogicalFileScanToPhysicalFileScan()
|
||||
.build()
|
||||
.transform((LogicalFileScan) logicalScan, cascadesContext)
|
||||
.transform(logicalScan, cascadesContext)
|
||||
.get(0);
|
||||
if (project != null) {
|
||||
return aggregate.withChildren(ImmutableList.of(
|
||||
@ -1193,8 +1180,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
return new AggregateExpression(nonDistinct, AggregateParam.LOCAL_RESULT);
|
||||
} else {
|
||||
Alias alias = nonDistinctAggFunctionToAliasPhase1.get(outputChild);
|
||||
return new AggregateExpression(
|
||||
aggregateFunction, bufferToResultParam, alias.toSlot());
|
||||
return new AggregateExpression(aggregateFunction, bufferToResultParam, alias.toSlot());
|
||||
}
|
||||
} else {
|
||||
return outputChild;
|
||||
@ -1582,7 +1568,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
|
||||
return new MultiDistinctCount(function.getArgument(0),
|
||||
function.getArguments().subList(1, function.arity()).toArray(new Expression[0]));
|
||||
} else if (function instanceof Sum && function.isDistinct()) {
|
||||
return new MultiDistinctSum(function.getArgument(0));
|
||||
return ((Sum) function).convertToMultiDistinct();
|
||||
} else if (function instanceof GroupConcat && function.isDistinct()) {
|
||||
return ((GroupConcat) function).convertToMultiDistinct();
|
||||
}
|
||||
|
||||
@ -78,6 +78,12 @@ public class Sum extends NullableAggregateFunction
|
||||
super("sum", distinct, alwaysNullable, arg);
|
||||
}
|
||||
|
||||
public MultiDistinctSum convertToMultiDistinct() {
|
||||
Preconditions.checkArgument(distinct,
|
||||
"can't convert to multi_distinct_sum because there is no distinct args");
|
||||
return new MultiDistinctSum(false, alwaysNullable, child());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkLegalityBeforeTypeCoercion() {
|
||||
DataType argType = child().getDataType();
|
||||
|
||||
Reference in New Issue
Block a user