[refactor](Nereids): refactor PredicatePropagation & support to infer Equal Condition (#29644)
This commit is contained in:
@ -36,7 +36,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
|
||||
import org.apache.doris.nereids.util.ImmutableEquivalenceSet;
|
||||
import org.apache.doris.nereids.util.ImmutableEqualSet;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
@ -145,7 +145,7 @@ public class EliminateJoinByFK extends DefaultPlanRewriter<JobContext> implement
|
||||
return project;
|
||||
}
|
||||
|
||||
private @Nullable Map<Slot, Slot> mapPrimaryToForeign(ImmutableEquivalenceSet<Slot> equivalenceSet,
|
||||
private @Nullable Map<Slot, Slot> mapPrimaryToForeign(ImmutableEqualSet<Slot> equivalenceSet,
|
||||
Set<Slot> foreignKeys) {
|
||||
ImmutableMap.Builder<Slot, Slot> builder = new ImmutableMap.Builder<>();
|
||||
for (Slot foreignSlot : foreignKeys) {
|
||||
@ -164,7 +164,7 @@ public class EliminateJoinByFK extends DefaultPlanRewriter<JobContext> implement
|
||||
// 4. if foreign key is null, add a isNotNull predicate for null-reject join condition
|
||||
private Plan eliminateJoin(LogicalProject<LogicalJoin<?, ?>> project, ForeignKeyContext context) {
|
||||
LogicalJoin<?, ?> join = project.child();
|
||||
ImmutableEquivalenceSet<Slot> equalSet = join.getEqualSlots();
|
||||
ImmutableEqualSet<Slot> equalSet = join.getEqualSlots();
|
||||
Set<Slot> leftSlots = Sets.intersection(join.left().getOutputSet(), equalSet.getAllItemSet());
|
||||
Set<Slot> rightSlots = Sets.intersection(join.right().getOutputSet(), equalSet.getAllItemSet());
|
||||
if (context.isForeignKey(leftSlots) && context.isPrimaryKey(rightSlots)) {
|
||||
|
||||
@ -17,16 +17,14 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.parser.NereidsParser;
|
||||
import org.apache.doris.nereids.rules.expression.rules.DateFunctionRewrite;
|
||||
import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.InPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.DateTimeType;
|
||||
import org.apache.doris.nereids.types.DateTimeV2Type;
|
||||
@ -35,11 +33,15 @@ import org.apache.doris.nereids.types.DateV2Type;
|
||||
import org.apache.doris.nereids.types.coercion.CharacterType;
|
||||
import org.apache.doris.nereids.types.coercion.DateLikeType;
|
||||
import org.apache.doris.nereids.types.coercion.IntegralType;
|
||||
import org.apache.doris.nereids.util.ImmutableEqualSet;
|
||||
import org.apache.doris.nereids.util.TypeCoercionUtils;
|
||||
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.Objects;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
@ -65,59 +67,62 @@ public class PredicatePropagation {
|
||||
}
|
||||
}
|
||||
|
||||
private static class EqualInferInfo {
|
||||
|
||||
public final InferType inferType;
|
||||
public final Expression left;
|
||||
public final Expression right;
|
||||
public final ComparisonPredicate comparisonPredicate;
|
||||
|
||||
public EqualInferInfo(InferType inferType,
|
||||
Expression left, Expression right,
|
||||
ComparisonPredicate comparisonPredicate) {
|
||||
this.inferType = inferType;
|
||||
this.left = left;
|
||||
this.right = right;
|
||||
this.comparisonPredicate = comparisonPredicate;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* infer additional predicates.
|
||||
*/
|
||||
public static Set<Expression> infer(Set<Expression> predicates) {
|
||||
Set<Expression> inferred = Sets.newHashSet();
|
||||
ImmutableEqualSet.Builder<Slot> equalSetBuilder = new ImmutableEqualSet.Builder<>();
|
||||
Map<Slot, List<Expression>> slotPredicates = new HashMap<>();
|
||||
Set<Pair<Slot, Slot>> equalPairs = new HashSet<>();
|
||||
for (Expression predicate : predicates) {
|
||||
// if we support more infer predicate expression type, we should impl withInferred() method.
|
||||
// And should add inferred props in withChildren() method just like ComparisonPredicate,
|
||||
// and it's subclass, to mark the predicate is from infer.
|
||||
if (!(predicate instanceof ComparisonPredicate
|
||||
|| (predicate instanceof InPredicate && ((InPredicate) predicate).isLiteralChildren()))) {
|
||||
Set<Slot> inputSlots = predicate.getInputSlots();
|
||||
if (inputSlots.size() == 1) {
|
||||
if (predicate instanceof ComparisonPredicate
|
||||
|| (predicate instanceof InPredicate && ((InPredicate) predicate).isLiteralChildren())) {
|
||||
slotPredicates.computeIfAbsent(inputSlots.iterator().next(), k -> new ArrayList<>()).add(predicate);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (predicate instanceof InPredicate) {
|
||||
continue;
|
||||
|
||||
if (predicate instanceof EqualTo) {
|
||||
getEqualSlot(equalSetBuilder, equalPairs, (EqualTo) predicate);
|
||||
}
|
||||
EqualInferInfo equalInfo = getEqualInferInfo((ComparisonPredicate) predicate);
|
||||
if (equalInfo.inferType == InferType.NONE) {
|
||||
continue;
|
||||
}
|
||||
Set<Expression> newInferred = predicates.stream()
|
||||
.filter(p -> !p.equals(predicate))
|
||||
.filter(p -> p instanceof ComparisonPredicate || p instanceof InPredicate)
|
||||
.map(predicateInfo -> doInferPredicate(equalInfo, predicateInfo))
|
||||
.filter(Objects::nonNull)
|
||||
.collect(Collectors.toSet());
|
||||
inferred.addAll(newInferred);
|
||||
}
|
||||
inferred.removeAll(predicates);
|
||||
|
||||
ImmutableEqualSet<Slot> equalSet = equalSetBuilder.build();
|
||||
|
||||
Set<Expression> inferred = new HashSet<>();
|
||||
slotPredicates.forEach((left, exprs) -> {
|
||||
for (Slot right : equalSet.calEqualSet(left)) {
|
||||
for (Expression expr : exprs) {
|
||||
inferred.add(doInferPredicate(left, right, expr));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// infer equal to equal like a = b & b = c -> a = c
|
||||
// a b c | e f g
|
||||
// get (a b) (a c) (b c) | (e f) (e g) (f g)
|
||||
List<Set<Slot>> equalSetList = equalSet.calEqualSetList();
|
||||
for (Set<Slot> es : equalSetList) {
|
||||
List<Slot> el = es.stream().sorted(Comparator.comparingInt(s -> s.getExprId().asInt()))
|
||||
.collect(Collectors.toList());
|
||||
for (int i = 0; i < el.size(); i++) {
|
||||
Slot left = el.get(i);
|
||||
for (int j = i + 1; j < el.size(); j++) {
|
||||
Slot right = el.get(j);
|
||||
if (!equalPairs.contains(Pair.of(left, right))) {
|
||||
inferred.add(TypeCoercionUtils.processComparisonPredicate(new EqualTo(left, right))
|
||||
.withInferred(true));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return inferred;
|
||||
}
|
||||
|
||||
private static Expression doInferPredicate(EqualInferInfo equalInfo, Expression predicate) {
|
||||
Expression equalLeft = equalInfo.left;
|
||||
Expression equalRight = equalInfo.right;
|
||||
|
||||
private static Expression doInferPredicate(Expression equalLeft, Expression equalRight, Expression predicate) {
|
||||
DataType leftType = predicate.child(0).getDataType();
|
||||
InferType inferType;
|
||||
if (leftType instanceof CharacterType) {
|
||||
@ -160,47 +165,6 @@ public class PredicatePropagation {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression`
|
||||
* Now only support infer `ComparisonPredicate`.
|
||||
* TODO: We should determine whether `expression` satisfies the condition for replacement
|
||||
* eg: Satisfy `expression` is non-deterministic
|
||||
*/
|
||||
private static Expression doInfer(EqualInferInfo equalInfo, EqualInferInfo predicateInfo) {
|
||||
Expression equalLeft = equalInfo.left;
|
||||
Expression equalRight = equalInfo.right;
|
||||
|
||||
Expression predicateLeft = predicateInfo.left;
|
||||
Expression predicateRight = predicateInfo.right;
|
||||
Expression newLeft = inferOneSide(predicateLeft, equalLeft, equalRight);
|
||||
Expression newRight = inferOneSide(predicateRight, equalLeft, equalRight);
|
||||
if (newLeft == null || newRight == null) {
|
||||
return null;
|
||||
}
|
||||
ComparisonPredicate newPredicate = (ComparisonPredicate) predicateInfo
|
||||
.comparisonPredicate.withChildren(newLeft, newRight);
|
||||
Expression expr = SimplifyComparisonPredicate.INSTANCE
|
||||
.rewrite(TypeCoercionUtils.processComparisonPredicate(newPredicate), null);
|
||||
return DateFunctionRewrite.INSTANCE.rewrite(expr, null).withInferred(true);
|
||||
}
|
||||
|
||||
private static Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
|
||||
if (predicateOneSide instanceof SlotReference) {
|
||||
if (predicateOneSide.equals(equalLeft)) {
|
||||
return equalRight;
|
||||
} else if (predicateOneSide.equals(equalRight)) {
|
||||
return equalLeft;
|
||||
}
|
||||
} else if (predicateOneSide.isConstant()) {
|
||||
if (predicateOneSide instanceof IntegerLikeLiteral) {
|
||||
return new NereidsParser().parseExpression(((IntegerLikeLiteral) predicateOneSide).toSql());
|
||||
} else {
|
||||
return predicateOneSide;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static Optional<Expression> validForInfer(Expression expression, InferType inferType) {
|
||||
if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
|
||||
return Optional.empty();
|
||||
@ -249,7 +213,7 @@ public class PredicatePropagation {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
private static EqualInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
|
||||
private static Optional<Pair<Expression, Expression>> inferInferInfo(ComparisonPredicate comparisonPredicate) {
|
||||
DataType leftType = comparisonPredicate.left().getDataType();
|
||||
InferType inferType;
|
||||
if (leftType instanceof CharacterType) {
|
||||
@ -264,29 +228,21 @@ public class PredicatePropagation {
|
||||
Optional<Expression> left = validForInfer(comparisonPredicate.left(), inferType);
|
||||
Optional<Expression> right = validForInfer(comparisonPredicate.right(), inferType);
|
||||
if (!left.isPresent() || !right.isPresent()) {
|
||||
inferType = InferType.NONE;
|
||||
return Optional.empty();
|
||||
}
|
||||
return new EqualInferInfo(inferType, left.orElse(comparisonPredicate.left()),
|
||||
right.orElse(comparisonPredicate.right()), comparisonPredicate);
|
||||
return Optional.of(Pair.of(left.get(), right.get()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Currently only equivalence derivation is supported
|
||||
* and requires that the left and right sides of an expression must be slot
|
||||
* <p>
|
||||
* TODO: NullSafeEqual
|
||||
*/
|
||||
private static EqualInferInfo getEqualInferInfo(ComparisonPredicate predicate) {
|
||||
if (!(predicate instanceof EqualTo)) {
|
||||
return new EqualInferInfo(InferType.NONE, predicate.left(), predicate.right(), predicate);
|
||||
}
|
||||
EqualInferInfo info = inferInferInfo(predicate);
|
||||
if (info.inferType == InferType.NONE) {
|
||||
return info;
|
||||
}
|
||||
if (info.left instanceof SlotReference && info.right instanceof SlotReference) {
|
||||
return info;
|
||||
}
|
||||
return new EqualInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
|
||||
private static void getEqualSlot(ImmutableEqualSet.Builder<Slot> equalSlots, Set<Pair<Slot, Slot>> equalPairs,
|
||||
EqualTo predicate) {
|
||||
inferInferInfo(predicate)
|
||||
.filter(info -> info.first instanceof Slot && info.second instanceof Slot)
|
||||
.ifPresent(pair -> {
|
||||
Slot left = (Slot) pair.first;
|
||||
Slot right = (Slot) pair.second;
|
||||
equalSlots.addEqualPair(left, right);
|
||||
equalPairs.add(left.getExprId().asInt() <= right.getExprId().asInt()
|
||||
? Pair.of(left, right) : Pair.of(right, left));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,7 +38,7 @@ import org.apache.doris.nereids.trees.plans.PlanType;
|
||||
import org.apache.doris.nereids.trees.plans.algebra.Join;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.ImmutableEquivalenceSet;
|
||||
import org.apache.doris.nereids.util.ImmutableEqualSet;
|
||||
import org.apache.doris.nereids.util.JoinUtils;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
@ -467,12 +467,12 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
|
||||
/**
|
||||
* get Equal slot from join
|
||||
*/
|
||||
public ImmutableEquivalenceSet<Slot> getEqualSlots() {
|
||||
public ImmutableEqualSet<Slot> getEqualSlots() {
|
||||
// TODO: Use fd in the future
|
||||
if (!joinType.isInnerJoin() && !joinType.isSemiJoin()) {
|
||||
return ImmutableEquivalenceSet.of();
|
||||
return ImmutableEqualSet.empty();
|
||||
}
|
||||
ImmutableEquivalenceSet.Builder<Slot> builder = new ImmutableEquivalenceSet.Builder<>();
|
||||
ImmutableEqualSet.Builder<Slot> builder = new ImmutableEqualSet.Builder<>();
|
||||
hashJoinConjuncts.stream()
|
||||
.filter(e -> e instanceof EqualPredicate
|
||||
&& e.child(0) instanceof Slot
|
||||
|
||||
@ -17,64 +17,73 @@
|
||||
|
||||
package org.apache.doris.nereids.util;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* EquivalenceSet
|
||||
* A class representing an immutable set of elements with equivalence relations.
|
||||
*/
|
||||
public class ImmutableEquivalenceSet<T> {
|
||||
final Map<T, T> root;
|
||||
public class ImmutableEqualSet<T> {
|
||||
private final Map<T, T> root;
|
||||
|
||||
ImmutableEquivalenceSet(Map<T, T> root) {
|
||||
ImmutableEqualSet(Map<T, T> root) {
|
||||
this.root = ImmutableMap.copyOf(root);
|
||||
}
|
||||
|
||||
public static <T> ImmutableEquivalenceSet<T> of() {
|
||||
return new ImmutableEquivalenceSet<>(ImmutableMap.of());
|
||||
public static <T> ImmutableEqualSet<T> empty() {
|
||||
return new ImmutableEqualSet<>(ImmutableMap.of());
|
||||
}
|
||||
|
||||
/**
|
||||
* Builder of ImmutableEquivalenceSet
|
||||
* Builder for ImmutableEqualSet.
|
||||
*/
|
||||
public static class Builder<T> {
|
||||
final Map<T, T> parent = new HashMap<>();
|
||||
private final Map<T, T> parent = new HashMap<>();
|
||||
private final Map<T, Integer> size = new HashMap<>();
|
||||
|
||||
/**
|
||||
* Add a equal pair
|
||||
*/
|
||||
public void addEqualPair(T a, T b) {
|
||||
parent.computeIfAbsent(b, v -> v);
|
||||
parent.computeIfAbsent(a, v -> v);
|
||||
union(a, b);
|
||||
}
|
||||
|
||||
private void union(T a, T b) {
|
||||
T root1 = findRoot(a);
|
||||
T root2 = findRoot(b);
|
||||
|
||||
if (root1 != root2) {
|
||||
parent.put(b, root1);
|
||||
findRoot(b);
|
||||
// merge by size
|
||||
if (size.get(root1) < size.get(root2)) {
|
||||
parent.put(root1, root2);
|
||||
size.put(root2, size.get(root2) + size.get(root1));
|
||||
} else {
|
||||
parent.put(root2, root1);
|
||||
size.put(root1, size.get(root1) + size.get(root2));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private T findRoot(T a) {
|
||||
parent.putIfAbsent(a, a); // Ensure that the element is added
|
||||
size.putIfAbsent(a, 1); // Initialize size to 1
|
||||
|
||||
if (!parent.get(a).equals(a)) {
|
||||
parent.put(a, findRoot(parent.get(a)));
|
||||
parent.put(a, findRoot(parent.get(a))); // Path compression
|
||||
}
|
||||
return parent.get(a);
|
||||
}
|
||||
|
||||
public ImmutableEquivalenceSet<T> build() {
|
||||
public ImmutableEqualSet<T> build() {
|
||||
parent.keySet().forEach(this::findRoot);
|
||||
return new ImmutableEquivalenceSet<>(parent);
|
||||
return new ImmutableEqualSet<>(parent);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* cal equal set for a except self
|
||||
* Calculate equal set for a except self
|
||||
*/
|
||||
public Set<T> calEqualSet(T a) {
|
||||
T ra = root.get(a);
|
||||
@ -83,6 +92,21 @@ public class ImmutableEquivalenceSet<T> {
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate all equal set
|
||||
*/
|
||||
public List<Set<T>> calEqualSetList() {
|
||||
return root.values()
|
||||
.stream()
|
||||
.distinct()
|
||||
.map(a -> {
|
||||
T ra = root.get(a);
|
||||
return root.keySet().stream()
|
||||
.filter(t -> root.get(t).equals(ra))
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
}).collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
|
||||
public Set<T> getAllItemSet() {
|
||||
return ImmutableSet.copyOf(root.keySet());
|
||||
}
|
||||
Reference in New Issue
Block a user