[refactor](Nereids): refactor PredicatePropagation & support to infer Equal Condition (#29644)

This commit is contained in:
jakevin
2024-01-08 19:42:12 +08:00
committed by yiguolei
parent 8fc9c18c85
commit a2da434e3b
6 changed files with 135 additions and 139 deletions

View File

@ -36,7 +36,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ImmutableEquivalenceSet;
import org.apache.doris.nereids.util.ImmutableEqualSet;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
@ -145,7 +145,7 @@ public class EliminateJoinByFK extends DefaultPlanRewriter<JobContext> implement
return project;
}
private @Nullable Map<Slot, Slot> mapPrimaryToForeign(ImmutableEquivalenceSet<Slot> equivalenceSet,
private @Nullable Map<Slot, Slot> mapPrimaryToForeign(ImmutableEqualSet<Slot> equivalenceSet,
Set<Slot> foreignKeys) {
ImmutableMap.Builder<Slot, Slot> builder = new ImmutableMap.Builder<>();
for (Slot foreignSlot : foreignKeys) {
@ -164,7 +164,7 @@ public class EliminateJoinByFK extends DefaultPlanRewriter<JobContext> implement
// 4. if foreign key is null, add a isNotNull predicate for null-reject join condition
private Plan eliminateJoin(LogicalProject<LogicalJoin<?, ?>> project, ForeignKeyContext context) {
LogicalJoin<?, ?> join = project.child();
ImmutableEquivalenceSet<Slot> equalSet = join.getEqualSlots();
ImmutableEqualSet<Slot> equalSet = join.getEqualSlots();
Set<Slot> leftSlots = Sets.intersection(join.left().getOutputSet(), equalSet.getAllItemSet());
Set<Slot> rightSlots = Sets.intersection(join.right().getOutputSet(), equalSet.getAllItemSet());
if (context.isForeignKey(leftSlots) && context.isPrimaryKey(rightSlots)) {

View File

@ -17,16 +17,14 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.rules.expression.rules.DateFunctionRewrite;
import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
@ -35,11 +33,15 @@ import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.types.coercion.DateLikeType;
import org.apache.doris.nereids.types.coercion.IntegralType;
import org.apache.doris.nereids.util.ImmutableEqualSet;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.collect.Sets;
import java.util.Objects;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@ -65,59 +67,62 @@ public class PredicatePropagation {
}
}
private static class EqualInferInfo {
public final InferType inferType;
public final Expression left;
public final Expression right;
public final ComparisonPredicate comparisonPredicate;
public EqualInferInfo(InferType inferType,
Expression left, Expression right,
ComparisonPredicate comparisonPredicate) {
this.inferType = inferType;
this.left = left;
this.right = right;
this.comparisonPredicate = comparisonPredicate;
}
}
/**
* infer additional predicates.
*/
public static Set<Expression> infer(Set<Expression> predicates) {
Set<Expression> inferred = Sets.newHashSet();
ImmutableEqualSet.Builder<Slot> equalSetBuilder = new ImmutableEqualSet.Builder<>();
Map<Slot, List<Expression>> slotPredicates = new HashMap<>();
Set<Pair<Slot, Slot>> equalPairs = new HashSet<>();
for (Expression predicate : predicates) {
// if we support more infer predicate expression type, we should impl withInferred() method.
// And should add inferred props in withChildren() method just like ComparisonPredicate,
// and it's subclass, to mark the predicate is from infer.
if (!(predicate instanceof ComparisonPredicate
|| (predicate instanceof InPredicate && ((InPredicate) predicate).isLiteralChildren()))) {
Set<Slot> inputSlots = predicate.getInputSlots();
if (inputSlots.size() == 1) {
if (predicate instanceof ComparisonPredicate
|| (predicate instanceof InPredicate && ((InPredicate) predicate).isLiteralChildren())) {
slotPredicates.computeIfAbsent(inputSlots.iterator().next(), k -> new ArrayList<>()).add(predicate);
}
continue;
}
if (predicate instanceof InPredicate) {
continue;
if (predicate instanceof EqualTo) {
getEqualSlot(equalSetBuilder, equalPairs, (EqualTo) predicate);
}
EqualInferInfo equalInfo = getEqualInferInfo((ComparisonPredicate) predicate);
if (equalInfo.inferType == InferType.NONE) {
continue;
}
Set<Expression> newInferred = predicates.stream()
.filter(p -> !p.equals(predicate))
.filter(p -> p instanceof ComparisonPredicate || p instanceof InPredicate)
.map(predicateInfo -> doInferPredicate(equalInfo, predicateInfo))
.filter(Objects::nonNull)
.collect(Collectors.toSet());
inferred.addAll(newInferred);
}
inferred.removeAll(predicates);
ImmutableEqualSet<Slot> equalSet = equalSetBuilder.build();
Set<Expression> inferred = new HashSet<>();
slotPredicates.forEach((left, exprs) -> {
for (Slot right : equalSet.calEqualSet(left)) {
for (Expression expr : exprs) {
inferred.add(doInferPredicate(left, right, expr));
}
}
});
// infer equal to equal like a = b & b = c -> a = c
// a b c | e f g
// get (a b) (a c) (b c) | (e f) (e g) (f g)
List<Set<Slot>> equalSetList = equalSet.calEqualSetList();
for (Set<Slot> es : equalSetList) {
List<Slot> el = es.stream().sorted(Comparator.comparingInt(s -> s.getExprId().asInt()))
.collect(Collectors.toList());
for (int i = 0; i < el.size(); i++) {
Slot left = el.get(i);
for (int j = i + 1; j < el.size(); j++) {
Slot right = el.get(j);
if (!equalPairs.contains(Pair.of(left, right))) {
inferred.add(TypeCoercionUtils.processComparisonPredicate(new EqualTo(left, right))
.withInferred(true));
}
}
}
}
return inferred;
}
private static Expression doInferPredicate(EqualInferInfo equalInfo, Expression predicate) {
Expression equalLeft = equalInfo.left;
Expression equalRight = equalInfo.right;
private static Expression doInferPredicate(Expression equalLeft, Expression equalRight, Expression predicate) {
DataType leftType = predicate.child(0).getDataType();
InferType inferType;
if (leftType instanceof CharacterType) {
@ -160,47 +165,6 @@ public class PredicatePropagation {
}
}
/**
* Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression`
* Now only support infer `ComparisonPredicate`.
* TODO: We should determine whether `expression` satisfies the condition for replacement
* eg: Satisfy `expression` is non-deterministic
*/
private static Expression doInfer(EqualInferInfo equalInfo, EqualInferInfo predicateInfo) {
Expression equalLeft = equalInfo.left;
Expression equalRight = equalInfo.right;
Expression predicateLeft = predicateInfo.left;
Expression predicateRight = predicateInfo.right;
Expression newLeft = inferOneSide(predicateLeft, equalLeft, equalRight);
Expression newRight = inferOneSide(predicateRight, equalLeft, equalRight);
if (newLeft == null || newRight == null) {
return null;
}
ComparisonPredicate newPredicate = (ComparisonPredicate) predicateInfo
.comparisonPredicate.withChildren(newLeft, newRight);
Expression expr = SimplifyComparisonPredicate.INSTANCE
.rewrite(TypeCoercionUtils.processComparisonPredicate(newPredicate), null);
return DateFunctionRewrite.INSTANCE.rewrite(expr, null).withInferred(true);
}
private static Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
if (predicateOneSide instanceof SlotReference) {
if (predicateOneSide.equals(equalLeft)) {
return equalRight;
} else if (predicateOneSide.equals(equalRight)) {
return equalLeft;
}
} else if (predicateOneSide.isConstant()) {
if (predicateOneSide instanceof IntegerLikeLiteral) {
return new NereidsParser().parseExpression(((IntegerLikeLiteral) predicateOneSide).toSql());
} else {
return predicateOneSide;
}
}
return null;
}
private static Optional<Expression> validForInfer(Expression expression, InferType inferType) {
if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
return Optional.empty();
@ -249,7 +213,7 @@ public class PredicatePropagation {
return Optional.empty();
}
private static EqualInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
private static Optional<Pair<Expression, Expression>> inferInferInfo(ComparisonPredicate comparisonPredicate) {
DataType leftType = comparisonPredicate.left().getDataType();
InferType inferType;
if (leftType instanceof CharacterType) {
@ -264,29 +228,21 @@ public class PredicatePropagation {
Optional<Expression> left = validForInfer(comparisonPredicate.left(), inferType);
Optional<Expression> right = validForInfer(comparisonPredicate.right(), inferType);
if (!left.isPresent() || !right.isPresent()) {
inferType = InferType.NONE;
return Optional.empty();
}
return new EqualInferInfo(inferType, left.orElse(comparisonPredicate.left()),
right.orElse(comparisonPredicate.right()), comparisonPredicate);
return Optional.of(Pair.of(left.get(), right.get()));
}
/**
* Currently only equivalence derivation is supported
* and requires that the left and right sides of an expression must be slot
* <p>
* TODO: NullSafeEqual
*/
private static EqualInferInfo getEqualInferInfo(ComparisonPredicate predicate) {
if (!(predicate instanceof EqualTo)) {
return new EqualInferInfo(InferType.NONE, predicate.left(), predicate.right(), predicate);
}
EqualInferInfo info = inferInferInfo(predicate);
if (info.inferType == InferType.NONE) {
return info;
}
if (info.left instanceof SlotReference && info.right instanceof SlotReference) {
return info;
}
return new EqualInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
private static void getEqualSlot(ImmutableEqualSet.Builder<Slot> equalSlots, Set<Pair<Slot, Slot>> equalPairs,
EqualTo predicate) {
inferInferInfo(predicate)
.filter(info -> info.first instanceof Slot && info.second instanceof Slot)
.ifPresent(pair -> {
Slot left = (Slot) pair.first;
Slot right = (Slot) pair.second;
equalSlots.addEqualPair(left, right);
equalPairs.add(left.getExprId().asInt() <= right.getExprId().asInt()
? Pair.of(left, right) : Pair.of(right, left));
});
}
}

View File

@ -38,7 +38,7 @@ import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.algebra.Join;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.ImmutableEquivalenceSet;
import org.apache.doris.nereids.util.ImmutableEqualSet;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.nereids.util.Utils;
@ -467,12 +467,12 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
/**
* get Equal slot from join
*/
public ImmutableEquivalenceSet<Slot> getEqualSlots() {
public ImmutableEqualSet<Slot> getEqualSlots() {
// TODO: Use fd in the future
if (!joinType.isInnerJoin() && !joinType.isSemiJoin()) {
return ImmutableEquivalenceSet.of();
return ImmutableEqualSet.empty();
}
ImmutableEquivalenceSet.Builder<Slot> builder = new ImmutableEquivalenceSet.Builder<>();
ImmutableEqualSet.Builder<Slot> builder = new ImmutableEqualSet.Builder<>();
hashJoinConjuncts.stream()
.filter(e -> e instanceof EqualPredicate
&& e.child(0) instanceof Slot

View File

@ -17,64 +17,73 @@
package org.apache.doris.nereids.util;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* EquivalenceSet
* A class representing an immutable set of elements with equivalence relations.
*/
public class ImmutableEquivalenceSet<T> {
final Map<T, T> root;
public class ImmutableEqualSet<T> {
private final Map<T, T> root;
ImmutableEquivalenceSet(Map<T, T> root) {
ImmutableEqualSet(Map<T, T> root) {
this.root = ImmutableMap.copyOf(root);
}
public static <T> ImmutableEquivalenceSet<T> of() {
return new ImmutableEquivalenceSet<>(ImmutableMap.of());
public static <T> ImmutableEqualSet<T> empty() {
return new ImmutableEqualSet<>(ImmutableMap.of());
}
/**
* Builder of ImmutableEquivalenceSet
* Builder for ImmutableEqualSet.
*/
public static class Builder<T> {
final Map<T, T> parent = new HashMap<>();
private final Map<T, T> parent = new HashMap<>();
private final Map<T, Integer> size = new HashMap<>();
/**
* Add a equal pair
*/
public void addEqualPair(T a, T b) {
parent.computeIfAbsent(b, v -> v);
parent.computeIfAbsent(a, v -> v);
union(a, b);
}
private void union(T a, T b) {
T root1 = findRoot(a);
T root2 = findRoot(b);
if (root1 != root2) {
parent.put(b, root1);
findRoot(b);
// merge by size
if (size.get(root1) < size.get(root2)) {
parent.put(root1, root2);
size.put(root2, size.get(root2) + size.get(root1));
} else {
parent.put(root2, root1);
size.put(root1, size.get(root1) + size.get(root2));
}
}
}
private T findRoot(T a) {
parent.putIfAbsent(a, a); // Ensure that the element is added
size.putIfAbsent(a, 1); // Initialize size to 1
if (!parent.get(a).equals(a)) {
parent.put(a, findRoot(parent.get(a)));
parent.put(a, findRoot(parent.get(a))); // Path compression
}
return parent.get(a);
}
public ImmutableEquivalenceSet<T> build() {
public ImmutableEqualSet<T> build() {
parent.keySet().forEach(this::findRoot);
return new ImmutableEquivalenceSet<>(parent);
return new ImmutableEqualSet<>(parent);
}
}
/**
* cal equal set for a except self
* Calculate equal set for a except self
*/
public Set<T> calEqualSet(T a) {
T ra = root.get(a);
@ -83,6 +92,21 @@ public class ImmutableEquivalenceSet<T> {
.collect(ImmutableSet.toImmutableSet());
}
/**
* Calculate all equal set
*/
public List<Set<T>> calEqualSetList() {
return root.values()
.stream()
.distinct()
.map(a -> {
T ra = root.get(a);
return root.keySet().stream()
.filter(t -> root.get(t).equals(ra))
.collect(ImmutableSet.toImmutableSet());
}).collect(ImmutableList.toImmutableList());
}
public Set<T> getAllItemSet() {
return ImmutableSet.copyOf(root.keySet());
}