[Fix](nereids) fix NormalizeRepeat, change the outputExpression rewrite logic (#34196)

In NormalizeRepeat, three parts of the outputExpression of LogicalRepeat need to be pushed down and outputted by bottom project: flattenGroupingSetExpr, argumentsOfGroupingScalarFunction, argumentsOfAggregateFunction.
In the original code, use these three parts to rewrite the outputExpressions of LogicalRepeat to slots.This can cause problems in some cases, for example:
```sql
SELECT
	ROUND( SUM(pk + 1) - 3) col_alias1,
	pk + 1 AS col_alias3 
	FROM
	table_20_undef_partitions2_keys3_properties4_distributed_by53
GROUP BY
	GROUPING SETS ((pk), ()) ;
```
The three parts expression needed to be pushed down are: pk, pk+1. The original code use pk+1 to rewrite the pk + 1 AS col_alias3  to slot. But the pk+1 is not in the list of grouping outputs, and then report error.
This pr change the rewrite process,  divide the expression needed to be pushed down  into 2 parts: one is (flattenGroupingSetExpr) and the other one is (argumentsOfGroupingScalarFunction, argumentsOfAggregateFunction).
 and use the flattenGroupingSetExpr rewrite all LogicalRepeat outputExpressions, and use the argumentsOfGroupingScalarFunction, argumentsOfAggregateFunction to rewrite only the agg function arguments and the grouping scalar function.
So, in the above sql, the pk + 1 AS col_alias3  will not be rewritten to slot, and can be computed.
This commit is contained in:
feiniaofeiafei
2024-05-08 17:06:40 +08:00
committed by yiguolei
parent c0cca6103b
commit 7c56c17ecc
5 changed files with 195 additions and 72 deletions

View File

@ -43,19 +43,18 @@ import org.apache.doris.nereids.util.PlanUtils.CollectNonWindowedAggFuncs;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.collect.Sets.SetView;
import org.jetbrains.annotations.NotNull;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
/** NormalizeRepeat
@ -117,19 +116,34 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
}
private LogicalAggregate<Plan> normalizeRepeat(LogicalRepeat<Plan> repeat) {
Set<Expression> needToSlots = collectNeedToSlotExpressions(repeat);
NormalizeToSlotContext context = buildContext(repeat, needToSlots);
Set<Expression> needToSlotsGroupingExpr = collectNeedToSlotGroupingExpr(repeat);
NormalizeToSlotContext groupingExprContext = buildContext(repeat, needToSlotsGroupingExpr);
Map<Expression, NormalizeToSlotTriplet> groupingExprMap = groupingExprContext.getNormalizeToSlotMap();
Set<Alias> existsAlias = getExistsAlias(repeat, groupingExprMap);
Set<Expression> needToSlotsArgs = collectNeedToSlotArgsOfGroupingScalarFuncAndAggFunc(repeat);
NormalizeToSlotContext argsContext = NormalizeToSlotContext.buildContext(existsAlias, needToSlotsArgs);
// normalize grouping sets to List<List<Slot>>
List<List<Slot>> normalizedGroupingSets = repeat.getGroupingSets()
.stream()
.map(groupingSet -> (List<Slot>) (List) context.normalizeToUseSlotRef(groupingSet))
.collect(ImmutableList.toImmutableList());
ImmutableList.Builder<List<Slot>> normalizedGroupingSetBuilder = ImmutableList.builder();
for (List<Expression> groupingSet : repeat.getGroupingSets()) {
List<Slot> normalizedSet = (List<Slot>) (List) groupingExprContext.normalizeToUseSlotRef(groupingSet);
normalizedGroupingSetBuilder.add(normalizedSet);
}
List<List<Slot>> normalizedGroupingSets = normalizedGroupingSetBuilder.build();
// replace the arguments of grouping scalar function to virtual slots
// replace some complex expression to slot, e.g. `a + 1`
List<NamedExpression> normalizedAggOutput = context.normalizeToUseSlotRef(
repeat.getOutputExpressions(), this::normalizeGroupingScalarFunction);
// use argsContext
// rewrite the arguments of grouping scalar function to slots
// rewrite grouping scalar function to virtual slots
// rewrite the arguments of agg function to slots
List<NamedExpression> normalizedAggOutput = Lists.newArrayList();
for (Expression expr : repeat.getOutputExpressions()) {
Expression rewrittenExpr = expr.rewriteDownShortCircuit(
e -> normalizeAggFuncChildrenAndGroupingScalarFunc(argsContext, e));
normalizedAggOutput.add((NamedExpression) rewrittenExpr);
}
// use groupingExprContext rewrite the normalizedAggOutput
normalizedAggOutput = groupingExprContext.normalizeToUseSlotRef(normalizedAggOutput);
Set<VirtualSlotReference> virtualSlotsInFunction =
ExpressionUtils.collect(normalizedAggOutput, VirtualSlotReference.class::isInstance);
@ -156,7 +170,12 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
.addAll(allVirtualSlots)
.build();
Set<NamedExpression> pushedProject = context.pushDownToNamedExpression(needToSlots);
// 3 parts need push down:
// flattenGroupingSetExpr, argumentsOfGroupingScalarFunction, argumentsOfAggregateFunction
Set<Expression> needToSlots = Sets.union(needToSlotsArgs, needToSlotsGroupingExpr);
NormalizeToSlotContext fullContext = argsContext.mergeContext(groupingExprContext);
Set<NamedExpression> pushedProject = fullContext.pushDownToNamedExpression(needToSlots);
Plan normalizedChild = pushDownProject(pushedProject, repeat.child());
LogicalRepeat<Plan> normalizedRepeat = repeat.withNormalizedExpr(
@ -170,42 +189,43 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
Optional.of(normalizedRepeat), normalizedRepeat);
}
private Set<Expression> collectNeedToSlotExpressions(LogicalRepeat<Plan> repeat) {
// 3 parts need push down:
// flattenGroupingSetExpr, argumentsOfGroupingScalarFunction, argumentsOfAggregateFunction
Set<Expression> flattenGroupingSetExpr = ImmutableSet.copyOf(
private Set<Expression> collectNeedToSlotGroupingExpr(LogicalRepeat<Plan> repeat) {
// grouping sets should be pushed down, e.g. grouping sets((k + 1)),
// we should push down the `k + 1` to the bottom plan
return ImmutableSet.copyOf(
ExpressionUtils.flatExpressions(repeat.getGroupingSets()));
}
private Set<Expression> collectNeedToSlotArgsOfGroupingScalarFuncAndAggFunc(LogicalRepeat<Plan> repeat) {
Set<GroupingScalarFunction> groupingScalarFunctions = ExpressionUtils.collect(
repeat.getOutputExpressions(), GroupingScalarFunction.class::isInstance);
ImmutableSet<Expression> argumentsOfGroupingScalarFunction = groupingScalarFunctions.stream()
.flatMap(function -> function.getArguments().stream())
.collect(ImmutableSet.toImmutableSet());
ImmutableSet.Builder<Expression> argumentsSetBuilder = ImmutableSet.builder();
for (GroupingScalarFunction function : groupingScalarFunctions) {
argumentsSetBuilder.addAll(function.getArguments());
}
ImmutableSet<Expression> argumentsOfGroupingScalarFunction = argumentsSetBuilder.build();
List<AggregateFunction> aggregateFunctions = CollectNonWindowedAggFuncs.collect(repeat.getOutputExpressions());
ImmutableSet.Builder<Expression> argumentsOfAggregateFunctionBuilder = ImmutableSet.builder();
for (AggregateFunction function : aggregateFunctions) {
for (Expression arg : function.getArguments()) {
if (arg instanceof OrderExpression) {
argumentsOfAggregateFunctionBuilder.add(arg.child(0));
} else {
argumentsOfAggregateFunctionBuilder.add(arg);
}
}
}
ImmutableSet<Expression> argumentsOfAggregateFunction = argumentsOfAggregateFunctionBuilder.build();
ImmutableSet<Expression> argumentsOfAggregateFunction = aggregateFunctions.stream()
.flatMap(function -> function.getArguments().stream().map(arg -> {
if (arg instanceof OrderExpression) {
return arg.child(0);
} else {
return arg;
}
}))
.collect(ImmutableSet.toImmutableSet());
ImmutableSet<Expression> needPushDown = ImmutableSet.<Expression>builder()
return ImmutableSet.<Expression>builder()
// grouping sets should be pushed down, e.g. grouping sets((k + 1)),
// we should push down the `k + 1` to the bottom plan
.addAll(flattenGroupingSetExpr)
// e.g. grouping_id(k + 1), we should push down the `k + 1` to the bottom plan
.addAll(argumentsOfGroupingScalarFunction)
// e.g. sum(k + 1), we should push down the `k + 1` to the bottom plan
.addAll(argumentsOfAggregateFunction)
.build();
return needPushDown;
}
private Plan pushDownProject(Set<NamedExpression> pushedExprs, Plan originBottomPlan) {
@ -250,8 +270,13 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
return Optional.of(new NormalizeToSlotTriplet(expression, newSlot, originTriplet.pushedExpr));
}
private Expression normalizeGroupingScalarFunction(NormalizeToSlotContext context, Expression expr) {
if (expr instanceof GroupingScalarFunction) {
private Expression normalizeAggFuncChildrenAndGroupingScalarFunc(NormalizeToSlotContext context, Expression expr) {
if (expr instanceof AggregateFunction) {
AggregateFunction function = (AggregateFunction) expr;
List<Expression> normalizedRealExpressions = context.normalizeToUseSlotRef(function.getArguments());
function = function.withChildren(normalizedRealExpressions);
return function;
} else if (expr instanceof GroupingScalarFunction) {
GroupingScalarFunction function = (GroupingScalarFunction) expr;
List<Expression> normalizedRealExpressions = context.normalizeToUseSlotRef(function.getArguments());
function = function.withChildren(normalizedRealExpressions);
@ -262,6 +287,20 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
}
}
private Set<Alias> getExistsAlias(LogicalRepeat<Plan> repeat,
Map<Expression, NormalizeToSlotTriplet> groupingExprMap) {
Set<Alias> existsAlias = Sets.newHashSet();
Set<Alias> aliases = ExpressionUtils.collect(repeat.getOutputExpressions(), Alias.class::isInstance);
existsAlias.addAll(aliases);
for (NormalizeToSlotTriplet triplet : groupingExprMap.values()) {
if (triplet.pushedExpr instanceof Alias) {
Alias alias = (Alias) triplet.pushedExpr;
existsAlias.add(alias);
}
}
return existsAlias;
}
/*
* compute slots that appear both in agg func and grouping sets,
* copy the slots and output in the project below the repeat as new copied slots,
@ -278,56 +317,77 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
private LogicalAggregate<Plan> dealSlotAppearBothInAggFuncAndGroupingSets(
@NotNull LogicalAggregate<Plan> aggregate) {
LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
List<AggregateFunction> aggregateFunctions =
CollectNonWindowedAggFuncs.collect(aggregate.getOutputExpressions());
Set<Slot> aggUsedSlots = aggregateFunctions.stream()
.flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
.collect(ImmutableSet.toImmutableSet());
Set<Slot> groupingSetsUsedSlot = repeat.getGroupingSets().stream()
.flatMap(Collection::stream)
.flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
.collect(Collectors.toSet());
Set<Slot> resSet = new HashSet<>(aggUsedSlots);
resSet.retainAll(groupingSetsUsedSlot);
if (resSet.isEmpty()) {
Map<Slot, Alias> commonSlotToAliasMap = getCommonSlotToAliasMap(repeat, aggregate);
if (commonSlotToAliasMap.isEmpty()) {
return aggregate;
}
Map<Slot, Alias> slotMapping = resSet.stream().collect(
Collectors.toMap(key -> key, Alias::new)
);
Set<Alias> newAliases = new HashSet<>(slotMapping.values());
List<Slot> newSlots = newAliases.stream()
.map(Alias::toSlot)
.collect(Collectors.toList());
// modify repeat child to a new project with more projections
Set<Alias> newAliases = new HashSet<>(commonSlotToAliasMap.values());
List<Slot> originSlots = repeat.child().getOutput();
ImmutableList<NamedExpression> immList =
ImmutableList<NamedExpression> newProjects =
ImmutableList.<NamedExpression>builder().addAll(originSlots).addAll(newAliases).build();
LogicalProject<Plan> newProject = new LogicalProject<>(immList, repeat.child());
repeat = repeat.withChildren(ImmutableList.of(newProject));
LogicalProject<Plan> newLogicalProject = new LogicalProject<>(newProjects, repeat.child());
repeat = repeat.withChildren(ImmutableList.of(newLogicalProject));
// modify repeat outputs
List<Slot> originRepeatSlots = repeat.getOutput();
repeat = repeat.withAggOutput(ImmutableList
.<NamedExpression>builder()
.addAll(originRepeatSlots.stream().filter(slot -> ! (slot instanceof VirtualSlotReference))
.collect(Collectors.toList()))
List<Slot> virtualSlots = Lists.newArrayList();
List<Slot> nonVirtualSlots = Lists.newArrayList();
for (Slot slot : originRepeatSlots) {
if (slot instanceof VirtualSlotReference) {
virtualSlots.add(slot);
} else {
nonVirtualSlots.add(slot);
}
}
List<Slot> newSlots = Lists.newArrayList();
for (Alias alias : newAliases) {
newSlots.add(alias.toSlot());
}
repeat = repeat.withAggOutput(ImmutableList.<NamedExpression>builder()
.addAll(nonVirtualSlots)
.addAll(newSlots)
.addAll(originRepeatSlots.stream().filter(slot -> (slot instanceof VirtualSlotReference))
.collect(Collectors.toList()))
.addAll(virtualSlots)
.build());
aggregate = aggregate.withChildren(ImmutableList.of(repeat));
List<NamedExpression> newOutputExpressions = aggregate.getOutputExpressions().stream()
.map(e -> (NamedExpression) e.accept(RewriteAggFuncWithoutWindowAggFunc.INSTANCE,
slotMapping))
.collect(Collectors.toList());
ImmutableList.Builder<NamedExpression> newOutputExpressionBuilder = ImmutableList.builder();
for (NamedExpression expression : aggregate.getOutputExpressions()) {
NamedExpression newExpression = (NamedExpression) expression
.accept(RewriteAggFuncWithoutWindowAggFunc.INSTANCE, commonSlotToAliasMap);
newOutputExpressionBuilder.add(newExpression);
}
List<NamedExpression> newOutputExpressions = newOutputExpressionBuilder.build();
return aggregate.withAggOutput(newOutputExpressions);
}
private Map<Slot, Alias> getCommonSlotToAliasMap(LogicalRepeat<Plan> repeat, LogicalAggregate<Plan> aggregate) {
List<AggregateFunction> aggregateFunctions =
CollectNonWindowedAggFuncs.collect(aggregate.getOutputExpressions());
ImmutableSet.Builder<Slot> aggUsedSlotBuilder = ImmutableSet.builder();
for (AggregateFunction function : aggregateFunctions) {
aggUsedSlotBuilder.addAll(function.<Set<SlotReference>>collect(SlotReference.class::isInstance));
}
ImmutableSet<Slot> aggUsedSlots = aggUsedSlotBuilder.build();
ImmutableSet.Builder<Slot> groupingSetsUsedSlotBuilder = ImmutableSet.builder();
for (List<Expression> groupingSet : repeat.getGroupingSets()) {
for (Expression expr : groupingSet) {
groupingSetsUsedSlotBuilder.addAll(expr.<Set<SlotReference>>collect(SlotReference.class::isInstance));
}
}
ImmutableSet<Slot> groupingSetsUsedSlot = groupingSetsUsedSlotBuilder.build();
Set<Slot> resSet = new HashSet<>(aggUsedSlots);
resSet.retainAll(groupingSetsUsedSlot);
Map<Slot, Alias> commonSlotToAliasMap = Maps.newHashMap();
for (Slot key : resSet) {
Alias alias = new Alias(key);
commonSlotToAliasMap.put(key, alias);
}
return commonSlotToAliasMap;
}
/**
* This class use the map(slotMapping) to rewrite all slots in trival-agg.
* The purpose of this class is to only rewrite the slots in trival-agg and not to rewrite the slots in window-agg.

View File

@ -51,6 +51,17 @@ public interface NormalizeToSlot {
this.normalizeToSlotMap = normalizeToSlotMap;
}
public Map<Expression, NormalizeToSlotTriplet> getNormalizeToSlotMap() {
return normalizeToSlotMap;
}
public NormalizeToSlotContext mergeContext(NormalizeToSlotContext context) {
Map<Expression, NormalizeToSlotTriplet> newMap = Maps.newHashMap();
newMap.putAll(this.normalizeToSlotMap);
newMap.putAll(context.getNormalizeToSlotMap());
return new NormalizeToSlotContext(newMap);
}
/**
* build normalization context by follow step.
* 1. collect all exists alias by input parameters existsAliases build a reverted map: expr -> alias

View File

@ -145,4 +145,5 @@ public class BuiltinFunctionBuilder extends FunctionBuilder {
.map(constructor -> new BuiltinFunctionBuilder(functionClass, (Constructor<BoundFunction>) constructor))
.collect(ImmutableList.toImmutableList());
}
}