[refactor](Nereids): unify all replaceNamedExpressions (#28228)

Use a unified function `replaceNamedExpressions ` instead of implementing it yourself repeatedly.
This commit is contained in:
jakevin
2024-01-09 13:32:55 +08:00
committed by yiguolei
parent 0c7c9485b6
commit 028e59efab
16 changed files with 405 additions and 443 deletions

View File

@ -21,9 +21,8 @@ import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
@ -36,7 +35,6 @@ import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* Eliminate filter which is FALSE or TRUE.
@ -68,11 +66,9 @@ public class EliminateFilter implements RewriteRuleFactory {
.toRule(RuleType.ELIMINATE_FILTER),
logicalFilter(logicalOneRowRelation()).thenApply(ctx -> {
LogicalFilter<LogicalOneRowRelation> filter = ctx.root;
Map<Expression, Expression> replaceMap =
filter.child().getOutputs().stream().filter(e -> e instanceof Alias)
.collect(Collectors.toMap(NamedExpression::toSlot, e -> ((Alias) e).child()));
Map<Slot, Expression> replaceMap = ExpressionUtils.generateReplaceMap(filter.child().getOutputs());
ImmutableSet.Builder newConjuncts = ImmutableSet.builder();
ImmutableSet.Builder<Expression> newConjuncts = ImmutableSet.builder();
ExpressionRewriteContext context =
new ExpressionRewriteContext(ctx.cascadesContext);
for (Expression expression : filter.getConjuncts()) {

View File

@ -19,16 +19,11 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.util.PlanUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import java.util.Map;
import java.util.List;
/**
* Project(OneRowRelation) -> OneRowRelation
@ -36,26 +31,11 @@ import java.util.Map;
public class PushProjectIntoOneRowRelation extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalProject(logicalOneRowRelation()).then(p -> {
Map<Expression, Expression> replaceMap = Maps.newHashMap();
Map<Expression, NamedExpression> replaceRootMap = Maps.newHashMap();
p.child().getProjects().forEach(ne -> {
if (ne instanceof Alias) {
replaceMap.put(ne.toSlot(), ((Alias) ne).child());
} else {
replaceMap.put(ne, ne);
}
replaceRootMap.put(ne.toSlot(), ne);
});
ImmutableList.Builder<NamedExpression> newProjections = ImmutableList.builder();
for (NamedExpression old : p.getProjects()) {
if (old instanceof SlotReference) {
newProjections.add(replaceRootMap.get(old));
} else {
newProjections.add((NamedExpression) ExpressionUtils.replace(old, replaceMap));
}
}
return p.child().withProjects(newProjections.build());
return logicalProject(logicalOneRowRelation()).then(project -> {
LogicalOneRowRelation oneRowRelation = project.child();
List<NamedExpression> namedExpressions = PlanUtils.mergeProjections(oneRowRelation.getProjects(),
project.getProjects());
return oneRowRelation.withProjects(namedExpressions);
}).toRule(RuleType.PUSH_PROJECT_INTO_ONE_ROW_RELATION);
}

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.plans.algebra;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
@ -31,7 +30,6 @@ import com.google.common.collect.ImmutableMap;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* Common interface for logical/physical project.
@ -51,16 +49,7 @@ public interface Project {
* </pre>
*/
default Map<Slot, Expression> getAliasToProducer() {
return getProjects()
.stream()
.filter(Alias.class::isInstance)
.collect(
Collectors.toMap(
NamedExpression::toSlot,
// Avoid cast to alias, retrieving the first child expression.
alias -> alias.child(0)
)
);
return ExpressionUtils.generateReplaceMap(getProjects());
}
/**

View File

@ -304,6 +304,22 @@ public class ExpressionUtils {
}
}
/**
* Generate replaceMap Slot -> Expression from NamedExpression[Expression as name]
*/
public static Map<Slot, Expression> generateReplaceMap(List<NamedExpression> namedExpressions) {
return namedExpressions
.stream()
.filter(Alias.class::isInstance)
.collect(
Collectors.toMap(
NamedExpression::toSlot,
// Avoid cast to alias, retrieving the first child expression.
alias -> alias.child(0)
)
);
}
/**
* Replace expression node in the expression tree by `replaceMap` in top-down manner.
* For example.
@ -346,6 +362,23 @@ public class ExpressionUtils {
.collect(ImmutableSet.toImmutableSet());
}
/**
* Replace expression node in the expression tree by `replaceMap` in top-down manner.
*/
public static List<NamedExpression> replaceNamedExpressions(List<NamedExpression> namedExpressions,
Map<? extends Expression, ? extends Expression> replaceMap) {
return namedExpressions.stream()
.map(namedExpression -> {
NamedExpression newExpr = replace(namedExpression, replaceMap);
if (newExpr.getExprId().equals(namedExpression.getExprId())) {
return newExpr;
} else {
return new Alias(namedExpression.getExprId(), newExpr, namedExpression.getName());
}
})
.collect(ImmutableList.toImmutableList());
}
public static <E extends Expression> List<E> rewriteDownShortCircuit(
Collection<E> exprs, Function<Expression, Expression> rewriteFunction) {
return exprs.stream()

View File

@ -53,16 +53,9 @@ public class ImmutableEqualSet<T> {
public void addEqualPair(T a, T b) {
T root1 = findRoot(a);
T root2 = findRoot(b);
if (root1 != root2) {
// merge by size
if (size.get(root1) < size.get(root2)) {
parent.put(root1, root2);
size.put(root2, size.get(root2) + size.get(root1));
} else {
parent.put(root2, root1);
size.put(root1, size.get(root1) + size.get(root2));
}
parent.put(b, root1);
findRoot(b);
}
}

View File

@ -17,7 +17,6 @@
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
@ -27,14 +26,12 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Util for plan
@ -78,24 +75,8 @@ public class PlanUtils {
*/
public static List<NamedExpression> mergeProjections(List<NamedExpression> childProjects,
List<NamedExpression> parentProjects) {
Map<Expression, Alias> replaceMap =
childProjects.stream().filter(e -> e instanceof Alias).collect(
Collectors.toMap(NamedExpression::toSlot, e -> (Alias) e, (v1, v2) -> v1));
return parentProjects.stream().map(expr -> {
if (expr instanceof Alias) {
Alias alias = (Alias) expr;
Expression insideExpr = alias.child();
Expression newInsideExpr = insideExpr.rewriteUp(e -> {
Alias getAlias = replaceMap.get(e);
return getAlias == null ? e : getAlias.child();
});
return newInsideExpr == insideExpr ? expr
: alias.withChildren(ImmutableList.of(newInsideExpr));
} else {
Alias getAlias = replaceMap.get(expr);
return getAlias == null ? expr : getAlias;
}
}).collect(ImmutableList.toImmutableList());
Map<Slot, Expression> replaceMap = ExpressionUtils.generateReplaceMap(childProjects);
return ExpressionUtils.replaceNamedExpressions(parentProjects, replaceMap);
}
public static Plan skipProjectFilterLimit(Plan plan) {