[feature](nereids) deal the slots that appear both in agg func and grouping sets (#31318)

this PR support slot appearing both in agg func and grouping sets.
sql like below:
select sum(a) from t group by grouping sets ((a)); 

Before this PR, Nereids throw exception like below:
col_int_undef_signed cannot both in select list and aggregate functions when using GROUPING SETS/CUBE/ROLLUP, please use union instead.

This PR removes the restriction and supports this situation.
This commit is contained in:
feiniaofeiafei
2024-02-26 19:59:51 +08:00
committed by yiguolei
parent dd229b77b1
commit 481d94c3fc
8 changed files with 248 additions and 80 deletions

View File

@ -23,7 +23,6 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext;
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotTriplet;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
@ -44,8 +43,10 @@ import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.collect.Sets.SetView;
import org.jetbrains.annotations.NotNull;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@ -80,35 +81,16 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
logicalRepeat(any()).when(LogicalRepeat::canBindVirtualSlot).then(repeat -> {
checkRepeatLegality(repeat);
// add virtual slot, LogicalAggregate and LogicalProject for normalize
return normalizeRepeat(repeat);
LogicalAggregate<Plan> agg = normalizeRepeat(repeat);
return dealSlotAppearBothInAggFuncAndGroupingSets(agg);
})
);
}
private void checkRepeatLegality(LogicalRepeat<Plan> repeat) {
checkIfAggFuncSlotInGroupingSets(repeat);
checkGroupingSetsSize(repeat);
}
private void checkIfAggFuncSlotInGroupingSets(LogicalRepeat<Plan> repeat) {
Set<Slot> aggUsedSlots = repeat.getOutputExpressions().stream()
.flatMap(e -> e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream())
.flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
.collect(ImmutableSet.toImmutableSet());
Set<ExprId> groupingSetsUsedSlotExprIds = repeat.getGroupingSets().stream()
.flatMap(Collection::stream)
.flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
.map(SlotReference::getExprId)
.collect(Collectors.toSet());
for (Slot slot : aggUsedSlots) {
if (groupingSetsUsedSlotExprIds.contains(slot.getExprId())) {
throw new AnalysisException("column: " + slot.toSql() + " cannot both in select "
+ "list and aggregate functions when using GROUPING SETS/CUBE/ROLLUP, "
+ "please use union instead.");
}
}
}
private void checkGroupingSetsSize(LogicalRepeat<Plan> repeat) {
Set<Expression> flattenGroupingSetExpr = ImmutableSet.copyOf(
ExpressionUtils.flatExpressions(repeat.getGroupingSets()));
@ -265,4 +247,78 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
return expr;
}
}
/*
* compute slots that appear both in agg func and grouping sets,
* copy the slots and output in the project below the repeat as new copied slots,
* and refer the new copied slots in aggregate parameters.
* eg: original plan after normalizedRepeat
* LogicalAggregate (groupByExpr=[a#0, GROUPING_ID#1], outputExpr=[a#0, GROUPING_ID#1, sum(a#0) as `sum(a)`#2])
* +--LogicalRepeat (groupingSets=[[a#0]], outputExpr=[a#0, GROUPING_ID#1]
* +--LogicalProject (projects =[a#0])
* After:
* LogicalAggregate (groupByExpr=[a#0, GROUPING_ID#1], outputExpr=[a#0, GROUPING_ID#1, sum(a#3) as `sum(a)`#2])
* +--LogicalRepeat (groupingSets=[[a#0]], outputExpr=[a#0, a#3, GROUPING_ID#1]
* +--LogicalProject (projects =[a#0, a#0 as `a`#3])
*/
private LogicalAggregate<Plan> dealSlotAppearBothInAggFuncAndGroupingSets(
@NotNull LogicalAggregate<Plan> aggregate) {
LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
Set<Slot> aggUsedSlots = aggregate.getOutputExpressions().stream()
.flatMap(e -> e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream())
.flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
.collect(ImmutableSet.toImmutableSet());
Set<Slot> groupingSetsUsedSlot = repeat.getGroupingSets().stream()
.flatMap(Collection::stream)
.flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
.collect(Collectors.toSet());
Set<Slot> resSet = new HashSet<>(aggUsedSlots);
resSet.retainAll(groupingSetsUsedSlot);
if (resSet.isEmpty()) {
return aggregate;
}
Map<Slot, Alias> slotMapping = resSet.stream().collect(
Collectors.toMap(key -> key, Alias::new)
);
Set<Alias> newAliases = new HashSet<>(slotMapping.values());
List<Slot> newSlots = newAliases.stream()
.map(Alias::toSlot)
.collect(Collectors.toList());
// modify repeat child to a new project with more projections
List<Slot> originSlots = repeat.child().getOutput();
ImmutableList<NamedExpression> immList =
ImmutableList.<NamedExpression>builder().addAll(originSlots).addAll(newAliases).build();
LogicalProject<Plan> newProject = new LogicalProject<>(immList, repeat.child());
repeat = repeat.withChildren(ImmutableList.of(newProject));
// modify repeat outputs
List<Slot> originRepeatSlots = repeat.getOutput();
repeat = repeat.withAggOutput(ImmutableList
.<NamedExpression>builder()
.addAll(originRepeatSlots.stream().filter(slot -> ! (slot instanceof VirtualSlotReference))
.collect(Collectors.toList()))
.addAll(newSlots)
.addAll(originRepeatSlots.stream().filter(slot -> (slot instanceof VirtualSlotReference))
.collect(Collectors.toList()))
.build());
aggregate = aggregate.withChildren(ImmutableList.of(repeat));
// modify aggregate functions' parameter slot reference to new copied slots
List<NamedExpression> newOutputExpressions = aggregate.getOutputExpressions().stream()
.map(output -> (NamedExpression) output.rewriteDownShortCircuit(expr -> {
if (expr instanceof AggregateFunction) {
return expr.rewriteDownShortCircuit(e -> {
if (e instanceof Slot && slotMapping.containsKey(e)) {
return slotMapping.get(e).toSlot();
}
return e;
});
}
return expr;
})
).collect(Collectors.toList());
return aggregate.withAggOutput(newOutputExpressions);
}
}