[refactor](Nereids): unify all replaceNamedExpressions (#28228)
Use a unified function `replaceNamedExpressions ` instead of implementing it yourself repeatedly.
This commit is contained in:
@ -21,9 +21,8 @@ import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
|
||||
@ -36,7 +35,6 @@ import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Eliminate filter which is FALSE or TRUE.
|
||||
@ -68,11 +66,9 @@ public class EliminateFilter implements RewriteRuleFactory {
|
||||
.toRule(RuleType.ELIMINATE_FILTER),
|
||||
logicalFilter(logicalOneRowRelation()).thenApply(ctx -> {
|
||||
LogicalFilter<LogicalOneRowRelation> filter = ctx.root;
|
||||
Map<Expression, Expression> replaceMap =
|
||||
filter.child().getOutputs().stream().filter(e -> e instanceof Alias)
|
||||
.collect(Collectors.toMap(NamedExpression::toSlot, e -> ((Alias) e).child()));
|
||||
Map<Slot, Expression> replaceMap = ExpressionUtils.generateReplaceMap(filter.child().getOutputs());
|
||||
|
||||
ImmutableSet.Builder newConjuncts = ImmutableSet.builder();
|
||||
ImmutableSet.Builder<Expression> newConjuncts = ImmutableSet.builder();
|
||||
ExpressionRewriteContext context =
|
||||
new ExpressionRewriteContext(ctx.cascadesContext);
|
||||
for (Expression expression : filter.getConjuncts()) {
|
||||
|
||||
@ -19,16 +19,11 @@ package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
|
||||
import org.apache.doris.nereids.util.PlanUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Maps;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Project(OneRowRelation) -> OneRowRelation
|
||||
@ -36,26 +31,11 @@ import java.util.Map;
|
||||
public class PushProjectIntoOneRowRelation extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalProject(logicalOneRowRelation()).then(p -> {
|
||||
Map<Expression, Expression> replaceMap = Maps.newHashMap();
|
||||
Map<Expression, NamedExpression> replaceRootMap = Maps.newHashMap();
|
||||
p.child().getProjects().forEach(ne -> {
|
||||
if (ne instanceof Alias) {
|
||||
replaceMap.put(ne.toSlot(), ((Alias) ne).child());
|
||||
} else {
|
||||
replaceMap.put(ne, ne);
|
||||
}
|
||||
replaceRootMap.put(ne.toSlot(), ne);
|
||||
});
|
||||
ImmutableList.Builder<NamedExpression> newProjections = ImmutableList.builder();
|
||||
for (NamedExpression old : p.getProjects()) {
|
||||
if (old instanceof SlotReference) {
|
||||
newProjections.add(replaceRootMap.get(old));
|
||||
} else {
|
||||
newProjections.add((NamedExpression) ExpressionUtils.replace(old, replaceMap));
|
||||
}
|
||||
}
|
||||
return p.child().withProjects(newProjections.build());
|
||||
return logicalProject(logicalOneRowRelation()).then(project -> {
|
||||
LogicalOneRowRelation oneRowRelation = project.child();
|
||||
List<NamedExpression> namedExpressions = PlanUtils.mergeProjections(oneRowRelation.getProjects(),
|
||||
project.getProjects());
|
||||
return oneRowRelation.withProjects(namedExpressions);
|
||||
|
||||
}).toRule(RuleType.PUSH_PROJECT_INTO_ONE_ROW_RELATION);
|
||||
}
|
||||
|
||||
@ -18,7 +18,6 @@
|
||||
package org.apache.doris.nereids.trees.plans.algebra;
|
||||
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
@ -31,7 +30,6 @@ import com.google.common.collect.ImmutableMap;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Common interface for logical/physical project.
|
||||
@ -51,16 +49,7 @@ public interface Project {
|
||||
* </pre>
|
||||
*/
|
||||
default Map<Slot, Expression> getAliasToProducer() {
|
||||
return getProjects()
|
||||
.stream()
|
||||
.filter(Alias.class::isInstance)
|
||||
.collect(
|
||||
Collectors.toMap(
|
||||
NamedExpression::toSlot,
|
||||
// Avoid cast to alias, retrieving the first child expression.
|
||||
alias -> alias.child(0)
|
||||
)
|
||||
);
|
||||
return ExpressionUtils.generateReplaceMap(getProjects());
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -304,6 +304,22 @@ public class ExpressionUtils {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate replaceMap Slot -> Expression from NamedExpression[Expression as name]
|
||||
*/
|
||||
public static Map<Slot, Expression> generateReplaceMap(List<NamedExpression> namedExpressions) {
|
||||
return namedExpressions
|
||||
.stream()
|
||||
.filter(Alias.class::isInstance)
|
||||
.collect(
|
||||
Collectors.toMap(
|
||||
NamedExpression::toSlot,
|
||||
// Avoid cast to alias, retrieving the first child expression.
|
||||
alias -> alias.child(0)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace expression node in the expression tree by `replaceMap` in top-down manner.
|
||||
* For example.
|
||||
@ -346,6 +362,23 @@ public class ExpressionUtils {
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace expression node in the expression tree by `replaceMap` in top-down manner.
|
||||
*/
|
||||
public static List<NamedExpression> replaceNamedExpressions(List<NamedExpression> namedExpressions,
|
||||
Map<? extends Expression, ? extends Expression> replaceMap) {
|
||||
return namedExpressions.stream()
|
||||
.map(namedExpression -> {
|
||||
NamedExpression newExpr = replace(namedExpression, replaceMap);
|
||||
if (newExpr.getExprId().equals(namedExpression.getExprId())) {
|
||||
return newExpr;
|
||||
} else {
|
||||
return new Alias(namedExpression.getExprId(), newExpr, namedExpression.getName());
|
||||
}
|
||||
})
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
|
||||
public static <E extends Expression> List<E> rewriteDownShortCircuit(
|
||||
Collection<E> exprs, Function<Expression, Expression> rewriteFunction) {
|
||||
return exprs.stream()
|
||||
|
||||
@ -53,16 +53,9 @@ public class ImmutableEqualSet<T> {
|
||||
public void addEqualPair(T a, T b) {
|
||||
T root1 = findRoot(a);
|
||||
T root2 = findRoot(b);
|
||||
|
||||
if (root1 != root2) {
|
||||
// merge by size
|
||||
if (size.get(root1) < size.get(root2)) {
|
||||
parent.put(root1, root2);
|
||||
size.put(root2, size.get(root2) + size.get(root1));
|
||||
} else {
|
||||
parent.put(root2, root1);
|
||||
size.put(root1, size.get(root1) + size.get(root2));
|
||||
}
|
||||
parent.put(b, root1);
|
||||
findRoot(b);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
|
||||
package org.apache.doris.nereids.util;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
@ -27,14 +26,12 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Util for plan
|
||||
@ -78,24 +75,8 @@ public class PlanUtils {
|
||||
*/
|
||||
public static List<NamedExpression> mergeProjections(List<NamedExpression> childProjects,
|
||||
List<NamedExpression> parentProjects) {
|
||||
Map<Expression, Alias> replaceMap =
|
||||
childProjects.stream().filter(e -> e instanceof Alias).collect(
|
||||
Collectors.toMap(NamedExpression::toSlot, e -> (Alias) e, (v1, v2) -> v1));
|
||||
return parentProjects.stream().map(expr -> {
|
||||
if (expr instanceof Alias) {
|
||||
Alias alias = (Alias) expr;
|
||||
Expression insideExpr = alias.child();
|
||||
Expression newInsideExpr = insideExpr.rewriteUp(e -> {
|
||||
Alias getAlias = replaceMap.get(e);
|
||||
return getAlias == null ? e : getAlias.child();
|
||||
});
|
||||
return newInsideExpr == insideExpr ? expr
|
||||
: alias.withChildren(ImmutableList.of(newInsideExpr));
|
||||
} else {
|
||||
Alias getAlias = replaceMap.get(expr);
|
||||
return getAlias == null ? expr : getAlias;
|
||||
}
|
||||
}).collect(ImmutableList.toImmutableList());
|
||||
Map<Slot, Expression> replaceMap = ExpressionUtils.generateReplaceMap(childProjects);
|
||||
return ExpressionUtils.replaceNamedExpressions(parentProjects, replaceMap);
|
||||
}
|
||||
|
||||
public static Plan skipProjectFilterLimit(Plan plan) {
|
||||
|
||||
Reference in New Issue
Block a user