[Fix](nereids) Only rewrite the slots that appear both in trival-agg func and grouping sets (#31600)
* [Fix](nereids) Only rewrite the slots that appear both in trival-agg func and grouping sets * [Fix](nereids) Only rewrite the slots that appear both in trival-agg func and grouping sets --------- Co-authored-by: feiniaofeiafei <moailing@selectdb.com>
This commit is contained in:
@ -29,8 +29,10 @@ import org.apache.doris.nereids.trees.expressions.OrderExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.WindowExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.algebra.Repeat;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
@ -47,6 +49,7 @@ import com.google.common.collect.Sets;
|
||||
import com.google.common.collect.Sets.SetView;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@ -267,8 +270,11 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
|
||||
private LogicalAggregate<Plan> dealSlotAppearBothInAggFuncAndGroupingSets(
|
||||
@NotNull LogicalAggregate<Plan> aggregate) {
|
||||
LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
|
||||
Set<Slot> aggUsedSlots = aggregate.getOutputExpressions().stream()
|
||||
.flatMap(e -> e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream())
|
||||
|
||||
List<AggregateFunction> aggregateFunctions = Lists.newArrayList();
|
||||
aggregate.getOutputExpressions().forEach(
|
||||
o -> o.accept(PlanUtils.CollectNonWindowedAggFuncs.INSTANCE, aggregateFunctions));
|
||||
Set<Slot> aggUsedSlots = aggregateFunctions.stream()
|
||||
.flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
Set<Slot> groupingSetsUsedSlot = repeat.getGroupingSets().stream()
|
||||
@ -308,20 +314,68 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
|
||||
.build());
|
||||
aggregate = aggregate.withChildren(ImmutableList.of(repeat));
|
||||
|
||||
// modify aggregate functions' parameter slot reference to new copied slots
|
||||
List<NamedExpression> newOutputExpressions = aggregate.getOutputExpressions().stream()
|
||||
.map(output -> (NamedExpression) output.rewriteDownShortCircuit(expr -> {
|
||||
if (expr instanceof AggregateFunction) {
|
||||
return expr.rewriteDownShortCircuit(e -> {
|
||||
if (e instanceof Slot && slotMapping.containsKey(e)) {
|
||||
return slotMapping.get(e).toSlot();
|
||||
}
|
||||
return e;
|
||||
});
|
||||
}
|
||||
return expr;
|
||||
})
|
||||
).collect(Collectors.toList());
|
||||
.map(e -> (NamedExpression) e.accept(RewriteAggFuncWithoutWindowAggFunc.INSTANCE,
|
||||
slotMapping))
|
||||
.collect(Collectors.toList());
|
||||
return aggregate.withAggOutput(newOutputExpressions);
|
||||
}
|
||||
|
||||
/**
|
||||
* This class use the map(slotMapping) to rewrite all slots in trival-agg.
|
||||
* The purpose of this class is to only rewrite the slots in trival-agg and not to rewrite the slots in window-agg.
|
||||
*/
|
||||
private static class RewriteAggFuncWithoutWindowAggFunc
|
||||
extends DefaultExpressionRewriter<Map<Slot, Alias>> {
|
||||
|
||||
private static final RewriteAggFuncWithoutWindowAggFunc
|
||||
INSTANCE = new RewriteAggFuncWithoutWindowAggFunc();
|
||||
|
||||
private RewriteAggFuncWithoutWindowAggFunc() {}
|
||||
|
||||
@Override
|
||||
public Expression visitAggregateFunction(AggregateFunction aggregateFunction, Map<Slot, Alias> slotMapping) {
|
||||
return aggregateFunction.rewriteDownShortCircuit(e -> {
|
||||
if (e instanceof Slot && slotMapping.containsKey(e)) {
|
||||
return slotMapping.get(e).toSlot();
|
||||
}
|
||||
return e;
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitWindow(WindowExpression windowExpression, Map<Slot, Alias> slotMapping) {
|
||||
List<Expression> newChildren = new ArrayList<>();
|
||||
Expression function = windowExpression.getFunction();
|
||||
Expression oldFuncChild = function.child(0);
|
||||
boolean hasNewChildren = false;
|
||||
if (oldFuncChild != null) {
|
||||
Expression newFuncChild;
|
||||
newFuncChild = function.child(0).accept(this, slotMapping);
|
||||
hasNewChildren = (newFuncChild != oldFuncChild);
|
||||
newChildren.add(hasNewChildren
|
||||
? function.withChildren(ImmutableList.of(newFuncChild)) : function);
|
||||
} else {
|
||||
newChildren.add(function);
|
||||
}
|
||||
for (Expression partitionKey : windowExpression.getPartitionKeys()) {
|
||||
Expression newChild = partitionKey.accept(this, slotMapping);
|
||||
if (newChild != partitionKey) {
|
||||
hasNewChildren = true;
|
||||
}
|
||||
newChildren.add(newChild);
|
||||
}
|
||||
for (Expression orderKey : windowExpression.getOrderKeys()) {
|
||||
Expression newChild = orderKey.accept(this, slotMapping);
|
||||
if (newChild != orderKey) {
|
||||
hasNewChildren = true;
|
||||
}
|
||||
newChildren.add(newChild);
|
||||
}
|
||||
if (windowExpression.getWindowFrame().isPresent()) {
|
||||
newChildren.add(windowExpression.getWindowFrame().get());
|
||||
}
|
||||
return hasNewChildren ? windowExpression.withChildren(newChildren) : windowExpression;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user