[Fix](nereids) Only rewrite the slots that appear both in trival-agg func and grouping sets (#31600)

* [Fix](nereids) Only rewrite the slots that appear both in trival-agg func and grouping sets

* [Fix](nereids) Only rewrite the slots that appear both in trival-agg func and grouping sets

---------

Co-authored-by: feiniaofeiafei <moailing@selectdb.com>
This commit is contained in:
feiniaofeiafei
2024-02-29 22:17:43 +08:00
committed by yiguolei
parent ae926d0d8c
commit 2047b9416f
3 changed files with 160 additions and 15 deletions

View File

@ -29,8 +29,10 @@ import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Repeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
@ -47,6 +49,7 @@ import com.google.common.collect.Sets;
import com.google.common.collect.Sets.SetView;
import org.jetbrains.annotations.NotNull;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
@ -267,8 +270,11 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
private LogicalAggregate<Plan> dealSlotAppearBothInAggFuncAndGroupingSets(
@NotNull LogicalAggregate<Plan> aggregate) {
LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
Set<Slot> aggUsedSlots = aggregate.getOutputExpressions().stream()
.flatMap(e -> e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream())
List<AggregateFunction> aggregateFunctions = Lists.newArrayList();
aggregate.getOutputExpressions().forEach(
o -> o.accept(PlanUtils.CollectNonWindowedAggFuncs.INSTANCE, aggregateFunctions));
Set<Slot> aggUsedSlots = aggregateFunctions.stream()
.flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
.collect(ImmutableSet.toImmutableSet());
Set<Slot> groupingSetsUsedSlot = repeat.getGroupingSets().stream()
@ -308,20 +314,68 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
.build());
aggregate = aggregate.withChildren(ImmutableList.of(repeat));
// modify aggregate functions' parameter slot reference to new copied slots
List<NamedExpression> newOutputExpressions = aggregate.getOutputExpressions().stream()
.map(output -> (NamedExpression) output.rewriteDownShortCircuit(expr -> {
if (expr instanceof AggregateFunction) {
return expr.rewriteDownShortCircuit(e -> {
if (e instanceof Slot && slotMapping.containsKey(e)) {
return slotMapping.get(e).toSlot();
}
return e;
});
}
return expr;
})
).collect(Collectors.toList());
.map(e -> (NamedExpression) e.accept(RewriteAggFuncWithoutWindowAggFunc.INSTANCE,
slotMapping))
.collect(Collectors.toList());
return aggregate.withAggOutput(newOutputExpressions);
}
/**
* This class use the map(slotMapping) to rewrite all slots in trival-agg.
* The purpose of this class is to only rewrite the slots in trival-agg and not to rewrite the slots in window-agg.
*/
private static class RewriteAggFuncWithoutWindowAggFunc
extends DefaultExpressionRewriter<Map<Slot, Alias>> {
private static final RewriteAggFuncWithoutWindowAggFunc
INSTANCE = new RewriteAggFuncWithoutWindowAggFunc();
private RewriteAggFuncWithoutWindowAggFunc() {}
@Override
public Expression visitAggregateFunction(AggregateFunction aggregateFunction, Map<Slot, Alias> slotMapping) {
return aggregateFunction.rewriteDownShortCircuit(e -> {
if (e instanceof Slot && slotMapping.containsKey(e)) {
return slotMapping.get(e).toSlot();
}
return e;
});
}
@Override
public Expression visitWindow(WindowExpression windowExpression, Map<Slot, Alias> slotMapping) {
List<Expression> newChildren = new ArrayList<>();
Expression function = windowExpression.getFunction();
Expression oldFuncChild = function.child(0);
boolean hasNewChildren = false;
if (oldFuncChild != null) {
Expression newFuncChild;
newFuncChild = function.child(0).accept(this, slotMapping);
hasNewChildren = (newFuncChild != oldFuncChild);
newChildren.add(hasNewChildren
? function.withChildren(ImmutableList.of(newFuncChild)) : function);
} else {
newChildren.add(function);
}
for (Expression partitionKey : windowExpression.getPartitionKeys()) {
Expression newChild = partitionKey.accept(this, slotMapping);
if (newChild != partitionKey) {
hasNewChildren = true;
}
newChildren.add(newChild);
}
for (Expression orderKey : windowExpression.getOrderKeys()) {
Expression newChild = orderKey.accept(this, slotMapping);
if (newChild != orderKey) {
hasNewChildren = true;
}
newChildren.add(newChild);
}
if (windowExpression.getWindowFrame().isPresent()) {
newChildren.add(windowExpression.getWindowFrame().get());
}
return hasNewChildren ? windowExpression.withChildren(newChildren) : windowExpression;
}
}
}