[Fix](nereids) fix normalize repeat alias rewrite (#38166) (#38454)

cherry-pick #38166 to branch-2.1
This commit is contained in:
feiniaofeiafei
2024-07-31 10:59:15 +08:00
committed by GitHub
parent 182bf4d323
commit 94111da2a9
4 changed files with 684 additions and 13 deletions

View File

@ -51,6 +51,8 @@ import com.google.common.collect.Sets.SetView;
import org.jetbrains.annotations.NotNull;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@ -133,9 +135,9 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
Set<Expression> needToSlotsGroupingExpr = collectNeedToSlotGroupingExpr(repeat);
NormalizeToSlotContext groupingExprContext = buildContext(repeat, needToSlotsGroupingExpr);
Map<Expression, NormalizeToSlotTriplet> groupingExprMap = groupingExprContext.getNormalizeToSlotMap();
Set<Alias> existsAlias = getExistsAlias(repeat, groupingExprMap);
Map<Expression, Alias> existsAlias = getExistsAlias(repeat, groupingExprMap);
Set<Expression> needToSlotsArgs = collectNeedToSlotArgsOfGroupingScalarFuncAndAggFunc(repeat);
NormalizeToSlotContext argsContext = NormalizeToSlotContext.buildContext(existsAlias, needToSlotsArgs);
NormalizeToSlotContext argsContext = buildContextWithAlias(repeat, existsAlias, needToSlotsArgs);
// normalize grouping sets to List<List<Slot>>
ImmutableList.Builder<List<Slot>> normalizedGroupingSetBuilder = ImmutableList.builder();
@ -254,12 +256,27 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
/** buildContext */
public static NormalizeToSlotContext buildContext(Repeat<? extends Plan> repeat,
Set<? extends Expression> sourceExpressions) {
Set<Alias> aliases = ExpressionUtils.collect(repeat.getOutputExpressions(), Alias.class::isInstance);
List<Alias> aliases = ExpressionUtils.collectToList(repeat.getOutputExpressions(), Alias.class::isInstance);
Map<Expression, Alias> existsAliasMap = Maps.newLinkedHashMap();
for (Alias existsAlias : aliases) {
if (existsAliasMap.containsKey(existsAlias.child())) {
continue;
}
existsAliasMap.put(existsAlias.child(), existsAlias);
}
Map<Expression, NormalizeToSlotTriplet> normalizeToSlotMap = Maps.newLinkedHashMap();
for (Expression expression : sourceExpressions) {
Optional<NormalizeToSlotTriplet> pushDownTriplet =
toGroupingSetExpressionPushDownTriplet(expression, existsAliasMap.get(expression));
pushDownTriplet.ifPresent(
normalizeToSlotTriplet -> normalizeToSlotMap.put(expression, normalizeToSlotTriplet));
}
return new NormalizeToSlotContext(normalizeToSlotMap);
}
private static NormalizeToSlotContext buildContextWithAlias(Repeat<? extends Plan> repeat,
Map<Expression, Alias> existsAliasMap, Collection<? extends Expression> sourceExpressions) {
List<Expression> groupingSetExpressions = ExpressionUtils.flatExpressions(repeat.getGroupingSets());
Map<Expression, NormalizeToSlotTriplet> normalizeToSlotMap = Maps.newLinkedHashMap();
for (Expression expression : sourceExpressions) {
@ -270,10 +287,8 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
pushDownTriplet = Optional.of(
NormalizeToSlotTriplet.toTriplet(expression, existsAliasMap.get(expression)));
}
if (pushDownTriplet.isPresent()) {
normalizeToSlotMap.put(expression, pushDownTriplet.get());
}
pushDownTriplet.ifPresent(
normalizeToSlotTriplet -> normalizeToSlotMap.put(expression, normalizeToSlotTriplet));
}
return new NormalizeToSlotContext(normalizeToSlotMap);
}
@ -304,18 +319,23 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
}
}
private static Set<Alias> getExistsAlias(LogicalRepeat<Plan> repeat,
private static Map<Expression, Alias> getExistsAlias(LogicalRepeat<Plan> repeat,
Map<Expression, NormalizeToSlotTriplet> groupingExprMap) {
Set<Alias> existsAlias = Sets.newHashSet();
Set<Alias> aliases = ExpressionUtils.collect(repeat.getOutputExpressions(), Alias.class::isInstance);
existsAlias.addAll(aliases);
Map<Expression, Alias> existsAliasMap = new HashMap<>();
for (NormalizeToSlotTriplet triplet : groupingExprMap.values()) {
if (triplet.pushedExpr instanceof Alias) {
Alias alias = (Alias) triplet.pushedExpr;
existsAlias.add(alias);
existsAliasMap.put(triplet.originExpr, alias);
}
}
return existsAlias;
List<Alias> aliases = ExpressionUtils.collectToList(repeat.getOutputExpressions(), Alias.class::isInstance);
for (Alias alias : aliases) {
if (existsAliasMap.containsKey(alias.child())) {
continue;
}
existsAliasMap.put(alias.child(), alias);
}
return existsAliasMap;
}
/*

View File

@ -699,6 +699,15 @@ public class ExpressionUtils {
return set.build();
}
public static <E> List<E> collectToList(Collection<? extends Expression> expressions,
Predicate<TreeNode<Expression>> predicate) {
ImmutableList.Builder<E> list = ImmutableList.builder();
for (Expression expr : expressions) {
list.addAll(expr.collectToList(predicate));
}
return list.build();
}
/**
* extract uniform slot for the given predicate, such as a = 1 and b = 2
*/