[feature](nereids) deal the slots that appear both in agg func and grouping sets (#31318)
this PR support slot appearing both in agg func and grouping sets. sql like below: select sum(a) from t group by grouping sets ((a)); Before this PR, Nereids throw exception like below: col_int_undef_signed cannot both in select list and aggregate functions when using GROUPING SETS/CUBE/ROLLUP, please use union instead. This PR removes the restriction and supports this situation.
This commit is contained in:
@ -23,7 +23,6 @@ import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext;
|
||||
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotTriplet;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.OrderExpression;
|
||||
@ -44,8 +43,10 @@ import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Maps;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.google.common.collect.Sets.SetView;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
@ -80,35 +81,16 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
|
||||
logicalRepeat(any()).when(LogicalRepeat::canBindVirtualSlot).then(repeat -> {
|
||||
checkRepeatLegality(repeat);
|
||||
// add virtual slot, LogicalAggregate and LogicalProject for normalize
|
||||
return normalizeRepeat(repeat);
|
||||
LogicalAggregate<Plan> agg = normalizeRepeat(repeat);
|
||||
return dealSlotAppearBothInAggFuncAndGroupingSets(agg);
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
private void checkRepeatLegality(LogicalRepeat<Plan> repeat) {
|
||||
checkIfAggFuncSlotInGroupingSets(repeat);
|
||||
checkGroupingSetsSize(repeat);
|
||||
}
|
||||
|
||||
private void checkIfAggFuncSlotInGroupingSets(LogicalRepeat<Plan> repeat) {
|
||||
Set<Slot> aggUsedSlots = repeat.getOutputExpressions().stream()
|
||||
.flatMap(e -> e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream())
|
||||
.flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
Set<ExprId> groupingSetsUsedSlotExprIds = repeat.getGroupingSets().stream()
|
||||
.flatMap(Collection::stream)
|
||||
.flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
|
||||
.map(SlotReference::getExprId)
|
||||
.collect(Collectors.toSet());
|
||||
for (Slot slot : aggUsedSlots) {
|
||||
if (groupingSetsUsedSlotExprIds.contains(slot.getExprId())) {
|
||||
throw new AnalysisException("column: " + slot.toSql() + " cannot both in select "
|
||||
+ "list and aggregate functions when using GROUPING SETS/CUBE/ROLLUP, "
|
||||
+ "please use union instead.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void checkGroupingSetsSize(LogicalRepeat<Plan> repeat) {
|
||||
Set<Expression> flattenGroupingSetExpr = ImmutableSet.copyOf(
|
||||
ExpressionUtils.flatExpressions(repeat.getGroupingSets()));
|
||||
@ -265,4 +247,78 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
|
||||
return expr;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* compute slots that appear both in agg func and grouping sets,
|
||||
* copy the slots and output in the project below the repeat as new copied slots,
|
||||
* and refer the new copied slots in aggregate parameters.
|
||||
* eg: original plan after normalizedRepeat
|
||||
* LogicalAggregate (groupByExpr=[a#0, GROUPING_ID#1], outputExpr=[a#0, GROUPING_ID#1, sum(a#0) as `sum(a)`#2])
|
||||
* +--LogicalRepeat (groupingSets=[[a#0]], outputExpr=[a#0, GROUPING_ID#1]
|
||||
* +--LogicalProject (projects =[a#0])
|
||||
* After:
|
||||
* LogicalAggregate (groupByExpr=[a#0, GROUPING_ID#1], outputExpr=[a#0, GROUPING_ID#1, sum(a#3) as `sum(a)`#2])
|
||||
* +--LogicalRepeat (groupingSets=[[a#0]], outputExpr=[a#0, a#3, GROUPING_ID#1]
|
||||
* +--LogicalProject (projects =[a#0, a#0 as `a`#3])
|
||||
*/
|
||||
private LogicalAggregate<Plan> dealSlotAppearBothInAggFuncAndGroupingSets(
|
||||
@NotNull LogicalAggregate<Plan> aggregate) {
|
||||
LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
|
||||
Set<Slot> aggUsedSlots = aggregate.getOutputExpressions().stream()
|
||||
.flatMap(e -> e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream())
|
||||
.flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
Set<Slot> groupingSetsUsedSlot = repeat.getGroupingSets().stream()
|
||||
.flatMap(Collection::stream)
|
||||
.flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
Set<Slot> resSet = new HashSet<>(aggUsedSlots);
|
||||
resSet.retainAll(groupingSetsUsedSlot);
|
||||
if (resSet.isEmpty()) {
|
||||
return aggregate;
|
||||
}
|
||||
Map<Slot, Alias> slotMapping = resSet.stream().collect(
|
||||
Collectors.toMap(key -> key, Alias::new)
|
||||
);
|
||||
Set<Alias> newAliases = new HashSet<>(slotMapping.values());
|
||||
List<Slot> newSlots = newAliases.stream()
|
||||
.map(Alias::toSlot)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// modify repeat child to a new project with more projections
|
||||
List<Slot> originSlots = repeat.child().getOutput();
|
||||
ImmutableList<NamedExpression> immList =
|
||||
ImmutableList.<NamedExpression>builder().addAll(originSlots).addAll(newAliases).build();
|
||||
LogicalProject<Plan> newProject = new LogicalProject<>(immList, repeat.child());
|
||||
repeat = repeat.withChildren(ImmutableList.of(newProject));
|
||||
|
||||
// modify repeat outputs
|
||||
List<Slot> originRepeatSlots = repeat.getOutput();
|
||||
repeat = repeat.withAggOutput(ImmutableList
|
||||
.<NamedExpression>builder()
|
||||
.addAll(originRepeatSlots.stream().filter(slot -> ! (slot instanceof VirtualSlotReference))
|
||||
.collect(Collectors.toList()))
|
||||
.addAll(newSlots)
|
||||
.addAll(originRepeatSlots.stream().filter(slot -> (slot instanceof VirtualSlotReference))
|
||||
.collect(Collectors.toList()))
|
||||
.build());
|
||||
aggregate = aggregate.withChildren(ImmutableList.of(repeat));
|
||||
|
||||
// modify aggregate functions' parameter slot reference to new copied slots
|
||||
List<NamedExpression> newOutputExpressions = aggregate.getOutputExpressions().stream()
|
||||
.map(output -> (NamedExpression) output.rewriteDownShortCircuit(expr -> {
|
||||
if (expr instanceof AggregateFunction) {
|
||||
return expr.rewriteDownShortCircuit(e -> {
|
||||
if (e instanceof Slot && slotMapping.containsKey(e)) {
|
||||
return slotMapping.get(e).toSlot();
|
||||
}
|
||||
return e;
|
||||
});
|
||||
}
|
||||
return expr;
|
||||
})
|
||||
).collect(Collectors.toList());
|
||||
return aggregate.withAggOutput(newOutputExpressions);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user