From a2da434e3b78b3ba7a823c173a150235ed96c47a Mon Sep 17 00:00:00 2001 From: jakevin Date: Mon, 8 Jan 2024 19:42:12 +0800 Subject: [PATCH] [refactor](Nereids): refactor PredicatePropagation & support to infer Equal Condition (#29644) --- .../rules/rewrite/EliminateJoinByFK.java | 6 +- .../rules/rewrite/PredicatePropagation.java | 178 +++++++----------- .../trees/plans/logical/LogicalJoin.java | 8 +- ...valenceSet.java => ImmutableEqualSet.java} | 64 +++++-- .../rewrite/PredicatePropagationTest.java | 16 ++ .../data/nereids_p0/hint/fix_leading.out | 2 +- 6 files changed, 135 insertions(+), 139 deletions(-) rename fe/fe-core/src/main/java/org/apache/doris/nereids/util/{ImmutableEquivalenceSet.java => ImmutableEqualSet.java} (50%) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java index 594dee5085..b4a6eac207 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java @@ -36,7 +36,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; -import org.apache.doris.nereids.util.ImmutableEquivalenceSet; +import org.apache.doris.nereids.util.ImmutableEqualSet; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -145,7 +145,7 @@ public class EliminateJoinByFK extends DefaultPlanRewriter implement return project; } - private @Nullable Map mapPrimaryToForeign(ImmutableEquivalenceSet equivalenceSet, + private @Nullable Map mapPrimaryToForeign(ImmutableEqualSet equivalenceSet, Set foreignKeys) { ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); for (Slot foreignSlot : foreignKeys) { @@ -164,7 +164,7 @@ public class EliminateJoinByFK extends DefaultPlanRewriter implement // 4. if foreign key is null, add a isNotNull predicate for null-reject join condition private Plan eliminateJoin(LogicalProject> project, ForeignKeyContext context) { LogicalJoin join = project.child(); - ImmutableEquivalenceSet equalSet = join.getEqualSlots(); + ImmutableEqualSet equalSet = join.getEqualSlots(); Set leftSlots = Sets.intersection(join.left().getOutputSet(), equalSet.getAllItemSet()); Set rightSlots = Sets.intersection(join.right().getOutputSet(), equalSet.getAllItemSet()); if (context.isForeignKey(leftSlots) && context.isPrimaryKey(rightSlots)) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java index 7788bbb7f0..5d11a1fa54 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java @@ -17,16 +17,14 @@ package org.apache.doris.nereids.rules.rewrite; -import org.apache.doris.nereids.parser.NereidsParser; -import org.apache.doris.nereids.rules.expression.rules.DateFunctionRewrite; -import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DateTimeType; import org.apache.doris.nereids.types.DateTimeV2Type; @@ -35,11 +33,15 @@ import org.apache.doris.nereids.types.DateV2Type; import org.apache.doris.nereids.types.coercion.CharacterType; import org.apache.doris.nereids.types.coercion.DateLikeType; import org.apache.doris.nereids.types.coercion.IntegralType; +import org.apache.doris.nereids.util.ImmutableEqualSet; import org.apache.doris.nereids.util.TypeCoercionUtils; -import com.google.common.collect.Sets; - -import java.util.Objects; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -65,59 +67,62 @@ public class PredicatePropagation { } } - private static class EqualInferInfo { - - public final InferType inferType; - public final Expression left; - public final Expression right; - public final ComparisonPredicate comparisonPredicate; - - public EqualInferInfo(InferType inferType, - Expression left, Expression right, - ComparisonPredicate comparisonPredicate) { - this.inferType = inferType; - this.left = left; - this.right = right; - this.comparisonPredicate = comparisonPredicate; - } - } - /** * infer additional predicates. */ public static Set infer(Set predicates) { - Set inferred = Sets.newHashSet(); + ImmutableEqualSet.Builder equalSetBuilder = new ImmutableEqualSet.Builder<>(); + Map> slotPredicates = new HashMap<>(); + Set> equalPairs = new HashSet<>(); for (Expression predicate : predicates) { - // if we support more infer predicate expression type, we should impl withInferred() method. - // And should add inferred props in withChildren() method just like ComparisonPredicate, - // and it's subclass, to mark the predicate is from infer. - if (!(predicate instanceof ComparisonPredicate - || (predicate instanceof InPredicate && ((InPredicate) predicate).isLiteralChildren()))) { + Set inputSlots = predicate.getInputSlots(); + if (inputSlots.size() == 1) { + if (predicate instanceof ComparisonPredicate + || (predicate instanceof InPredicate && ((InPredicate) predicate).isLiteralChildren())) { + slotPredicates.computeIfAbsent(inputSlots.iterator().next(), k -> new ArrayList<>()).add(predicate); + } continue; } - if (predicate instanceof InPredicate) { - continue; + + if (predicate instanceof EqualTo) { + getEqualSlot(equalSetBuilder, equalPairs, (EqualTo) predicate); } - EqualInferInfo equalInfo = getEqualInferInfo((ComparisonPredicate) predicate); - if (equalInfo.inferType == InferType.NONE) { - continue; - } - Set newInferred = predicates.stream() - .filter(p -> !p.equals(predicate)) - .filter(p -> p instanceof ComparisonPredicate || p instanceof InPredicate) - .map(predicateInfo -> doInferPredicate(equalInfo, predicateInfo)) - .filter(Objects::nonNull) - .collect(Collectors.toSet()); - inferred.addAll(newInferred); } - inferred.removeAll(predicates); + + ImmutableEqualSet equalSet = equalSetBuilder.build(); + + Set inferred = new HashSet<>(); + slotPredicates.forEach((left, exprs) -> { + for (Slot right : equalSet.calEqualSet(left)) { + for (Expression expr : exprs) { + inferred.add(doInferPredicate(left, right, expr)); + } + } + }); + + // infer equal to equal like a = b & b = c -> a = c + // a b c | e f g + // get (a b) (a c) (b c) | (e f) (e g) (f g) + List> equalSetList = equalSet.calEqualSetList(); + for (Set es : equalSetList) { + List el = es.stream().sorted(Comparator.comparingInt(s -> s.getExprId().asInt())) + .collect(Collectors.toList()); + for (int i = 0; i < el.size(); i++) { + Slot left = el.get(i); + for (int j = i + 1; j < el.size(); j++) { + Slot right = el.get(j); + if (!equalPairs.contains(Pair.of(left, right))) { + inferred.add(TypeCoercionUtils.processComparisonPredicate(new EqualTo(left, right)) + .withInferred(true)); + } + } + } + } + return inferred; } - private static Expression doInferPredicate(EqualInferInfo equalInfo, Expression predicate) { - Expression equalLeft = equalInfo.left; - Expression equalRight = equalInfo.right; - + private static Expression doInferPredicate(Expression equalLeft, Expression equalRight, Expression predicate) { DataType leftType = predicate.child(0).getDataType(); InferType inferType; if (leftType instanceof CharacterType) { @@ -160,47 +165,6 @@ public class PredicatePropagation { } } - /** - * Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression` - * Now only support infer `ComparisonPredicate`. - * TODO: We should determine whether `expression` satisfies the condition for replacement - * eg: Satisfy `expression` is non-deterministic - */ - private static Expression doInfer(EqualInferInfo equalInfo, EqualInferInfo predicateInfo) { - Expression equalLeft = equalInfo.left; - Expression equalRight = equalInfo.right; - - Expression predicateLeft = predicateInfo.left; - Expression predicateRight = predicateInfo.right; - Expression newLeft = inferOneSide(predicateLeft, equalLeft, equalRight); - Expression newRight = inferOneSide(predicateRight, equalLeft, equalRight); - if (newLeft == null || newRight == null) { - return null; - } - ComparisonPredicate newPredicate = (ComparisonPredicate) predicateInfo - .comparisonPredicate.withChildren(newLeft, newRight); - Expression expr = SimplifyComparisonPredicate.INSTANCE - .rewrite(TypeCoercionUtils.processComparisonPredicate(newPredicate), null); - return DateFunctionRewrite.INSTANCE.rewrite(expr, null).withInferred(true); - } - - private static Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) { - if (predicateOneSide instanceof SlotReference) { - if (predicateOneSide.equals(equalLeft)) { - return equalRight; - } else if (predicateOneSide.equals(equalRight)) { - return equalLeft; - } - } else if (predicateOneSide.isConstant()) { - if (predicateOneSide instanceof IntegerLikeLiteral) { - return new NereidsParser().parseExpression(((IntegerLikeLiteral) predicateOneSide).toSql()); - } else { - return predicateOneSide; - } - } - return null; - } - private static Optional validForInfer(Expression expression, InferType inferType) { if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) { return Optional.empty(); @@ -249,7 +213,7 @@ public class PredicatePropagation { return Optional.empty(); } - private static EqualInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) { + private static Optional> inferInferInfo(ComparisonPredicate comparisonPredicate) { DataType leftType = comparisonPredicate.left().getDataType(); InferType inferType; if (leftType instanceof CharacterType) { @@ -264,29 +228,21 @@ public class PredicatePropagation { Optional left = validForInfer(comparisonPredicate.left(), inferType); Optional right = validForInfer(comparisonPredicate.right(), inferType); if (!left.isPresent() || !right.isPresent()) { - inferType = InferType.NONE; + return Optional.empty(); } - return new EqualInferInfo(inferType, left.orElse(comparisonPredicate.left()), - right.orElse(comparisonPredicate.right()), comparisonPredicate); + return Optional.of(Pair.of(left.get(), right.get())); } - /** - * Currently only equivalence derivation is supported - * and requires that the left and right sides of an expression must be slot - *

- * TODO: NullSafeEqual - */ - private static EqualInferInfo getEqualInferInfo(ComparisonPredicate predicate) { - if (!(predicate instanceof EqualTo)) { - return new EqualInferInfo(InferType.NONE, predicate.left(), predicate.right(), predicate); - } - EqualInferInfo info = inferInferInfo(predicate); - if (info.inferType == InferType.NONE) { - return info; - } - if (info.left instanceof SlotReference && info.right instanceof SlotReference) { - return info; - } - return new EqualInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate); + private static void getEqualSlot(ImmutableEqualSet.Builder equalSlots, Set> equalPairs, + EqualTo predicate) { + inferInferInfo(predicate) + .filter(info -> info.first instanceof Slot && info.second instanceof Slot) + .ifPresent(pair -> { + Slot left = (Slot) pair.first; + Slot right = (Slot) pair.second; + equalSlots.addEqualPair(left, right); + equalPairs.add(left.getExprId().asInt() <= right.getExprId().asInt() + ? Pair.of(left, right) : Pair.of(right, left)); + }); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java index 76fda759ef..6c78193abf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java @@ -38,7 +38,7 @@ import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.algebra.Join; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; import org.apache.doris.nereids.util.ExpressionUtils; -import org.apache.doris.nereids.util.ImmutableEquivalenceSet; +import org.apache.doris.nereids.util.ImmutableEqualSet; import org.apache.doris.nereids.util.JoinUtils; import org.apache.doris.nereids.util.Utils; @@ -467,12 +467,12 @@ public class LogicalJoin getEqualSlots() { + public ImmutableEqualSet getEqualSlots() { // TODO: Use fd in the future if (!joinType.isInnerJoin() && !joinType.isSemiJoin()) { - return ImmutableEquivalenceSet.of(); + return ImmutableEqualSet.empty(); } - ImmutableEquivalenceSet.Builder builder = new ImmutableEquivalenceSet.Builder<>(); + ImmutableEqualSet.Builder builder = new ImmutableEqualSet.Builder<>(); hashJoinConjuncts.stream() .filter(e -> e instanceof EqualPredicate && e.child(0) instanceof Slot diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ImmutableEquivalenceSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ImmutableEqualSet.java similarity index 50% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/util/ImmutableEquivalenceSet.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/util/ImmutableEqualSet.java index 66d20597fc..724414e2e1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ImmutableEquivalenceSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ImmutableEqualSet.java @@ -17,64 +17,73 @@ package org.apache.doris.nereids.util; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Set; /** - * EquivalenceSet + * A class representing an immutable set of elements with equivalence relations. */ -public class ImmutableEquivalenceSet { - final Map root; +public class ImmutableEqualSet { + private final Map root; - ImmutableEquivalenceSet(Map root) { + ImmutableEqualSet(Map root) { this.root = ImmutableMap.copyOf(root); } - public static ImmutableEquivalenceSet of() { - return new ImmutableEquivalenceSet<>(ImmutableMap.of()); + public static ImmutableEqualSet empty() { + return new ImmutableEqualSet<>(ImmutableMap.of()); } /** - * Builder of ImmutableEquivalenceSet + * Builder for ImmutableEqualSet. */ public static class Builder { - final Map parent = new HashMap<>(); + private final Map parent = new HashMap<>(); + private final Map size = new HashMap<>(); + /** + * Add a equal pair + */ public void addEqualPair(T a, T b) { - parent.computeIfAbsent(b, v -> v); - parent.computeIfAbsent(a, v -> v); - union(a, b); - } - - private void union(T a, T b) { T root1 = findRoot(a); T root2 = findRoot(b); if (root1 != root2) { - parent.put(b, root1); - findRoot(b); + // merge by size + if (size.get(root1) < size.get(root2)) { + parent.put(root1, root2); + size.put(root2, size.get(root2) + size.get(root1)); + } else { + parent.put(root2, root1); + size.put(root1, size.get(root1) + size.get(root2)); + } } } private T findRoot(T a) { + parent.putIfAbsent(a, a); // Ensure that the element is added + size.putIfAbsent(a, 1); // Initialize size to 1 + if (!parent.get(a).equals(a)) { - parent.put(a, findRoot(parent.get(a))); + parent.put(a, findRoot(parent.get(a))); // Path compression } return parent.get(a); } - public ImmutableEquivalenceSet build() { + public ImmutableEqualSet build() { parent.keySet().forEach(this::findRoot); - return new ImmutableEquivalenceSet<>(parent); + return new ImmutableEqualSet<>(parent); } } /** - * cal equal set for a except self + * Calculate equal set for a except self */ public Set calEqualSet(T a) { T ra = root.get(a); @@ -83,6 +92,21 @@ public class ImmutableEquivalenceSet { .collect(ImmutableSet.toImmutableSet()); } + /** + * Calculate all equal set + */ + public List> calEqualSetList() { + return root.values() + .stream() + .distinct() + .map(a -> { + T ra = root.get(a); + return root.keySet().stream() + .filter(t -> root.get(t).equals(ra)) + .collect(ImmutableSet.toImmutableSet()); + }).collect(ImmutableList.toImmutableList()); + } + public Set getAllItemSet() { return ImmutableSet.copyOf(root.keySet()); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java index b1aa25df1b..1efa94451a 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.Literal; @@ -34,6 +35,7 @@ import java.util.Set; class PredicatePropagationTest { private final SlotReference a = new SlotReference("a", SmallIntType.INSTANCE); private final SlotReference b = new SlotReference("b", BigIntType.INSTANCE); + private final SlotReference c = new SlotReference("c", BigIntType.INSTANCE); @Test void equal() { @@ -48,4 +50,18 @@ class PredicatePropagationTest { Set inferExprs = PredicatePropagation.infer(exprs); System.out.println(inferExprs); } + + @Test + void inferSlotEqual() { + Set exprs = ImmutableSet.of(new EqualTo(a, b), new EqualTo(a, c)); + Set inferExprs = PredicatePropagation.infer(exprs); + System.out.println(inferExprs); + } + + @Test + void inferComplex0() { + Set exprs = ImmutableSet.of(new EqualTo(a, b), new EqualTo(a, c), new GreaterThan(a, Literal.of(1))); + Set inferExprs = PredicatePropagation.infer(exprs); + System.out.println(inferExprs); + } } diff --git a/regression-test/data/nereids_p0/hint/fix_leading.out b/regression-test/data/nereids_p0/hint/fix_leading.out index 58122945bb..898fe5882b 100644 --- a/regression-test/data/nereids_p0/hint/fix_leading.out +++ b/regression-test/data/nereids_p0/hint/fix_leading.out @@ -9,7 +9,7 @@ PhysicalResultSink ----------PhysicalDistribute[DistributionSpecHash] ------------PhysicalOlapScan[t2] --------PhysicalDistribute[DistributionSpecHash] -----------NestedLoopJoin[CROSS_JOIN](t4.c4 = t3.c3)(t3.c3 = t4.c4) +----------NestedLoopJoin[CROSS_JOIN](t3.c3 = t4.c4) ------------PhysicalOlapScan[t3] ------------PhysicalDistribute[DistributionSpecReplicated] --------------PhysicalOlapScan[t4]