[Fix](nereids) fix NormalizeRepeat, change the outputExpression rewrite logic (#34196)
In NormalizeRepeat, three parts of the outputExpression of LogicalRepeat need to be pushed down and outputted by bottom project: flattenGroupingSetExpr, argumentsOfGroupingScalarFunction, argumentsOfAggregateFunction. In the original code, use these three parts to rewrite the outputExpressions of LogicalRepeat to slots.This can cause problems in some cases, for example: ```sql SELECT ROUND( SUM(pk + 1) - 3) col_alias1, pk + 1 AS col_alias3 FROM table_20_undef_partitions2_keys3_properties4_distributed_by53 GROUP BY GROUPING SETS ((pk), ()) ; ``` The three parts expression needed to be pushed down are: pk, pk+1. The original code use pk+1 to rewrite the pk + 1 AS col_alias3 to slot. But the pk+1 is not in the list of grouping outputs, and then report error. This pr change the rewrite process, divide the expression needed to be pushed down into 2 parts: one is (flattenGroupingSetExpr) and the other one is (argumentsOfGroupingScalarFunction, argumentsOfAggregateFunction). and use the flattenGroupingSetExpr rewrite all LogicalRepeat outputExpressions, and use the argumentsOfGroupingScalarFunction, argumentsOfAggregateFunction to rewrite only the agg function arguments and the grouping scalar function. So, in the above sql, the pk + 1 AS col_alias3 will not be rewritten to slot, and can be computed.
This commit is contained in:
@ -43,19 +43,18 @@ import org.apache.doris.nereids.util.PlanUtils.CollectNonWindowedAggFuncs;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Maps;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.google.common.collect.Sets.SetView;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/** NormalizeRepeat
|
||||
@ -117,19 +116,34 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
|
||||
}
|
||||
|
||||
private LogicalAggregate<Plan> normalizeRepeat(LogicalRepeat<Plan> repeat) {
|
||||
Set<Expression> needToSlots = collectNeedToSlotExpressions(repeat);
|
||||
NormalizeToSlotContext context = buildContext(repeat, needToSlots);
|
||||
Set<Expression> needToSlotsGroupingExpr = collectNeedToSlotGroupingExpr(repeat);
|
||||
NormalizeToSlotContext groupingExprContext = buildContext(repeat, needToSlotsGroupingExpr);
|
||||
Map<Expression, NormalizeToSlotTriplet> groupingExprMap = groupingExprContext.getNormalizeToSlotMap();
|
||||
Set<Alias> existsAlias = getExistsAlias(repeat, groupingExprMap);
|
||||
Set<Expression> needToSlotsArgs = collectNeedToSlotArgsOfGroupingScalarFuncAndAggFunc(repeat);
|
||||
NormalizeToSlotContext argsContext = NormalizeToSlotContext.buildContext(existsAlias, needToSlotsArgs);
|
||||
|
||||
// normalize grouping sets to List<List<Slot>>
|
||||
List<List<Slot>> normalizedGroupingSets = repeat.getGroupingSets()
|
||||
.stream()
|
||||
.map(groupingSet -> (List<Slot>) (List) context.normalizeToUseSlotRef(groupingSet))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
ImmutableList.Builder<List<Slot>> normalizedGroupingSetBuilder = ImmutableList.builder();
|
||||
for (List<Expression> groupingSet : repeat.getGroupingSets()) {
|
||||
List<Slot> normalizedSet = (List<Slot>) (List) groupingExprContext.normalizeToUseSlotRef(groupingSet);
|
||||
normalizedGroupingSetBuilder.add(normalizedSet);
|
||||
}
|
||||
List<List<Slot>> normalizedGroupingSets = normalizedGroupingSetBuilder.build();
|
||||
|
||||
// replace the arguments of grouping scalar function to virtual slots
|
||||
// replace some complex expression to slot, e.g. `a + 1`
|
||||
List<NamedExpression> normalizedAggOutput = context.normalizeToUseSlotRef(
|
||||
repeat.getOutputExpressions(), this::normalizeGroupingScalarFunction);
|
||||
// use argsContext
|
||||
// rewrite the arguments of grouping scalar function to slots
|
||||
// rewrite grouping scalar function to virtual slots
|
||||
// rewrite the arguments of agg function to slots
|
||||
List<NamedExpression> normalizedAggOutput = Lists.newArrayList();
|
||||
for (Expression expr : repeat.getOutputExpressions()) {
|
||||
Expression rewrittenExpr = expr.rewriteDownShortCircuit(
|
||||
e -> normalizeAggFuncChildrenAndGroupingScalarFunc(argsContext, e));
|
||||
normalizedAggOutput.add((NamedExpression) rewrittenExpr);
|
||||
}
|
||||
|
||||
// use groupingExprContext rewrite the normalizedAggOutput
|
||||
normalizedAggOutput = groupingExprContext.normalizeToUseSlotRef(normalizedAggOutput);
|
||||
|
||||
Set<VirtualSlotReference> virtualSlotsInFunction =
|
||||
ExpressionUtils.collect(normalizedAggOutput, VirtualSlotReference.class::isInstance);
|
||||
@ -156,7 +170,12 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
|
||||
.addAll(allVirtualSlots)
|
||||
.build();
|
||||
|
||||
Set<NamedExpression> pushedProject = context.pushDownToNamedExpression(needToSlots);
|
||||
// 3 parts need push down:
|
||||
// flattenGroupingSetExpr, argumentsOfGroupingScalarFunction, argumentsOfAggregateFunction
|
||||
Set<Expression> needToSlots = Sets.union(needToSlotsArgs, needToSlotsGroupingExpr);
|
||||
NormalizeToSlotContext fullContext = argsContext.mergeContext(groupingExprContext);
|
||||
Set<NamedExpression> pushedProject = fullContext.pushDownToNamedExpression(needToSlots);
|
||||
|
||||
Plan normalizedChild = pushDownProject(pushedProject, repeat.child());
|
||||
|
||||
LogicalRepeat<Plan> normalizedRepeat = repeat.withNormalizedExpr(
|
||||
@ -170,42 +189,43 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
|
||||
Optional.of(normalizedRepeat), normalizedRepeat);
|
||||
}
|
||||
|
||||
private Set<Expression> collectNeedToSlotExpressions(LogicalRepeat<Plan> repeat) {
|
||||
// 3 parts need push down:
|
||||
// flattenGroupingSetExpr, argumentsOfGroupingScalarFunction, argumentsOfAggregateFunction
|
||||
|
||||
Set<Expression> flattenGroupingSetExpr = ImmutableSet.copyOf(
|
||||
private Set<Expression> collectNeedToSlotGroupingExpr(LogicalRepeat<Plan> repeat) {
|
||||
// grouping sets should be pushed down, e.g. grouping sets((k + 1)),
|
||||
// we should push down the `k + 1` to the bottom plan
|
||||
return ImmutableSet.copyOf(
|
||||
ExpressionUtils.flatExpressions(repeat.getGroupingSets()));
|
||||
}
|
||||
|
||||
private Set<Expression> collectNeedToSlotArgsOfGroupingScalarFuncAndAggFunc(LogicalRepeat<Plan> repeat) {
|
||||
Set<GroupingScalarFunction> groupingScalarFunctions = ExpressionUtils.collect(
|
||||
repeat.getOutputExpressions(), GroupingScalarFunction.class::isInstance);
|
||||
|
||||
ImmutableSet<Expression> argumentsOfGroupingScalarFunction = groupingScalarFunctions.stream()
|
||||
.flatMap(function -> function.getArguments().stream())
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
ImmutableSet.Builder<Expression> argumentsSetBuilder = ImmutableSet.builder();
|
||||
for (GroupingScalarFunction function : groupingScalarFunctions) {
|
||||
argumentsSetBuilder.addAll(function.getArguments());
|
||||
}
|
||||
ImmutableSet<Expression> argumentsOfGroupingScalarFunction = argumentsSetBuilder.build();
|
||||
|
||||
List<AggregateFunction> aggregateFunctions = CollectNonWindowedAggFuncs.collect(repeat.getOutputExpressions());
|
||||
ImmutableSet.Builder<Expression> argumentsOfAggregateFunctionBuilder = ImmutableSet.builder();
|
||||
for (AggregateFunction function : aggregateFunctions) {
|
||||
for (Expression arg : function.getArguments()) {
|
||||
if (arg instanceof OrderExpression) {
|
||||
argumentsOfAggregateFunctionBuilder.add(arg.child(0));
|
||||
} else {
|
||||
argumentsOfAggregateFunctionBuilder.add(arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
ImmutableSet<Expression> argumentsOfAggregateFunction = argumentsOfAggregateFunctionBuilder.build();
|
||||
|
||||
ImmutableSet<Expression> argumentsOfAggregateFunction = aggregateFunctions.stream()
|
||||
.flatMap(function -> function.getArguments().stream().map(arg -> {
|
||||
if (arg instanceof OrderExpression) {
|
||||
return arg.child(0);
|
||||
} else {
|
||||
return arg;
|
||||
}
|
||||
}))
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
|
||||
ImmutableSet<Expression> needPushDown = ImmutableSet.<Expression>builder()
|
||||
return ImmutableSet.<Expression>builder()
|
||||
// grouping sets should be pushed down, e.g. grouping sets((k + 1)),
|
||||
// we should push down the `k + 1` to the bottom plan
|
||||
.addAll(flattenGroupingSetExpr)
|
||||
// e.g. grouping_id(k + 1), we should push down the `k + 1` to the bottom plan
|
||||
.addAll(argumentsOfGroupingScalarFunction)
|
||||
// e.g. sum(k + 1), we should push down the `k + 1` to the bottom plan
|
||||
.addAll(argumentsOfAggregateFunction)
|
||||
.build();
|
||||
return needPushDown;
|
||||
}
|
||||
|
||||
private Plan pushDownProject(Set<NamedExpression> pushedExprs, Plan originBottomPlan) {
|
||||
@ -250,8 +270,13 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
|
||||
return Optional.of(new NormalizeToSlotTriplet(expression, newSlot, originTriplet.pushedExpr));
|
||||
}
|
||||
|
||||
private Expression normalizeGroupingScalarFunction(NormalizeToSlotContext context, Expression expr) {
|
||||
if (expr instanceof GroupingScalarFunction) {
|
||||
private Expression normalizeAggFuncChildrenAndGroupingScalarFunc(NormalizeToSlotContext context, Expression expr) {
|
||||
if (expr instanceof AggregateFunction) {
|
||||
AggregateFunction function = (AggregateFunction) expr;
|
||||
List<Expression> normalizedRealExpressions = context.normalizeToUseSlotRef(function.getArguments());
|
||||
function = function.withChildren(normalizedRealExpressions);
|
||||
return function;
|
||||
} else if (expr instanceof GroupingScalarFunction) {
|
||||
GroupingScalarFunction function = (GroupingScalarFunction) expr;
|
||||
List<Expression> normalizedRealExpressions = context.normalizeToUseSlotRef(function.getArguments());
|
||||
function = function.withChildren(normalizedRealExpressions);
|
||||
@ -262,6 +287,20 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
|
||||
}
|
||||
}
|
||||
|
||||
private Set<Alias> getExistsAlias(LogicalRepeat<Plan> repeat,
|
||||
Map<Expression, NormalizeToSlotTriplet> groupingExprMap) {
|
||||
Set<Alias> existsAlias = Sets.newHashSet();
|
||||
Set<Alias> aliases = ExpressionUtils.collect(repeat.getOutputExpressions(), Alias.class::isInstance);
|
||||
existsAlias.addAll(aliases);
|
||||
for (NormalizeToSlotTriplet triplet : groupingExprMap.values()) {
|
||||
if (triplet.pushedExpr instanceof Alias) {
|
||||
Alias alias = (Alias) triplet.pushedExpr;
|
||||
existsAlias.add(alias);
|
||||
}
|
||||
}
|
||||
return existsAlias;
|
||||
}
|
||||
|
||||
/*
|
||||
* compute slots that appear both in agg func and grouping sets,
|
||||
* copy the slots and output in the project below the repeat as new copied slots,
|
||||
@ -278,56 +317,77 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
|
||||
private LogicalAggregate<Plan> dealSlotAppearBothInAggFuncAndGroupingSets(
|
||||
@NotNull LogicalAggregate<Plan> aggregate) {
|
||||
LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
|
||||
|
||||
List<AggregateFunction> aggregateFunctions =
|
||||
CollectNonWindowedAggFuncs.collect(aggregate.getOutputExpressions());
|
||||
Set<Slot> aggUsedSlots = aggregateFunctions.stream()
|
||||
.flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
Set<Slot> groupingSetsUsedSlot = repeat.getGroupingSets().stream()
|
||||
.flatMap(Collection::stream)
|
||||
.flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
Set<Slot> resSet = new HashSet<>(aggUsedSlots);
|
||||
resSet.retainAll(groupingSetsUsedSlot);
|
||||
if (resSet.isEmpty()) {
|
||||
Map<Slot, Alias> commonSlotToAliasMap = getCommonSlotToAliasMap(repeat, aggregate);
|
||||
if (commonSlotToAliasMap.isEmpty()) {
|
||||
return aggregate;
|
||||
}
|
||||
Map<Slot, Alias> slotMapping = resSet.stream().collect(
|
||||
Collectors.toMap(key -> key, Alias::new)
|
||||
);
|
||||
Set<Alias> newAliases = new HashSet<>(slotMapping.values());
|
||||
List<Slot> newSlots = newAliases.stream()
|
||||
.map(Alias::toSlot)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// modify repeat child to a new project with more projections
|
||||
Set<Alias> newAliases = new HashSet<>(commonSlotToAliasMap.values());
|
||||
List<Slot> originSlots = repeat.child().getOutput();
|
||||
ImmutableList<NamedExpression> immList =
|
||||
ImmutableList<NamedExpression> newProjects =
|
||||
ImmutableList.<NamedExpression>builder().addAll(originSlots).addAll(newAliases).build();
|
||||
LogicalProject<Plan> newProject = new LogicalProject<>(immList, repeat.child());
|
||||
repeat = repeat.withChildren(ImmutableList.of(newProject));
|
||||
LogicalProject<Plan> newLogicalProject = new LogicalProject<>(newProjects, repeat.child());
|
||||
repeat = repeat.withChildren(ImmutableList.of(newLogicalProject));
|
||||
|
||||
// modify repeat outputs
|
||||
List<Slot> originRepeatSlots = repeat.getOutput();
|
||||
repeat = repeat.withAggOutput(ImmutableList
|
||||
.<NamedExpression>builder()
|
||||
.addAll(originRepeatSlots.stream().filter(slot -> ! (slot instanceof VirtualSlotReference))
|
||||
.collect(Collectors.toList()))
|
||||
List<Slot> virtualSlots = Lists.newArrayList();
|
||||
List<Slot> nonVirtualSlots = Lists.newArrayList();
|
||||
for (Slot slot : originRepeatSlots) {
|
||||
if (slot instanceof VirtualSlotReference) {
|
||||
virtualSlots.add(slot);
|
||||
} else {
|
||||
nonVirtualSlots.add(slot);
|
||||
}
|
||||
}
|
||||
List<Slot> newSlots = Lists.newArrayList();
|
||||
for (Alias alias : newAliases) {
|
||||
newSlots.add(alias.toSlot());
|
||||
}
|
||||
repeat = repeat.withAggOutput(ImmutableList.<NamedExpression>builder()
|
||||
.addAll(nonVirtualSlots)
|
||||
.addAll(newSlots)
|
||||
.addAll(originRepeatSlots.stream().filter(slot -> (slot instanceof VirtualSlotReference))
|
||||
.collect(Collectors.toList()))
|
||||
.addAll(virtualSlots)
|
||||
.build());
|
||||
aggregate = aggregate.withChildren(ImmutableList.of(repeat));
|
||||
|
||||
List<NamedExpression> newOutputExpressions = aggregate.getOutputExpressions().stream()
|
||||
.map(e -> (NamedExpression) e.accept(RewriteAggFuncWithoutWindowAggFunc.INSTANCE,
|
||||
slotMapping))
|
||||
.collect(Collectors.toList());
|
||||
ImmutableList.Builder<NamedExpression> newOutputExpressionBuilder = ImmutableList.builder();
|
||||
for (NamedExpression expression : aggregate.getOutputExpressions()) {
|
||||
NamedExpression newExpression = (NamedExpression) expression
|
||||
.accept(RewriteAggFuncWithoutWindowAggFunc.INSTANCE, commonSlotToAliasMap);
|
||||
newOutputExpressionBuilder.add(newExpression);
|
||||
}
|
||||
List<NamedExpression> newOutputExpressions = newOutputExpressionBuilder.build();
|
||||
return aggregate.withAggOutput(newOutputExpressions);
|
||||
}
|
||||
|
||||
private Map<Slot, Alias> getCommonSlotToAliasMap(LogicalRepeat<Plan> repeat, LogicalAggregate<Plan> aggregate) {
|
||||
List<AggregateFunction> aggregateFunctions =
|
||||
CollectNonWindowedAggFuncs.collect(aggregate.getOutputExpressions());
|
||||
ImmutableSet.Builder<Slot> aggUsedSlotBuilder = ImmutableSet.builder();
|
||||
for (AggregateFunction function : aggregateFunctions) {
|
||||
aggUsedSlotBuilder.addAll(function.<Set<SlotReference>>collect(SlotReference.class::isInstance));
|
||||
}
|
||||
ImmutableSet<Slot> aggUsedSlots = aggUsedSlotBuilder.build();
|
||||
|
||||
ImmutableSet.Builder<Slot> groupingSetsUsedSlotBuilder = ImmutableSet.builder();
|
||||
for (List<Expression> groupingSet : repeat.getGroupingSets()) {
|
||||
for (Expression expr : groupingSet) {
|
||||
groupingSetsUsedSlotBuilder.addAll(expr.<Set<SlotReference>>collect(SlotReference.class::isInstance));
|
||||
}
|
||||
}
|
||||
ImmutableSet<Slot> groupingSetsUsedSlot = groupingSetsUsedSlotBuilder.build();
|
||||
|
||||
Set<Slot> resSet = new HashSet<>(aggUsedSlots);
|
||||
resSet.retainAll(groupingSetsUsedSlot);
|
||||
Map<Slot, Alias> commonSlotToAliasMap = Maps.newHashMap();
|
||||
for (Slot key : resSet) {
|
||||
Alias alias = new Alias(key);
|
||||
commonSlotToAliasMap.put(key, alias);
|
||||
}
|
||||
return commonSlotToAliasMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* This class use the map(slotMapping) to rewrite all slots in trival-agg.
|
||||
* The purpose of this class is to only rewrite the slots in trival-agg and not to rewrite the slots in window-agg.
|
||||
|
||||
@ -51,6 +51,17 @@ public interface NormalizeToSlot {
|
||||
this.normalizeToSlotMap = normalizeToSlotMap;
|
||||
}
|
||||
|
||||
public Map<Expression, NormalizeToSlotTriplet> getNormalizeToSlotMap() {
|
||||
return normalizeToSlotMap;
|
||||
}
|
||||
|
||||
public NormalizeToSlotContext mergeContext(NormalizeToSlotContext context) {
|
||||
Map<Expression, NormalizeToSlotTriplet> newMap = Maps.newHashMap();
|
||||
newMap.putAll(this.normalizeToSlotMap);
|
||||
newMap.putAll(context.getNormalizeToSlotMap());
|
||||
return new NormalizeToSlotContext(newMap);
|
||||
}
|
||||
|
||||
/**
|
||||
* build normalization context by follow step.
|
||||
* 1. collect all exists alias by input parameters existsAliases build a reverted map: expr -> alias
|
||||
|
||||
@ -145,4 +145,5 @@ public class BuiltinFunctionBuilder extends FunctionBuilder {
|
||||
.map(constructor -> new BuiltinFunctionBuilder(functionClass, (Constructor<BoundFunction>) constructor))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user