[feature](Nereids): InferPredicates support In (#29458)

This commit is contained in:
jakevin
2024-01-05 21:25:30 +08:00
committed by GitHub
parent f40cce1406
commit 7a0734dbd6
10 changed files with 220 additions and 116 deletions

View File

@ -27,7 +27,6 @@ import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
@ -37,6 +36,7 @@ import java.util.stream.Collectors;
/**
* infer additional predicates for `LogicalFilter` and `LogicalJoin`.
* <pre>
* The logic is as follows:
* 1. poll up bottom predicate then infer additional predicates
* for example:
@ -49,9 +49,9 @@ import java.util.stream.Collectors;
* select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t2.id = 1
* 2. put these predicates into `otherJoinConjuncts` , these predicates are processed in the next
* round of predicate push-down
* </pre>
*/
public class InferPredicates extends DefaultPlanRewriter<JobContext> implements CustomRewriter {
private final PredicatePropagation propagation = new PredicatePropagation();
private final PullUpPredicates pollUpPredicates = new PullUpPredicates();
@Override
@ -62,6 +62,9 @@ public class InferPredicates extends DefaultPlanRewriter<JobContext> implements
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, JobContext context) {
join = visitChildren(this, join, context);
if (join.isMarkJoin()) {
return join;
}
Plan left = join.left();
Plan right = join.right();
Set<Expression> expressions = getAllExpressions(left, right, join.getOnClauseCondition());
@ -86,7 +89,7 @@ public class InferPredicates extends DefaultPlanRewriter<JobContext> implements
break;
}
if (left != join.left() || right != join.right()) {
return join.withChildren(ImmutableList.of(left, right));
return join.withChildren(left, right);
} else {
return join;
}
@ -109,7 +112,7 @@ public class InferPredicates extends DefaultPlanRewriter<JobContext> implements
Set<Expression> baseExpressions = pullUpPredicates(left);
baseExpressions.addAll(pullUpPredicates(right));
condition.ifPresent(on -> baseExpressions.addAll(ExpressionUtils.extractConjunction(on)));
baseExpressions.addAll(propagation.infer(baseExpressions));
baseExpressions.addAll(PredicatePropagation.infer(baseExpressions));
return baseExpressions;
}

View File

@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.types.DataType;
@ -55,8 +56,7 @@ public class PredicatePropagation {
INTEGRAL(IntegralType.class),
STRING(CharacterType.class),
DATE(DateLikeType.class),
OTHER(DataType.class)
;
OTHER(DataType.class);
private final Class<? extends DataType> superClazz;
@ -65,15 +65,15 @@ public class PredicatePropagation {
}
}
private class ComparisonInferInfo {
private static class EqualInferInfo {
public final InferType inferType;
public final Optional<Expression> left;
public final Optional<Expression> right;
public final Expression left;
public final Expression right;
public final ComparisonPredicate comparisonPredicate;
public ComparisonInferInfo(InferType inferType,
Optional<Expression> left, Optional<Expression> right,
public EqualInferInfo(InferType inferType,
Expression left, Expression right,
ComparisonPredicate comparisonPredicate) {
this.inferType = inferType;
this.left = left;
@ -85,26 +85,27 @@ public class PredicatePropagation {
/**
* infer additional predicates.
*/
public Set<Expression> infer(Set<Expression> predicates) {
public static Set<Expression> infer(Set<Expression> predicates) {
Set<Expression> inferred = Sets.newHashSet();
for (Expression predicate : predicates) {
// if we support more infer predicate expression type, we should impl withInferred() method.
// And should add inferred props in withChildren() method just like ComparisonPredicate,
// and it's subclass, to mark the predicate is from infer.
if (!(predicate instanceof ComparisonPredicate)) {
if (!(predicate instanceof ComparisonPredicate
|| (predicate instanceof InPredicate && ((InPredicate) predicate).isLiteralChildren()))) {
continue;
}
ComparisonInferInfo equalInfo = getEquivalentInferInfo((ComparisonPredicate) predicate);
if (predicate instanceof InPredicate) {
continue;
}
EqualInferInfo equalInfo = getEqualInferInfo((ComparisonPredicate) predicate);
if (equalInfo.inferType == InferType.NONE) {
continue;
}
Set<Expression> newInferred = predicates.stream()
.filter(ComparisonPredicate.class::isInstance)
.filter(p -> !p.equals(predicate))
.map(ComparisonPredicate.class::cast)
.map(this::inferInferInfo)
.filter(predicateInfo -> predicateInfo.inferType != InferType.NONE)
.map(predicateInfo -> doInfer(equalInfo, predicateInfo))
.filter(p -> p instanceof ComparisonPredicate || p instanceof InPredicate)
.map(predicateInfo -> doInferPredicate(equalInfo, predicateInfo))
.filter(Objects::nonNull)
.collect(Collectors.toSet());
inferred.addAll(newInferred);
@ -113,17 +114,64 @@ public class PredicatePropagation {
return inferred;
}
private static Expression doInferPredicate(EqualInferInfo equalInfo, Expression predicate) {
Expression equalLeft = equalInfo.left;
Expression equalRight = equalInfo.right;
DataType leftType = predicate.child(0).getDataType();
InferType inferType;
if (leftType instanceof CharacterType) {
inferType = InferType.STRING;
} else if (leftType instanceof IntegralType) {
inferType = InferType.INTEGRAL;
} else if (leftType instanceof DateLikeType) {
inferType = InferType.DATE;
} else {
inferType = InferType.OTHER;
}
if (predicate instanceof ComparisonPredicate) {
ComparisonPredicate comparisonPredicate = (ComparisonPredicate) predicate;
Optional<Expression> left = validForInfer(comparisonPredicate.left(), inferType);
Optional<Expression> right = validForInfer(comparisonPredicate.right(), inferType);
if (!left.isPresent() || !right.isPresent()) {
return null;
}
} else if (predicate instanceof InPredicate) {
InPredicate inPredicate = (InPredicate) predicate;
Optional<Expression> left = validForInfer(inPredicate.getCompareExpr(), inferType);
if (!left.isPresent()) {
return null;
}
}
Expression newPredicate = predicate.rewriteUp(e -> {
if (e.equals(equalLeft)) {
return equalRight;
} else if (e.equals(equalRight)) {
return equalLeft;
} else {
return e;
}
});
if (predicate instanceof ComparisonPredicate) {
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) newPredicate).withInferred(true);
} else {
return TypeCoercionUtils.processInPredicate((InPredicate) newPredicate).withInferred(true);
}
}
/**
* Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression`
* Now only support infer `ComparisonPredicate`.
* TODO: We should determine whether `expression` satisfies the condition for replacement
* eg: Satisfy `expression` is non-deterministic
*/
private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo predicateInfo) {
Expression predicateLeft = predicateInfo.left.get();
Expression predicateRight = predicateInfo.right.get();
Expression equalLeft = equalInfo.left.get();
Expression equalRight = equalInfo.right.get();
private static Expression doInfer(EqualInferInfo equalInfo, EqualInferInfo predicateInfo) {
Expression equalLeft = equalInfo.left;
Expression equalRight = equalInfo.right;
Expression predicateLeft = predicateInfo.left;
Expression predicateRight = predicateInfo.right;
Expression newLeft = inferOneSide(predicateLeft, equalLeft, equalRight);
Expression newRight = inferOneSide(predicateRight, equalLeft, equalRight);
if (newLeft == null || newRight == null) {
@ -136,7 +184,7 @@ public class PredicatePropagation {
return DateFunctionRewrite.INSTANCE.rewrite(expr, null).withInferred(true);
}
private Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
private static Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
if (predicateOneSide instanceof SlotReference) {
if (predicateOneSide.equals(equalLeft)) {
return equalRight;
@ -153,60 +201,55 @@ public class PredicatePropagation {
return null;
}
private Optional<Expression> validForInfer(Expression expression, InferType inferType) {
private static Optional<Expression> validForInfer(Expression expression, InferType inferType) {
if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
return Optional.empty();
}
if (expression instanceof SlotReference || expression.isConstant()) {
return Optional.of(expression);
}
if (!(expression instanceof Cast)) {
return Optional.empty();
}
Cast cast = (Cast) expression;
Expression child = cast.child();
DataType dataType = cast.getDataType();
DataType childType = child.getDataType();
if (inferType == InferType.INTEGRAL) {
if (expression instanceof Cast) {
// avoid cast from wider type to narrower type, such as cast(int as smallint)
// IntegralType dataType = (IntegralType) expression.getDataType();
// DataType childType = ((Cast) expression).child().getDataType();
// if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
// return validForInfer(((Cast) expression).child(), inferType);
// }
return validForInfer(((Cast) expression).child(), inferType);
}
// avoid cast from wider type to narrower type, such as cast(int as smallint)
// IntegralType dataType = (IntegralType) expression.getDataType();
// DataType childType = ((Cast) expression).child().getDataType();
// if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
// return validForInfer(((Cast) expression).child(), inferType);
// }
return validForInfer(child, inferType);
} else if (inferType == InferType.DATE) {
if (expression instanceof Cast) {
DataType dataType = expression.getDataType();
DataType childType = ((Cast) expression).child().getDataType();
// avoid lost precision
if (dataType instanceof DateType) {
if (childType instanceof DateV2Type || childType instanceof DateType) {
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (dataType instanceof DateV2Type) {
if (childType instanceof DateType || childType instanceof DateV2Type) {
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (dataType instanceof DateTimeType) {
if (!(childType instanceof DateTimeV2Type)) {
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (dataType instanceof DateTimeV2Type) {
return validForInfer(((Cast) expression).child(), inferType);
// avoid lost precision
if (dataType instanceof DateType) {
if (childType instanceof DateV2Type || childType instanceof DateType) {
return validForInfer(child, inferType);
}
} else if (dataType instanceof DateV2Type) {
if (childType instanceof DateType || childType instanceof DateV2Type) {
return validForInfer(child, inferType);
}
} else if (dataType instanceof DateTimeType) {
if (!(childType instanceof DateTimeV2Type)) {
return validForInfer(child, inferType);
}
} else if (dataType instanceof DateTimeV2Type) {
return validForInfer(child, inferType);
}
} else if (inferType == InferType.STRING) {
if (expression instanceof Cast) {
DataType dataType = expression.getDataType();
DataType childType = ((Cast) expression).child().getDataType();
// avoid substring cast such as cast(char(3) as char(2))
if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
return validForInfer(((Cast) expression).child(), inferType);
}
// avoid substring cast such as cast(char(3) as char(2))
if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
return validForInfer(child, inferType);
}
} else {
return Optional.empty();
}
return Optional.empty();
}
private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
private static EqualInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
DataType leftType = comparisonPredicate.left().getDataType();
InferType inferType;
if (leftType instanceof CharacterType) {
@ -223,25 +266,27 @@ public class PredicatePropagation {
if (!left.isPresent() || !right.isPresent()) {
inferType = InferType.NONE;
}
return new ComparisonInferInfo(inferType, left, right, comparisonPredicate);
return new EqualInferInfo(inferType, left.orElse(comparisonPredicate.left()),
right.orElse(comparisonPredicate.right()), comparisonPredicate);
}
/**
* Currently only equivalence derivation is supported
* and requires that the left and right sides of an expression must be slot
* <p>
* TODO: NullSafeEqual
*/
private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate predicate) {
private static EqualInferInfo getEqualInferInfo(ComparisonPredicate predicate) {
if (!(predicate instanceof EqualTo)) {
return new ComparisonInferInfo(InferType.NONE,
Optional.of(predicate.left()), Optional.of(predicate.right()), predicate);
return new EqualInferInfo(InferType.NONE, predicate.left(), predicate.right(), predicate);
}
ComparisonInferInfo info = inferInferInfo(predicate);
EqualInferInfo info = inferInferInfo(predicate);
if (info.inferType == InferType.NONE) {
return info;
}
if (info.left.get() instanceof SlotReference && info.right.get() instanceof SlotReference) {
if (info.left instanceof SlotReference && info.right instanceof SlotReference) {
return info;
}
return new ComparisonInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
return new EqualInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
}
}

View File

@ -47,7 +47,6 @@ import java.util.stream.Collectors;
*/
public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void> {
PredicatePropagation propagation = new PredicatePropagation();
Map<Plan, ImmutableSet<Expression>> cache = new IdentityHashMap<>();
@Override
@ -99,6 +98,7 @@ public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void
public ImmutableSet<Expression> visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Void context) {
return cacheOrElse(aggregate, () -> {
ImmutableSet<Expression> childPredicates = aggregate.child().accept(this, context);
// TODO
Map<Expression, Slot> expressionSlotMap = aggregate.getOutputExpressions()
.stream()
.filter(this::hasAgg)
@ -130,7 +130,7 @@ public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void
private ImmutableSet<Expression> getAvailableExpressions(Collection<Expression> predicates, Plan plan) {
Set<Expression> expressions = Sets.newHashSet(predicates);
expressions.addAll(propagation.infer(expressions));
expressions.addAll(PredicatePropagation.infer(expressions));
return expressions.stream()
.filter(p -> plan.getOutputSet().containsAll(p.getInputSlots()))
.collect(ImmutableSet.toImmutableSet());

View File

@ -39,10 +39,6 @@ public class EqualTo extends EqualPredicate implements PropagateNullable {
super(ImmutableList.of(left, right), "=", inferred);
}
private EqualTo(List<Expression> children) {
this(children, false);
}
private EqualTo(List<Expression> children, boolean inferred) {
super(children, "=", inferred);
}

View File

@ -48,6 +48,12 @@ public class InPredicate extends Expression {
this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null"));
}
public InPredicate(Expression compareExpr, List<Expression> options, boolean inferred) {
super(new Builder<Expression>().add(compareExpr).addAll(options).build(), inferred);
this.compareExpr = Objects.requireNonNull(compareExpr, "Compare Expr cannot be null");
this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null"));
}
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitInPredicate(this, context);
}
@ -80,6 +86,11 @@ public class InPredicate extends Expression {
});
}
@Override
public Expression withInferred(boolean inferred) {
return new InPredicate(children.get(0), ImmutableList.copyOf(children).subList(1, children.size()), true);
}
@Override
public String toString() {
return compareExpr + " IN " + options.stream()

View File

@ -25,7 +25,7 @@ import org.apache.doris.utframe.TestWithFeService;
import org.junit.jupiter.api.Test;
public class InferPredicatesTest extends TestWithFeService implements MemoPatternMatchSupported {
class InferPredicatesTest extends TestWithFeService implements MemoPatternMatchSupported {
@Override
protected void runBeforeAll() throws Exception {
@ -77,7 +77,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
public void inferPredicatesTest01() {
void inferPredicatesTest01() {
String sql = "select * from student join score on student.id = score.sid where student.id > 1";
PlanChecker.from(connectContext)
@ -100,7 +100,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
public void inferPredicatesTest02() {
void inferPredicatesTest02() {
String sql = "select * from student join score on student.id = score.sid";
PlanChecker.from(connectContext)
@ -117,7 +117,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
public void inferPredicatesTest03() {
void inferPredicatesTest03() {
String sql = "select * from student join score on student.id = score.sid where student.id in (1,2,3)";
PlanChecker.from(connectContext)
@ -126,18 +126,17 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.matches(
logicalProject(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
logicalFilter(logicalOlapScan()).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
logicalOlapScan()
logicalFilter(logicalOlapScan()).when(filter -> ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("sid IN (1, 2, 3)"))
)
)
);
}
@Test
public void inferPredicatesTest04() {
void inferPredicatesTest04() {
String sql = "select * from student join score on student.id = score.sid and student.id in (1,2,3)";
PlanChecker.from(connectContext)
@ -146,18 +145,17 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.matches(
logicalProject(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
logicalFilter(logicalOlapScan()).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
logicalOlapScan()
logicalFilter(logicalOlapScan()).when(filter -> ExpressionUtils.isInferred(filter.getPredicate())
& filter.getPredicate().toSql().contains("sid IN (1, 2, 3)"))
)
)
);
}
@Test
public void inferPredicatesTest05() {
void inferPredicatesTest05() {
String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id where student.id > 1";
PlanChecker.from(connectContext)
@ -185,7 +183,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
public void inferPredicatesTest06() {
void inferPredicatesTest06() {
String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id and score.sid > 1";
PlanChecker.from(connectContext)
@ -213,7 +211,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
public void inferPredicatesTest07() {
void inferPredicatesTest07() {
String sql = "select * from student left join score on student.id = score.sid where student.id > 1";
PlanChecker.from(connectContext)
@ -236,7 +234,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
public void inferPredicatesTest08() {
void inferPredicatesTest08() {
String sql = "select * from student left join score on student.id = score.sid and student.id > 1";
PlanChecker.from(connectContext)
@ -256,7 +254,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
public void inferPredicatesTest09() {
void inferPredicatesTest09() {
// convert left join to inner join
String sql = "select * from student left join score on student.id = score.sid where score.sid > 1";
@ -280,7 +278,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
public void inferPredicatesTest10() {
void inferPredicatesTest10() {
String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid where t.nid > 1";
PlanChecker.from(connectContext)
@ -305,7 +303,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
public void inferPredicatesTest11() {
void inferPredicatesTest11() {
String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid and t.nid > 1";
PlanChecker.from(connectContext)
@ -327,7 +325,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
public void inferPredicatesTest12() {
void inferPredicatesTest12() {
String sql = "select * from student left join (select sid as nid, sum(grade) from score group by sid) s on s.nid = student.id where student.id > 1";
PlanChecker.from(connectContext)
@ -356,7 +354,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
public void inferPredicatesTest13() {
void inferPredicatesTest13() {
String sql = "select * from (select id, name from student where id = 1) t left join score on t.id = score.sid";
PlanChecker.from(connectContext)
@ -381,7 +379,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
public void inferPredicatesTest14() {
void inferPredicatesTest14() {
String sql = "select * from student left semi join score on student.id = score.sid where student.id > 1";
PlanChecker.from(connectContext)
@ -406,7 +404,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
public void inferPredicatesTest15() {
void inferPredicatesTest15() {
String sql = "select * from student left semi join score on student.id = score.sid and student.id > 1";
PlanChecker.from(connectContext)
@ -431,7 +429,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
public void inferPredicatesTest16() {
void inferPredicatesTest16() {
String sql = "select * from student left anti join score on student.id = score.sid and student.id > 1";
PlanChecker.from(connectContext)
@ -453,7 +451,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
public void inferPredicatesTest17() {
void inferPredicatesTest17() {
String sql = "select * from student left anti join score on student.id = score.sid and score.sid > 1";
PlanChecker.from(connectContext)
@ -475,7 +473,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
public void inferPredicatesTest18() {
void inferPredicatesTest18() {
String sql = "select * from student left anti join score on student.id = score.sid where student.id > 1";
PlanChecker.from(connectContext)
@ -500,7 +498,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
public void inferPredicatesTest19() {
void inferPredicatesTest19() {
String sql = "select * from subquery1\n"
+ "left semi join (\n"
+ " select t1.k3\n"
@ -564,7 +562,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
public void inferPredicatesTest20() {
void inferPredicatesTest20() {
String sql = "select * from student left join score on student.id = score.sid and score.sid > 1 inner join course on course.id = score.sid";
PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
PlanChecker.from(connectContext)
@ -592,7 +590,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
}
@Test
public void inferPredicatesTest21() {
void inferPredicatesTest21() {
String sql = "select * from student,score,course where student.id = score.sid and score.sid = course.id and score.sid > 1";
PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
PlanChecker.from(connectContext)
@ -623,7 +621,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
* test for #15310
*/
@Test
public void inferPredicatesTest22() {
void inferPredicatesTest22() {
String sql = "select * from student join (select sid as id1, sid as id2, grade from score) s on student.id = s.id1 where s.id1 > 1";
PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
PlanChecker.from(connectContext)
@ -651,7 +649,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
* in this case, filter on relation s1 should not contain s1.id = 1.
*/
@Test
public void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() {
void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() {
String sql = "select * from student s1"
+ " left join (select sid as id1, sid as id2, grade from score) s2 on s1.id = s2.id1 and s1.id = 1"
+ " join (select sid as id1, sid as id2, grade from score) s3 on s1.id = s3.id1 where s1.id = 2";

View File

@ -0,0 +1,51 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.SmallIntType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import org.junit.jupiter.api.Test;
import java.util.Set;
class PredicatePropagationTest {
private final SlotReference a = new SlotReference("a", SmallIntType.INSTANCE);
private final SlotReference b = new SlotReference("b", BigIntType.INSTANCE);
@Test
void equal() {
Set<Expression> exprs = ImmutableSet.of(new EqualTo(a, b), new EqualTo(a, Literal.of(1)));
Set<Expression> inferExprs = PredicatePropagation.infer(exprs);
System.out.println(inferExprs);
}
@Test
void in() {
Set<Expression> exprs = ImmutableSet.of(new EqualTo(a, b), new InPredicate(a, ImmutableList.of(Literal.of(1))));
Set<Expression> inferExprs = PredicatePropagation.infer(exprs);
System.out.println(inferExprs);
}
}

View File

@ -9,7 +9,7 @@ PhysicalResultSink
----------PhysicalDistribute[DistributionSpecHash]
------------PhysicalOlapScan[t2]
--------PhysicalDistribute[DistributionSpecHash]
----------NestedLoopJoin[CROSS_JOIN]
----------NestedLoopJoin[CROSS_JOIN](t4.c4 = t3.c3)(t3.c3 = t4.c4)
------------PhysicalOlapScan[t3]
------------PhysicalDistribute[DistributionSpecReplicated]
--------------PhysicalOlapScan[t4]

View File

@ -2609,7 +2609,7 @@ PhysicalResultSink
------------PhysicalProject
--------------PhysicalOlapScan[t2]
------------PhysicalDistribute[DistributionSpecReplicated]
--------------NestedLoopJoin[CROSS_JOIN]
--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
----------------PhysicalProject
------------------PhysicalOlapScan[t1]
----------------PhysicalDistribute[DistributionSpecReplicated]
@ -2631,7 +2631,7 @@ PhysicalResultSink
------------PhysicalProject
--------------PhysicalOlapScan[t2]
------------PhysicalDistribute[DistributionSpecReplicated]
--------------NestedLoopJoin[CROSS_JOIN]
--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
----------------PhysicalProject
------------------PhysicalOlapScan[t3]
----------------PhysicalDistribute[DistributionSpecReplicated]
@ -2745,7 +2745,7 @@ PhysicalResultSink
------------PhysicalProject
--------------PhysicalOlapScan[t2]
------------PhysicalDistribute[DistributionSpecReplicated]
--------------NestedLoopJoin[CROSS_JOIN]
--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
----------------PhysicalProject
------------------PhysicalOlapScan[t1]
----------------PhysicalDistribute[DistributionSpecReplicated]
@ -2767,7 +2767,7 @@ PhysicalResultSink
------------PhysicalProject
--------------PhysicalOlapScan[t2]
------------PhysicalDistribute[DistributionSpecReplicated]
--------------NestedLoopJoin[CROSS_JOIN]
--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
----------------PhysicalProject
------------------PhysicalOlapScan[t3]
----------------PhysicalDistribute[DistributionSpecReplicated]
@ -2881,7 +2881,7 @@ PhysicalResultSink
------------PhysicalProject
--------------PhysicalOlapScan[t2]
------------PhysicalDistribute[DistributionSpecHash]
--------------NestedLoopJoin[CROSS_JOIN]
--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
----------------PhysicalProject
------------------PhysicalOlapScan[t1]
----------------PhysicalDistribute[DistributionSpecReplicated]
@ -2903,7 +2903,7 @@ PhysicalResultSink
------------PhysicalProject
--------------PhysicalOlapScan[t2]
------------PhysicalDistribute[DistributionSpecHash]
--------------NestedLoopJoin[CROSS_JOIN]
--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
----------------PhysicalProject
------------------PhysicalOlapScan[t3]
----------------PhysicalDistribute[DistributionSpecReplicated]

View File

@ -41,7 +41,7 @@ suite("test_infer_predicate") {
explain {
sql "select * from infer_tb1 inner join infer_tb2 where cast(infer_tb2.k4 as int) = infer_tb1.k2 and infer_tb2.k4 = 1;"
contains "PREDICATES: k2"
contains "PREDICATES: CAST(k2"
}
explain {