diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java index 3c4593df54..36236c3db8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java @@ -27,7 +27,6 @@ import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.PlanUtils; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; @@ -37,6 +36,7 @@ import java.util.stream.Collectors; /** * infer additional predicates for `LogicalFilter` and `LogicalJoin`. + *
  * The logic is as follows:
  * 1. poll up bottom predicate then infer additional predicates
  *   for example:
@@ -49,9 +49,9 @@ import java.util.stream.Collectors;
  *      select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t2.id = 1
  * 2. put these predicates into `otherJoinConjuncts` , these predicates are processed in the next
  *   round of predicate push-down
+ * 
*/ public class InferPredicates extends DefaultPlanRewriter implements CustomRewriter { - private final PredicatePropagation propagation = new PredicatePropagation(); private final PullUpPredicates pollUpPredicates = new PullUpPredicates(); @Override @@ -62,6 +62,9 @@ public class InferPredicates extends DefaultPlanRewriter implements @Override public Plan visitLogicalJoin(LogicalJoin join, JobContext context) { join = visitChildren(this, join, context); + if (join.isMarkJoin()) { + return join; + } Plan left = join.left(); Plan right = join.right(); Set expressions = getAllExpressions(left, right, join.getOnClauseCondition()); @@ -86,7 +89,7 @@ public class InferPredicates extends DefaultPlanRewriter implements break; } if (left != join.left() || right != join.right()) { - return join.withChildren(ImmutableList.of(left, right)); + return join.withChildren(left, right); } else { return join; } @@ -109,7 +112,7 @@ public class InferPredicates extends DefaultPlanRewriter implements Set baseExpressions = pullUpPredicates(left); baseExpressions.addAll(pullUpPredicates(right)); condition.ifPresent(on -> baseExpressions.addAll(ExpressionUtils.extractConjunction(on))); - baseExpressions.addAll(propagation.infer(baseExpressions)); + baseExpressions.addAll(PredicatePropagation.infer(baseExpressions)); return baseExpressions; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java index 72e9023dc4..7788bbb7f0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; import org.apache.doris.nereids.types.DataType; @@ -55,8 +56,7 @@ public class PredicatePropagation { INTEGRAL(IntegralType.class), STRING(CharacterType.class), DATE(DateLikeType.class), - OTHER(DataType.class) - ; + OTHER(DataType.class); private final Class superClazz; @@ -65,15 +65,15 @@ public class PredicatePropagation { } } - private class ComparisonInferInfo { + private static class EqualInferInfo { public final InferType inferType; - public final Optional left; - public final Optional right; + public final Expression left; + public final Expression right; public final ComparisonPredicate comparisonPredicate; - public ComparisonInferInfo(InferType inferType, - Optional left, Optional right, + public EqualInferInfo(InferType inferType, + Expression left, Expression right, ComparisonPredicate comparisonPredicate) { this.inferType = inferType; this.left = left; @@ -85,26 +85,27 @@ public class PredicatePropagation { /** * infer additional predicates. */ - public Set infer(Set predicates) { + public static Set infer(Set predicates) { Set inferred = Sets.newHashSet(); for (Expression predicate : predicates) { // if we support more infer predicate expression type, we should impl withInferred() method. // And should add inferred props in withChildren() method just like ComparisonPredicate, // and it's subclass, to mark the predicate is from infer. - if (!(predicate instanceof ComparisonPredicate)) { + if (!(predicate instanceof ComparisonPredicate + || (predicate instanceof InPredicate && ((InPredicate) predicate).isLiteralChildren()))) { continue; } - ComparisonInferInfo equalInfo = getEquivalentInferInfo((ComparisonPredicate) predicate); + if (predicate instanceof InPredicate) { + continue; + } + EqualInferInfo equalInfo = getEqualInferInfo((ComparisonPredicate) predicate); if (equalInfo.inferType == InferType.NONE) { continue; } Set newInferred = predicates.stream() - .filter(ComparisonPredicate.class::isInstance) .filter(p -> !p.equals(predicate)) - .map(ComparisonPredicate.class::cast) - .map(this::inferInferInfo) - .filter(predicateInfo -> predicateInfo.inferType != InferType.NONE) - .map(predicateInfo -> doInfer(equalInfo, predicateInfo)) + .filter(p -> p instanceof ComparisonPredicate || p instanceof InPredicate) + .map(predicateInfo -> doInferPredicate(equalInfo, predicateInfo)) .filter(Objects::nonNull) .collect(Collectors.toSet()); inferred.addAll(newInferred); @@ -113,17 +114,64 @@ public class PredicatePropagation { return inferred; } + private static Expression doInferPredicate(EqualInferInfo equalInfo, Expression predicate) { + Expression equalLeft = equalInfo.left; + Expression equalRight = equalInfo.right; + + DataType leftType = predicate.child(0).getDataType(); + InferType inferType; + if (leftType instanceof CharacterType) { + inferType = InferType.STRING; + } else if (leftType instanceof IntegralType) { + inferType = InferType.INTEGRAL; + } else if (leftType instanceof DateLikeType) { + inferType = InferType.DATE; + } else { + inferType = InferType.OTHER; + } + if (predicate instanceof ComparisonPredicate) { + ComparisonPredicate comparisonPredicate = (ComparisonPredicate) predicate; + Optional left = validForInfer(comparisonPredicate.left(), inferType); + Optional right = validForInfer(comparisonPredicate.right(), inferType); + if (!left.isPresent() || !right.isPresent()) { + return null; + } + } else if (predicate instanceof InPredicate) { + InPredicate inPredicate = (InPredicate) predicate; + Optional left = validForInfer(inPredicate.getCompareExpr(), inferType); + if (!left.isPresent()) { + return null; + } + } + + Expression newPredicate = predicate.rewriteUp(e -> { + if (e.equals(equalLeft)) { + return equalRight; + } else if (e.equals(equalRight)) { + return equalLeft; + } else { + return e; + } + }); + if (predicate instanceof ComparisonPredicate) { + return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) newPredicate).withInferred(true); + } else { + return TypeCoercionUtils.processInPredicate((InPredicate) newPredicate).withInferred(true); + } + } + /** * Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression` * Now only support infer `ComparisonPredicate`. * TODO: We should determine whether `expression` satisfies the condition for replacement * eg: Satisfy `expression` is non-deterministic */ - private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo predicateInfo) { - Expression predicateLeft = predicateInfo.left.get(); - Expression predicateRight = predicateInfo.right.get(); - Expression equalLeft = equalInfo.left.get(); - Expression equalRight = equalInfo.right.get(); + private static Expression doInfer(EqualInferInfo equalInfo, EqualInferInfo predicateInfo) { + Expression equalLeft = equalInfo.left; + Expression equalRight = equalInfo.right; + + Expression predicateLeft = predicateInfo.left; + Expression predicateRight = predicateInfo.right; Expression newLeft = inferOneSide(predicateLeft, equalLeft, equalRight); Expression newRight = inferOneSide(predicateRight, equalLeft, equalRight); if (newLeft == null || newRight == null) { @@ -136,7 +184,7 @@ public class PredicatePropagation { return DateFunctionRewrite.INSTANCE.rewrite(expr, null).withInferred(true); } - private Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) { + private static Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) { if (predicateOneSide instanceof SlotReference) { if (predicateOneSide.equals(equalLeft)) { return equalRight; @@ -153,60 +201,55 @@ public class PredicatePropagation { return null; } - private Optional validForInfer(Expression expression, InferType inferType) { + private static Optional validForInfer(Expression expression, InferType inferType) { if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) { return Optional.empty(); } if (expression instanceof SlotReference || expression.isConstant()) { return Optional.of(expression); } + if (!(expression instanceof Cast)) { + return Optional.empty(); + } + Cast cast = (Cast) expression; + Expression child = cast.child(); + DataType dataType = cast.getDataType(); + DataType childType = child.getDataType(); if (inferType == InferType.INTEGRAL) { - if (expression instanceof Cast) { - // avoid cast from wider type to narrower type, such as cast(int as smallint) - // IntegralType dataType = (IntegralType) expression.getDataType(); - // DataType childType = ((Cast) expression).child().getDataType(); - // if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) { - // return validForInfer(((Cast) expression).child(), inferType); - // } - return validForInfer(((Cast) expression).child(), inferType); - } + // avoid cast from wider type to narrower type, such as cast(int as smallint) + // IntegralType dataType = (IntegralType) expression.getDataType(); + // DataType childType = ((Cast) expression).child().getDataType(); + // if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) { + // return validForInfer(((Cast) expression).child(), inferType); + // } + return validForInfer(child, inferType); } else if (inferType == InferType.DATE) { - if (expression instanceof Cast) { - DataType dataType = expression.getDataType(); - DataType childType = ((Cast) expression).child().getDataType(); - // avoid lost precision - if (dataType instanceof DateType) { - if (childType instanceof DateV2Type || childType instanceof DateType) { - return validForInfer(((Cast) expression).child(), inferType); - } - } else if (dataType instanceof DateV2Type) { - if (childType instanceof DateType || childType instanceof DateV2Type) { - return validForInfer(((Cast) expression).child(), inferType); - } - } else if (dataType instanceof DateTimeType) { - if (!(childType instanceof DateTimeV2Type)) { - return validForInfer(((Cast) expression).child(), inferType); - } - } else if (dataType instanceof DateTimeV2Type) { - return validForInfer(((Cast) expression).child(), inferType); + // avoid lost precision + if (dataType instanceof DateType) { + if (childType instanceof DateV2Type || childType instanceof DateType) { + return validForInfer(child, inferType); } + } else if (dataType instanceof DateV2Type) { + if (childType instanceof DateType || childType instanceof DateV2Type) { + return validForInfer(child, inferType); + } + } else if (dataType instanceof DateTimeType) { + if (!(childType instanceof DateTimeV2Type)) { + return validForInfer(child, inferType); + } + } else if (dataType instanceof DateTimeV2Type) { + return validForInfer(child, inferType); } } else if (inferType == InferType.STRING) { - if (expression instanceof Cast) { - DataType dataType = expression.getDataType(); - DataType childType = ((Cast) expression).child().getDataType(); - // avoid substring cast such as cast(char(3) as char(2)) - if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) { - return validForInfer(((Cast) expression).child(), inferType); - } + // avoid substring cast such as cast(char(3) as char(2)) + if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) { + return validForInfer(child, inferType); } - } else { - return Optional.empty(); } return Optional.empty(); } - private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) { + private static EqualInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) { DataType leftType = comparisonPredicate.left().getDataType(); InferType inferType; if (leftType instanceof CharacterType) { @@ -223,25 +266,27 @@ public class PredicatePropagation { if (!left.isPresent() || !right.isPresent()) { inferType = InferType.NONE; } - return new ComparisonInferInfo(inferType, left, right, comparisonPredicate); + return new EqualInferInfo(inferType, left.orElse(comparisonPredicate.left()), + right.orElse(comparisonPredicate.right()), comparisonPredicate); } /** * Currently only equivalence derivation is supported * and requires that the left and right sides of an expression must be slot + *

+ * TODO: NullSafeEqual */ - private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate predicate) { + private static EqualInferInfo getEqualInferInfo(ComparisonPredicate predicate) { if (!(predicate instanceof EqualTo)) { - return new ComparisonInferInfo(InferType.NONE, - Optional.of(predicate.left()), Optional.of(predicate.right()), predicate); + return new EqualInferInfo(InferType.NONE, predicate.left(), predicate.right(), predicate); } - ComparisonInferInfo info = inferInferInfo(predicate); + EqualInferInfo info = inferInferInfo(predicate); if (info.inferType == InferType.NONE) { return info; } - if (info.left.get() instanceof SlotReference && info.right.get() instanceof SlotReference) { + if (info.left instanceof SlotReference && info.right instanceof SlotReference) { return info; } - return new ComparisonInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate); + return new EqualInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java index 1a198c76ea..26e1358c2e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java @@ -47,7 +47,6 @@ import java.util.stream.Collectors; */ public class PullUpPredicates extends PlanVisitor, Void> { - PredicatePropagation propagation = new PredicatePropagation(); Map> cache = new IdentityHashMap<>(); @Override @@ -99,6 +98,7 @@ public class PullUpPredicates extends PlanVisitor, Void public ImmutableSet visitLogicalAggregate(LogicalAggregate aggregate, Void context) { return cacheOrElse(aggregate, () -> { ImmutableSet childPredicates = aggregate.child().accept(this, context); + // TODO Map expressionSlotMap = aggregate.getOutputExpressions() .stream() .filter(this::hasAgg) @@ -130,7 +130,7 @@ public class PullUpPredicates extends PlanVisitor, Void private ImmutableSet getAvailableExpressions(Collection predicates, Plan plan) { Set expressions = Sets.newHashSet(predicates); - expressions.addAll(propagation.infer(expressions)); + expressions.addAll(PredicatePropagation.infer(expressions)); return expressions.stream() .filter(p -> plan.getOutputSet().containsAll(p.getInputSlots())) .collect(ImmutableSet.toImmutableSet()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java index 2704d44655..3e71b3b89a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java @@ -39,10 +39,6 @@ public class EqualTo extends EqualPredicate implements PropagateNullable { super(ImmutableList.of(left, right), "=", inferred); } - private EqualTo(List children) { - this(children, false); - } - private EqualTo(List children, boolean inferred) { super(children, "=", inferred); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java index d839a1e906..c86a074dcf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java @@ -48,6 +48,12 @@ public class InPredicate extends Expression { this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null")); } + public InPredicate(Expression compareExpr, List options, boolean inferred) { + super(new Builder().add(compareExpr).addAll(options).build(), inferred); + this.compareExpr = Objects.requireNonNull(compareExpr, "Compare Expr cannot be null"); + this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null")); + } + public R accept(ExpressionVisitor visitor, C context) { return visitor.visitInPredicate(this, context); } @@ -80,6 +86,11 @@ public class InPredicate extends Expression { }); } + @Override + public Expression withInferred(boolean inferred) { + return new InPredicate(children.get(0), ImmutableList.copyOf(children).subList(1, children.size()), true); + } + @Override public String toString() { return compareExpr + " IN " + options.stream() diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java index c910e98fcd..0708ea3f17 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java @@ -25,7 +25,7 @@ import org.apache.doris.utframe.TestWithFeService; import org.junit.jupiter.api.Test; -public class InferPredicatesTest extends TestWithFeService implements MemoPatternMatchSupported { +class InferPredicatesTest extends TestWithFeService implements MemoPatternMatchSupported { @Override protected void runBeforeAll() throws Exception { @@ -77,7 +77,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter } @Test - public void inferPredicatesTest01() { + void inferPredicatesTest01() { String sql = "select * from student join score on student.id = score.sid where student.id > 1"; PlanChecker.from(connectContext) @@ -100,7 +100,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter } @Test - public void inferPredicatesTest02() { + void inferPredicatesTest02() { String sql = "select * from student join score on student.id = score.sid"; PlanChecker.from(connectContext) @@ -117,7 +117,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter } @Test - public void inferPredicatesTest03() { + void inferPredicatesTest03() { String sql = "select * from student join score on student.id = score.sid where student.id in (1,2,3)"; PlanChecker.from(connectContext) @@ -126,18 +126,17 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter .matches( logicalProject( logicalJoin( - logicalFilter( - logicalOlapScan() - ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + logicalFilter(logicalOlapScan()).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) & filter.getPredicate().toSql().contains("id IN (1, 2, 3)")), - logicalOlapScan() + logicalFilter(logicalOlapScan()).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid IN (1, 2, 3)")) ) ) ); } @Test - public void inferPredicatesTest04() { + void inferPredicatesTest04() { String sql = "select * from student join score on student.id = score.sid and student.id in (1,2,3)"; PlanChecker.from(connectContext) @@ -146,18 +145,17 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter .matches( logicalProject( logicalJoin( - logicalFilter( - logicalOlapScan() - ).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) + logicalFilter(logicalOlapScan()).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate()) & filter.getPredicate().toSql().contains("id IN (1, 2, 3)")), - logicalOlapScan() + logicalFilter(logicalOlapScan()).when(filter -> ExpressionUtils.isInferred(filter.getPredicate()) + & filter.getPredicate().toSql().contains("sid IN (1, 2, 3)")) ) ) ); } @Test - public void inferPredicatesTest05() { + void inferPredicatesTest05() { String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id where student.id > 1"; PlanChecker.from(connectContext) @@ -185,7 +183,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter } @Test - public void inferPredicatesTest06() { + void inferPredicatesTest06() { String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id and score.sid > 1"; PlanChecker.from(connectContext) @@ -213,7 +211,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter } @Test - public void inferPredicatesTest07() { + void inferPredicatesTest07() { String sql = "select * from student left join score on student.id = score.sid where student.id > 1"; PlanChecker.from(connectContext) @@ -236,7 +234,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter } @Test - public void inferPredicatesTest08() { + void inferPredicatesTest08() { String sql = "select * from student left join score on student.id = score.sid and student.id > 1"; PlanChecker.from(connectContext) @@ -256,7 +254,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter } @Test - public void inferPredicatesTest09() { + void inferPredicatesTest09() { // convert left join to inner join String sql = "select * from student left join score on student.id = score.sid where score.sid > 1"; @@ -280,7 +278,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter } @Test - public void inferPredicatesTest10() { + void inferPredicatesTest10() { String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid where t.nid > 1"; PlanChecker.from(connectContext) @@ -305,7 +303,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter } @Test - public void inferPredicatesTest11() { + void inferPredicatesTest11() { String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid and t.nid > 1"; PlanChecker.from(connectContext) @@ -327,7 +325,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter } @Test - public void inferPredicatesTest12() { + void inferPredicatesTest12() { String sql = "select * from student left join (select sid as nid, sum(grade) from score group by sid) s on s.nid = student.id where student.id > 1"; PlanChecker.from(connectContext) @@ -356,7 +354,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter } @Test - public void inferPredicatesTest13() { + void inferPredicatesTest13() { String sql = "select * from (select id, name from student where id = 1) t left join score on t.id = score.sid"; PlanChecker.from(connectContext) @@ -381,7 +379,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter } @Test - public void inferPredicatesTest14() { + void inferPredicatesTest14() { String sql = "select * from student left semi join score on student.id = score.sid where student.id > 1"; PlanChecker.from(connectContext) @@ -406,7 +404,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter } @Test - public void inferPredicatesTest15() { + void inferPredicatesTest15() { String sql = "select * from student left semi join score on student.id = score.sid and student.id > 1"; PlanChecker.from(connectContext) @@ -431,7 +429,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter } @Test - public void inferPredicatesTest16() { + void inferPredicatesTest16() { String sql = "select * from student left anti join score on student.id = score.sid and student.id > 1"; PlanChecker.from(connectContext) @@ -453,7 +451,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter } @Test - public void inferPredicatesTest17() { + void inferPredicatesTest17() { String sql = "select * from student left anti join score on student.id = score.sid and score.sid > 1"; PlanChecker.from(connectContext) @@ -475,7 +473,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter } @Test - public void inferPredicatesTest18() { + void inferPredicatesTest18() { String sql = "select * from student left anti join score on student.id = score.sid where student.id > 1"; PlanChecker.from(connectContext) @@ -500,7 +498,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter } @Test - public void inferPredicatesTest19() { + void inferPredicatesTest19() { String sql = "select * from subquery1\n" + "left semi join (\n" + " select t1.k3\n" @@ -564,7 +562,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter } @Test - public void inferPredicatesTest20() { + void inferPredicatesTest20() { String sql = "select * from student left join score on student.id = score.sid and score.sid > 1 inner join course on course.id = score.sid"; PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree(); PlanChecker.from(connectContext) @@ -592,7 +590,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter } @Test - public void inferPredicatesTest21() { + void inferPredicatesTest21() { String sql = "select * from student,score,course where student.id = score.sid and score.sid = course.id and score.sid > 1"; PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree(); PlanChecker.from(connectContext) @@ -623,7 +621,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter * test for #15310 */ @Test - public void inferPredicatesTest22() { + void inferPredicatesTest22() { String sql = "select * from student join (select sid as id1, sid as id2, grade from score) s on student.id = s.id1 where s.id1 > 1"; PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree(); PlanChecker.from(connectContext) @@ -651,7 +649,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter * in this case, filter on relation s1 should not contain s1.id = 1. */ @Test - public void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() { + void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() { String sql = "select * from student s1" + " left join (select sid as id1, sid as id2, grade from score) s2 on s1.id = s2.id1 and s1.id = 1" + " join (select sid as id1, sid as id2, grade from score) s3 on s1.id = s3.id1 where s1.id = 2"; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java new file mode 100644 index 0000000000..b1aa25df1b --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.InPredicate; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.SmallIntType; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.junit.jupiter.api.Test; + +import java.util.Set; + +class PredicatePropagationTest { + private final SlotReference a = new SlotReference("a", SmallIntType.INSTANCE); + private final SlotReference b = new SlotReference("b", BigIntType.INSTANCE); + + @Test + void equal() { + Set exprs = ImmutableSet.of(new EqualTo(a, b), new EqualTo(a, Literal.of(1))); + Set inferExprs = PredicatePropagation.infer(exprs); + System.out.println(inferExprs); + } + + @Test + void in() { + Set exprs = ImmutableSet.of(new EqualTo(a, b), new InPredicate(a, ImmutableList.of(Literal.of(1)))); + Set inferExprs = PredicatePropagation.infer(exprs); + System.out.println(inferExprs); + } +} diff --git a/regression-test/data/nereids_p0/hint/fix_leading.out b/regression-test/data/nereids_p0/hint/fix_leading.out index a3ca4f5411..58122945bb 100644 --- a/regression-test/data/nereids_p0/hint/fix_leading.out +++ b/regression-test/data/nereids_p0/hint/fix_leading.out @@ -9,7 +9,7 @@ PhysicalResultSink ----------PhysicalDistribute[DistributionSpecHash] ------------PhysicalOlapScan[t2] --------PhysicalDistribute[DistributionSpecHash] -----------NestedLoopJoin[CROSS_JOIN] +----------NestedLoopJoin[CROSS_JOIN](t4.c4 = t3.c3)(t3.c3 = t4.c4) ------------PhysicalOlapScan[t3] ------------PhysicalDistribute[DistributionSpecReplicated] --------------PhysicalOlapScan[t4] diff --git a/regression-test/data/nereids_p0/hint/test_leading.out b/regression-test/data/nereids_p0/hint/test_leading.out index d1bd8f8bd2..fe3831a9fc 100644 --- a/regression-test/data/nereids_p0/hint/test_leading.out +++ b/regression-test/data/nereids_p0/hint/test_leading.out @@ -2609,7 +2609,7 @@ PhysicalResultSink ------------PhysicalProject --------------PhysicalOlapScan[t2] ------------PhysicalDistribute[DistributionSpecReplicated] ---------------NestedLoopJoin[CROSS_JOIN] +--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3) ----------------PhysicalProject ------------------PhysicalOlapScan[t1] ----------------PhysicalDistribute[DistributionSpecReplicated] @@ -2631,7 +2631,7 @@ PhysicalResultSink ------------PhysicalProject --------------PhysicalOlapScan[t2] ------------PhysicalDistribute[DistributionSpecReplicated] ---------------NestedLoopJoin[CROSS_JOIN] +--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3) ----------------PhysicalProject ------------------PhysicalOlapScan[t3] ----------------PhysicalDistribute[DistributionSpecReplicated] @@ -2745,7 +2745,7 @@ PhysicalResultSink ------------PhysicalProject --------------PhysicalOlapScan[t2] ------------PhysicalDistribute[DistributionSpecReplicated] ---------------NestedLoopJoin[CROSS_JOIN] +--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3) ----------------PhysicalProject ------------------PhysicalOlapScan[t1] ----------------PhysicalDistribute[DistributionSpecReplicated] @@ -2767,7 +2767,7 @@ PhysicalResultSink ------------PhysicalProject --------------PhysicalOlapScan[t2] ------------PhysicalDistribute[DistributionSpecReplicated] ---------------NestedLoopJoin[CROSS_JOIN] +--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3) ----------------PhysicalProject ------------------PhysicalOlapScan[t3] ----------------PhysicalDistribute[DistributionSpecReplicated] @@ -2881,7 +2881,7 @@ PhysicalResultSink ------------PhysicalProject --------------PhysicalOlapScan[t2] ------------PhysicalDistribute[DistributionSpecHash] ---------------NestedLoopJoin[CROSS_JOIN] +--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3) ----------------PhysicalProject ------------------PhysicalOlapScan[t1] ----------------PhysicalDistribute[DistributionSpecReplicated] @@ -2903,7 +2903,7 @@ PhysicalResultSink ------------PhysicalProject --------------PhysicalOlapScan[t2] ------------PhysicalDistribute[DistributionSpecHash] ---------------NestedLoopJoin[CROSS_JOIN] +--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3) ----------------PhysicalProject ------------------PhysicalOlapScan[t3] ----------------PhysicalDistribute[DistributionSpecReplicated] diff --git a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy index c5942680ea..55645ed8ea 100644 --- a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy +++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy @@ -41,7 +41,7 @@ suite("test_infer_predicate") { explain { sql "select * from infer_tb1 inner join infer_tb2 where cast(infer_tb2.k4 as int) = infer_tb1.k2 and infer_tb2.k4 = 1;" - contains "PREDICATES: k2" + contains "PREDICATES: CAST(k2" } explain {