[feature](Nereids): InferPredicates support In (#29458)
This commit is contained in:
@ -27,7 +27,6 @@ import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.PlanUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
@ -37,6 +36,7 @@ import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* infer additional predicates for `LogicalFilter` and `LogicalJoin`.
|
||||
* <pre>
|
||||
* The logic is as follows:
|
||||
* 1. poll up bottom predicate then infer additional predicates
|
||||
* for example:
|
||||
@ -49,9 +49,9 @@ import java.util.stream.Collectors;
|
||||
* select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t2.id = 1
|
||||
* 2. put these predicates into `otherJoinConjuncts` , these predicates are processed in the next
|
||||
* round of predicate push-down
|
||||
* </pre>
|
||||
*/
|
||||
public class InferPredicates extends DefaultPlanRewriter<JobContext> implements CustomRewriter {
|
||||
private final PredicatePropagation propagation = new PredicatePropagation();
|
||||
private final PullUpPredicates pollUpPredicates = new PullUpPredicates();
|
||||
|
||||
@Override
|
||||
@ -62,6 +62,9 @@ public class InferPredicates extends DefaultPlanRewriter<JobContext> implements
|
||||
@Override
|
||||
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, JobContext context) {
|
||||
join = visitChildren(this, join, context);
|
||||
if (join.isMarkJoin()) {
|
||||
return join;
|
||||
}
|
||||
Plan left = join.left();
|
||||
Plan right = join.right();
|
||||
Set<Expression> expressions = getAllExpressions(left, right, join.getOnClauseCondition());
|
||||
@ -86,7 +89,7 @@ public class InferPredicates extends DefaultPlanRewriter<JobContext> implements
|
||||
break;
|
||||
}
|
||||
if (left != join.left() || right != join.right()) {
|
||||
return join.withChildren(ImmutableList.of(left, right));
|
||||
return join.withChildren(left, right);
|
||||
} else {
|
||||
return join;
|
||||
}
|
||||
@ -109,7 +112,7 @@ public class InferPredicates extends DefaultPlanRewriter<JobContext> implements
|
||||
Set<Expression> baseExpressions = pullUpPredicates(left);
|
||||
baseExpressions.addAll(pullUpPredicates(right));
|
||||
condition.ifPresent(on -> baseExpressions.addAll(ExpressionUtils.extractConjunction(on)));
|
||||
baseExpressions.addAll(propagation.infer(baseExpressions));
|
||||
baseExpressions.addAll(PredicatePropagation.infer(baseExpressions));
|
||||
return baseExpressions;
|
||||
}
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.InPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
@ -55,8 +56,7 @@ public class PredicatePropagation {
|
||||
INTEGRAL(IntegralType.class),
|
||||
STRING(CharacterType.class),
|
||||
DATE(DateLikeType.class),
|
||||
OTHER(DataType.class)
|
||||
;
|
||||
OTHER(DataType.class);
|
||||
|
||||
private final Class<? extends DataType> superClazz;
|
||||
|
||||
@ -65,15 +65,15 @@ public class PredicatePropagation {
|
||||
}
|
||||
}
|
||||
|
||||
private class ComparisonInferInfo {
|
||||
private static class EqualInferInfo {
|
||||
|
||||
public final InferType inferType;
|
||||
public final Optional<Expression> left;
|
||||
public final Optional<Expression> right;
|
||||
public final Expression left;
|
||||
public final Expression right;
|
||||
public final ComparisonPredicate comparisonPredicate;
|
||||
|
||||
public ComparisonInferInfo(InferType inferType,
|
||||
Optional<Expression> left, Optional<Expression> right,
|
||||
public EqualInferInfo(InferType inferType,
|
||||
Expression left, Expression right,
|
||||
ComparisonPredicate comparisonPredicate) {
|
||||
this.inferType = inferType;
|
||||
this.left = left;
|
||||
@ -85,26 +85,27 @@ public class PredicatePropagation {
|
||||
/**
|
||||
* infer additional predicates.
|
||||
*/
|
||||
public Set<Expression> infer(Set<Expression> predicates) {
|
||||
public static Set<Expression> infer(Set<Expression> predicates) {
|
||||
Set<Expression> inferred = Sets.newHashSet();
|
||||
for (Expression predicate : predicates) {
|
||||
// if we support more infer predicate expression type, we should impl withInferred() method.
|
||||
// And should add inferred props in withChildren() method just like ComparisonPredicate,
|
||||
// and it's subclass, to mark the predicate is from infer.
|
||||
if (!(predicate instanceof ComparisonPredicate)) {
|
||||
if (!(predicate instanceof ComparisonPredicate
|
||||
|| (predicate instanceof InPredicate && ((InPredicate) predicate).isLiteralChildren()))) {
|
||||
continue;
|
||||
}
|
||||
ComparisonInferInfo equalInfo = getEquivalentInferInfo((ComparisonPredicate) predicate);
|
||||
if (predicate instanceof InPredicate) {
|
||||
continue;
|
||||
}
|
||||
EqualInferInfo equalInfo = getEqualInferInfo((ComparisonPredicate) predicate);
|
||||
if (equalInfo.inferType == InferType.NONE) {
|
||||
continue;
|
||||
}
|
||||
Set<Expression> newInferred = predicates.stream()
|
||||
.filter(ComparisonPredicate.class::isInstance)
|
||||
.filter(p -> !p.equals(predicate))
|
||||
.map(ComparisonPredicate.class::cast)
|
||||
.map(this::inferInferInfo)
|
||||
.filter(predicateInfo -> predicateInfo.inferType != InferType.NONE)
|
||||
.map(predicateInfo -> doInfer(equalInfo, predicateInfo))
|
||||
.filter(p -> p instanceof ComparisonPredicate || p instanceof InPredicate)
|
||||
.map(predicateInfo -> doInferPredicate(equalInfo, predicateInfo))
|
||||
.filter(Objects::nonNull)
|
||||
.collect(Collectors.toSet());
|
||||
inferred.addAll(newInferred);
|
||||
@ -113,17 +114,64 @@ public class PredicatePropagation {
|
||||
return inferred;
|
||||
}
|
||||
|
||||
private static Expression doInferPredicate(EqualInferInfo equalInfo, Expression predicate) {
|
||||
Expression equalLeft = equalInfo.left;
|
||||
Expression equalRight = equalInfo.right;
|
||||
|
||||
DataType leftType = predicate.child(0).getDataType();
|
||||
InferType inferType;
|
||||
if (leftType instanceof CharacterType) {
|
||||
inferType = InferType.STRING;
|
||||
} else if (leftType instanceof IntegralType) {
|
||||
inferType = InferType.INTEGRAL;
|
||||
} else if (leftType instanceof DateLikeType) {
|
||||
inferType = InferType.DATE;
|
||||
} else {
|
||||
inferType = InferType.OTHER;
|
||||
}
|
||||
if (predicate instanceof ComparisonPredicate) {
|
||||
ComparisonPredicate comparisonPredicate = (ComparisonPredicate) predicate;
|
||||
Optional<Expression> left = validForInfer(comparisonPredicate.left(), inferType);
|
||||
Optional<Expression> right = validForInfer(comparisonPredicate.right(), inferType);
|
||||
if (!left.isPresent() || !right.isPresent()) {
|
||||
return null;
|
||||
}
|
||||
} else if (predicate instanceof InPredicate) {
|
||||
InPredicate inPredicate = (InPredicate) predicate;
|
||||
Optional<Expression> left = validForInfer(inPredicate.getCompareExpr(), inferType);
|
||||
if (!left.isPresent()) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
Expression newPredicate = predicate.rewriteUp(e -> {
|
||||
if (e.equals(equalLeft)) {
|
||||
return equalRight;
|
||||
} else if (e.equals(equalRight)) {
|
||||
return equalLeft;
|
||||
} else {
|
||||
return e;
|
||||
}
|
||||
});
|
||||
if (predicate instanceof ComparisonPredicate) {
|
||||
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) newPredicate).withInferred(true);
|
||||
} else {
|
||||
return TypeCoercionUtils.processInPredicate((InPredicate) newPredicate).withInferred(true);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression`
|
||||
* Now only support infer `ComparisonPredicate`.
|
||||
* TODO: We should determine whether `expression` satisfies the condition for replacement
|
||||
* eg: Satisfy `expression` is non-deterministic
|
||||
*/
|
||||
private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo predicateInfo) {
|
||||
Expression predicateLeft = predicateInfo.left.get();
|
||||
Expression predicateRight = predicateInfo.right.get();
|
||||
Expression equalLeft = equalInfo.left.get();
|
||||
Expression equalRight = equalInfo.right.get();
|
||||
private static Expression doInfer(EqualInferInfo equalInfo, EqualInferInfo predicateInfo) {
|
||||
Expression equalLeft = equalInfo.left;
|
||||
Expression equalRight = equalInfo.right;
|
||||
|
||||
Expression predicateLeft = predicateInfo.left;
|
||||
Expression predicateRight = predicateInfo.right;
|
||||
Expression newLeft = inferOneSide(predicateLeft, equalLeft, equalRight);
|
||||
Expression newRight = inferOneSide(predicateRight, equalLeft, equalRight);
|
||||
if (newLeft == null || newRight == null) {
|
||||
@ -136,7 +184,7 @@ public class PredicatePropagation {
|
||||
return DateFunctionRewrite.INSTANCE.rewrite(expr, null).withInferred(true);
|
||||
}
|
||||
|
||||
private Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
|
||||
private static Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
|
||||
if (predicateOneSide instanceof SlotReference) {
|
||||
if (predicateOneSide.equals(equalLeft)) {
|
||||
return equalRight;
|
||||
@ -153,60 +201,55 @@ public class PredicatePropagation {
|
||||
return null;
|
||||
}
|
||||
|
||||
private Optional<Expression> validForInfer(Expression expression, InferType inferType) {
|
||||
private static Optional<Expression> validForInfer(Expression expression, InferType inferType) {
|
||||
if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
|
||||
return Optional.empty();
|
||||
}
|
||||
if (expression instanceof SlotReference || expression.isConstant()) {
|
||||
return Optional.of(expression);
|
||||
}
|
||||
if (!(expression instanceof Cast)) {
|
||||
return Optional.empty();
|
||||
}
|
||||
Cast cast = (Cast) expression;
|
||||
Expression child = cast.child();
|
||||
DataType dataType = cast.getDataType();
|
||||
DataType childType = child.getDataType();
|
||||
if (inferType == InferType.INTEGRAL) {
|
||||
if (expression instanceof Cast) {
|
||||
// avoid cast from wider type to narrower type, such as cast(int as smallint)
|
||||
// IntegralType dataType = (IntegralType) expression.getDataType();
|
||||
// DataType childType = ((Cast) expression).child().getDataType();
|
||||
// if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
|
||||
// return validForInfer(((Cast) expression).child(), inferType);
|
||||
// }
|
||||
return validForInfer(((Cast) expression).child(), inferType);
|
||||
}
|
||||
// avoid cast from wider type to narrower type, such as cast(int as smallint)
|
||||
// IntegralType dataType = (IntegralType) expression.getDataType();
|
||||
// DataType childType = ((Cast) expression).child().getDataType();
|
||||
// if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
|
||||
// return validForInfer(((Cast) expression).child(), inferType);
|
||||
// }
|
||||
return validForInfer(child, inferType);
|
||||
} else if (inferType == InferType.DATE) {
|
||||
if (expression instanceof Cast) {
|
||||
DataType dataType = expression.getDataType();
|
||||
DataType childType = ((Cast) expression).child().getDataType();
|
||||
// avoid lost precision
|
||||
if (dataType instanceof DateType) {
|
||||
if (childType instanceof DateV2Type || childType instanceof DateType) {
|
||||
return validForInfer(((Cast) expression).child(), inferType);
|
||||
}
|
||||
} else if (dataType instanceof DateV2Type) {
|
||||
if (childType instanceof DateType || childType instanceof DateV2Type) {
|
||||
return validForInfer(((Cast) expression).child(), inferType);
|
||||
}
|
||||
} else if (dataType instanceof DateTimeType) {
|
||||
if (!(childType instanceof DateTimeV2Type)) {
|
||||
return validForInfer(((Cast) expression).child(), inferType);
|
||||
}
|
||||
} else if (dataType instanceof DateTimeV2Type) {
|
||||
return validForInfer(((Cast) expression).child(), inferType);
|
||||
// avoid lost precision
|
||||
if (dataType instanceof DateType) {
|
||||
if (childType instanceof DateV2Type || childType instanceof DateType) {
|
||||
return validForInfer(child, inferType);
|
||||
}
|
||||
} else if (dataType instanceof DateV2Type) {
|
||||
if (childType instanceof DateType || childType instanceof DateV2Type) {
|
||||
return validForInfer(child, inferType);
|
||||
}
|
||||
} else if (dataType instanceof DateTimeType) {
|
||||
if (!(childType instanceof DateTimeV2Type)) {
|
||||
return validForInfer(child, inferType);
|
||||
}
|
||||
} else if (dataType instanceof DateTimeV2Type) {
|
||||
return validForInfer(child, inferType);
|
||||
}
|
||||
} else if (inferType == InferType.STRING) {
|
||||
if (expression instanceof Cast) {
|
||||
DataType dataType = expression.getDataType();
|
||||
DataType childType = ((Cast) expression).child().getDataType();
|
||||
// avoid substring cast such as cast(char(3) as char(2))
|
||||
if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
|
||||
return validForInfer(((Cast) expression).child(), inferType);
|
||||
}
|
||||
// avoid substring cast such as cast(char(3) as char(2))
|
||||
if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
|
||||
return validForInfer(child, inferType);
|
||||
}
|
||||
} else {
|
||||
return Optional.empty();
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
|
||||
private static EqualInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
|
||||
DataType leftType = comparisonPredicate.left().getDataType();
|
||||
InferType inferType;
|
||||
if (leftType instanceof CharacterType) {
|
||||
@ -223,25 +266,27 @@ public class PredicatePropagation {
|
||||
if (!left.isPresent() || !right.isPresent()) {
|
||||
inferType = InferType.NONE;
|
||||
}
|
||||
return new ComparisonInferInfo(inferType, left, right, comparisonPredicate);
|
||||
return new EqualInferInfo(inferType, left.orElse(comparisonPredicate.left()),
|
||||
right.orElse(comparisonPredicate.right()), comparisonPredicate);
|
||||
}
|
||||
|
||||
/**
|
||||
* Currently only equivalence derivation is supported
|
||||
* and requires that the left and right sides of an expression must be slot
|
||||
* <p>
|
||||
* TODO: NullSafeEqual
|
||||
*/
|
||||
private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate predicate) {
|
||||
private static EqualInferInfo getEqualInferInfo(ComparisonPredicate predicate) {
|
||||
if (!(predicate instanceof EqualTo)) {
|
||||
return new ComparisonInferInfo(InferType.NONE,
|
||||
Optional.of(predicate.left()), Optional.of(predicate.right()), predicate);
|
||||
return new EqualInferInfo(InferType.NONE, predicate.left(), predicate.right(), predicate);
|
||||
}
|
||||
ComparisonInferInfo info = inferInferInfo(predicate);
|
||||
EqualInferInfo info = inferInferInfo(predicate);
|
||||
if (info.inferType == InferType.NONE) {
|
||||
return info;
|
||||
}
|
||||
if (info.left.get() instanceof SlotReference && info.right.get() instanceof SlotReference) {
|
||||
if (info.left instanceof SlotReference && info.right instanceof SlotReference) {
|
||||
return info;
|
||||
}
|
||||
return new ComparisonInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
|
||||
return new EqualInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
|
||||
}
|
||||
}
|
||||
|
||||
@ -47,7 +47,6 @@ import java.util.stream.Collectors;
|
||||
*/
|
||||
public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void> {
|
||||
|
||||
PredicatePropagation propagation = new PredicatePropagation();
|
||||
Map<Plan, ImmutableSet<Expression>> cache = new IdentityHashMap<>();
|
||||
|
||||
@Override
|
||||
@ -99,6 +98,7 @@ public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void
|
||||
public ImmutableSet<Expression> visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Void context) {
|
||||
return cacheOrElse(aggregate, () -> {
|
||||
ImmutableSet<Expression> childPredicates = aggregate.child().accept(this, context);
|
||||
// TODO
|
||||
Map<Expression, Slot> expressionSlotMap = aggregate.getOutputExpressions()
|
||||
.stream()
|
||||
.filter(this::hasAgg)
|
||||
@ -130,7 +130,7 @@ public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void
|
||||
|
||||
private ImmutableSet<Expression> getAvailableExpressions(Collection<Expression> predicates, Plan plan) {
|
||||
Set<Expression> expressions = Sets.newHashSet(predicates);
|
||||
expressions.addAll(propagation.infer(expressions));
|
||||
expressions.addAll(PredicatePropagation.infer(expressions));
|
||||
return expressions.stream()
|
||||
.filter(p -> plan.getOutputSet().containsAll(p.getInputSlots()))
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
|
||||
@ -39,10 +39,6 @@ public class EqualTo extends EqualPredicate implements PropagateNullable {
|
||||
super(ImmutableList.of(left, right), "=", inferred);
|
||||
}
|
||||
|
||||
private EqualTo(List<Expression> children) {
|
||||
this(children, false);
|
||||
}
|
||||
|
||||
private EqualTo(List<Expression> children, boolean inferred) {
|
||||
super(children, "=", inferred);
|
||||
}
|
||||
|
||||
@ -48,6 +48,12 @@ public class InPredicate extends Expression {
|
||||
this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null"));
|
||||
}
|
||||
|
||||
public InPredicate(Expression compareExpr, List<Expression> options, boolean inferred) {
|
||||
super(new Builder<Expression>().add(compareExpr).addAll(options).build(), inferred);
|
||||
this.compareExpr = Objects.requireNonNull(compareExpr, "Compare Expr cannot be null");
|
||||
this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null"));
|
||||
}
|
||||
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitInPredicate(this, context);
|
||||
}
|
||||
@ -80,6 +86,11 @@ public class InPredicate extends Expression {
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression withInferred(boolean inferred) {
|
||||
return new InPredicate(children.get(0), ImmutableList.copyOf(children).subList(1, children.size()), true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return compareExpr + " IN " + options.stream()
|
||||
|
||||
@ -25,7 +25,7 @@ import org.apache.doris.utframe.TestWithFeService;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class InferPredicatesTest extends TestWithFeService implements MemoPatternMatchSupported {
|
||||
class InferPredicatesTest extends TestWithFeService implements MemoPatternMatchSupported {
|
||||
|
||||
@Override
|
||||
protected void runBeforeAll() throws Exception {
|
||||
@ -77,7 +77,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest01() {
|
||||
void inferPredicatesTest01() {
|
||||
String sql = "select * from student join score on student.id = score.sid where student.id > 1";
|
||||
|
||||
PlanChecker.from(connectContext)
|
||||
@ -100,7 +100,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest02() {
|
||||
void inferPredicatesTest02() {
|
||||
String sql = "select * from student join score on student.id = score.sid";
|
||||
|
||||
PlanChecker.from(connectContext)
|
||||
@ -117,7 +117,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest03() {
|
||||
void inferPredicatesTest03() {
|
||||
String sql = "select * from student join score on student.id = score.sid where student.id in (1,2,3)";
|
||||
|
||||
PlanChecker.from(connectContext)
|
||||
@ -126,18 +126,17 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalJoin(
|
||||
logicalFilter(
|
||||
logicalOlapScan()
|
||||
).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
|
||||
logicalFilter(logicalOlapScan()).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
|
||||
& filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
|
||||
logicalOlapScan()
|
||||
logicalFilter(logicalOlapScan()).when(filter -> ExpressionUtils.isInferred(filter.getPredicate())
|
||||
& filter.getPredicate().toSql().contains("sid IN (1, 2, 3)"))
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest04() {
|
||||
void inferPredicatesTest04() {
|
||||
String sql = "select * from student join score on student.id = score.sid and student.id in (1,2,3)";
|
||||
|
||||
PlanChecker.from(connectContext)
|
||||
@ -146,18 +145,17 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalJoin(
|
||||
logicalFilter(
|
||||
logicalOlapScan()
|
||||
).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
|
||||
logicalFilter(logicalOlapScan()).when(filter -> !ExpressionUtils.isInferred(filter.getPredicate())
|
||||
& filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
|
||||
logicalOlapScan()
|
||||
logicalFilter(logicalOlapScan()).when(filter -> ExpressionUtils.isInferred(filter.getPredicate())
|
||||
& filter.getPredicate().toSql().contains("sid IN (1, 2, 3)"))
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest05() {
|
||||
void inferPredicatesTest05() {
|
||||
String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id where student.id > 1";
|
||||
|
||||
PlanChecker.from(connectContext)
|
||||
@ -185,7 +183,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest06() {
|
||||
void inferPredicatesTest06() {
|
||||
String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id and score.sid > 1";
|
||||
|
||||
PlanChecker.from(connectContext)
|
||||
@ -213,7 +211,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest07() {
|
||||
void inferPredicatesTest07() {
|
||||
String sql = "select * from student left join score on student.id = score.sid where student.id > 1";
|
||||
|
||||
PlanChecker.from(connectContext)
|
||||
@ -236,7 +234,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest08() {
|
||||
void inferPredicatesTest08() {
|
||||
String sql = "select * from student left join score on student.id = score.sid and student.id > 1";
|
||||
|
||||
PlanChecker.from(connectContext)
|
||||
@ -256,7 +254,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest09() {
|
||||
void inferPredicatesTest09() {
|
||||
// convert left join to inner join
|
||||
String sql = "select * from student left join score on student.id = score.sid where score.sid > 1";
|
||||
|
||||
@ -280,7 +278,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest10() {
|
||||
void inferPredicatesTest10() {
|
||||
String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid where t.nid > 1";
|
||||
|
||||
PlanChecker.from(connectContext)
|
||||
@ -305,7 +303,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest11() {
|
||||
void inferPredicatesTest11() {
|
||||
String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid and t.nid > 1";
|
||||
|
||||
PlanChecker.from(connectContext)
|
||||
@ -327,7 +325,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest12() {
|
||||
void inferPredicatesTest12() {
|
||||
String sql = "select * from student left join (select sid as nid, sum(grade) from score group by sid) s on s.nid = student.id where student.id > 1";
|
||||
|
||||
PlanChecker.from(connectContext)
|
||||
@ -356,7 +354,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest13() {
|
||||
void inferPredicatesTest13() {
|
||||
String sql = "select * from (select id, name from student where id = 1) t left join score on t.id = score.sid";
|
||||
|
||||
PlanChecker.from(connectContext)
|
||||
@ -381,7 +379,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest14() {
|
||||
void inferPredicatesTest14() {
|
||||
String sql = "select * from student left semi join score on student.id = score.sid where student.id > 1";
|
||||
|
||||
PlanChecker.from(connectContext)
|
||||
@ -406,7 +404,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest15() {
|
||||
void inferPredicatesTest15() {
|
||||
String sql = "select * from student left semi join score on student.id = score.sid and student.id > 1";
|
||||
|
||||
PlanChecker.from(connectContext)
|
||||
@ -431,7 +429,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest16() {
|
||||
void inferPredicatesTest16() {
|
||||
String sql = "select * from student left anti join score on student.id = score.sid and student.id > 1";
|
||||
|
||||
PlanChecker.from(connectContext)
|
||||
@ -453,7 +451,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest17() {
|
||||
void inferPredicatesTest17() {
|
||||
String sql = "select * from student left anti join score on student.id = score.sid and score.sid > 1";
|
||||
|
||||
PlanChecker.from(connectContext)
|
||||
@ -475,7 +473,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest18() {
|
||||
void inferPredicatesTest18() {
|
||||
String sql = "select * from student left anti join score on student.id = score.sid where student.id > 1";
|
||||
|
||||
PlanChecker.from(connectContext)
|
||||
@ -500,7 +498,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest19() {
|
||||
void inferPredicatesTest19() {
|
||||
String sql = "select * from subquery1\n"
|
||||
+ "left semi join (\n"
|
||||
+ " select t1.k3\n"
|
||||
@ -564,7 +562,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest20() {
|
||||
void inferPredicatesTest20() {
|
||||
String sql = "select * from student left join score on student.id = score.sid and score.sid > 1 inner join course on course.id = score.sid";
|
||||
PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
|
||||
PlanChecker.from(connectContext)
|
||||
@ -592,7 +590,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
}
|
||||
|
||||
@Test
|
||||
public void inferPredicatesTest21() {
|
||||
void inferPredicatesTest21() {
|
||||
String sql = "select * from student,score,course where student.id = score.sid and score.sid = course.id and score.sid > 1";
|
||||
PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
|
||||
PlanChecker.from(connectContext)
|
||||
@ -623,7 +621,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
* test for #15310
|
||||
*/
|
||||
@Test
|
||||
public void inferPredicatesTest22() {
|
||||
void inferPredicatesTest22() {
|
||||
String sql = "select * from student join (select sid as id1, sid as id2, grade from score) s on student.id = s.id1 where s.id1 > 1";
|
||||
PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
|
||||
PlanChecker.from(connectContext)
|
||||
@ -651,7 +649,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
* in this case, filter on relation s1 should not contain s1.id = 1.
|
||||
*/
|
||||
@Test
|
||||
public void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() {
|
||||
void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() {
|
||||
String sql = "select * from student s1"
|
||||
+ " left join (select sid as id1, sid as id2, grade from score) s2 on s1.id = s2.id1 and s1.id = 1"
|
||||
+ " join (select sid as id1, sid as id2, grade from score) s3 on s1.id = s3.id1 where s1.id = 2";
|
||||
|
||||
@ -0,0 +1,51 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.InPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.SmallIntType;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
class PredicatePropagationTest {
|
||||
private final SlotReference a = new SlotReference("a", SmallIntType.INSTANCE);
|
||||
private final SlotReference b = new SlotReference("b", BigIntType.INSTANCE);
|
||||
|
||||
@Test
|
||||
void equal() {
|
||||
Set<Expression> exprs = ImmutableSet.of(new EqualTo(a, b), new EqualTo(a, Literal.of(1)));
|
||||
Set<Expression> inferExprs = PredicatePropagation.infer(exprs);
|
||||
System.out.println(inferExprs);
|
||||
}
|
||||
|
||||
@Test
|
||||
void in() {
|
||||
Set<Expression> exprs = ImmutableSet.of(new EqualTo(a, b), new InPredicate(a, ImmutableList.of(Literal.of(1))));
|
||||
Set<Expression> inferExprs = PredicatePropagation.infer(exprs);
|
||||
System.out.println(inferExprs);
|
||||
}
|
||||
}
|
||||
@ -9,7 +9,7 @@ PhysicalResultSink
|
||||
----------PhysicalDistribute[DistributionSpecHash]
|
||||
------------PhysicalOlapScan[t2]
|
||||
--------PhysicalDistribute[DistributionSpecHash]
|
||||
----------NestedLoopJoin[CROSS_JOIN]
|
||||
----------NestedLoopJoin[CROSS_JOIN](t4.c4 = t3.c3)(t3.c3 = t4.c4)
|
||||
------------PhysicalOlapScan[t3]
|
||||
------------PhysicalDistribute[DistributionSpecReplicated]
|
||||
--------------PhysicalOlapScan[t4]
|
||||
|
||||
@ -2609,7 +2609,7 @@ PhysicalResultSink
|
||||
------------PhysicalProject
|
||||
--------------PhysicalOlapScan[t2]
|
||||
------------PhysicalDistribute[DistributionSpecReplicated]
|
||||
--------------NestedLoopJoin[CROSS_JOIN]
|
||||
--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
|
||||
----------------PhysicalProject
|
||||
------------------PhysicalOlapScan[t1]
|
||||
----------------PhysicalDistribute[DistributionSpecReplicated]
|
||||
@ -2631,7 +2631,7 @@ PhysicalResultSink
|
||||
------------PhysicalProject
|
||||
--------------PhysicalOlapScan[t2]
|
||||
------------PhysicalDistribute[DistributionSpecReplicated]
|
||||
--------------NestedLoopJoin[CROSS_JOIN]
|
||||
--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
|
||||
----------------PhysicalProject
|
||||
------------------PhysicalOlapScan[t3]
|
||||
----------------PhysicalDistribute[DistributionSpecReplicated]
|
||||
@ -2745,7 +2745,7 @@ PhysicalResultSink
|
||||
------------PhysicalProject
|
||||
--------------PhysicalOlapScan[t2]
|
||||
------------PhysicalDistribute[DistributionSpecReplicated]
|
||||
--------------NestedLoopJoin[CROSS_JOIN]
|
||||
--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
|
||||
----------------PhysicalProject
|
||||
------------------PhysicalOlapScan[t1]
|
||||
----------------PhysicalDistribute[DistributionSpecReplicated]
|
||||
@ -2767,7 +2767,7 @@ PhysicalResultSink
|
||||
------------PhysicalProject
|
||||
--------------PhysicalOlapScan[t2]
|
||||
------------PhysicalDistribute[DistributionSpecReplicated]
|
||||
--------------NestedLoopJoin[CROSS_JOIN]
|
||||
--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
|
||||
----------------PhysicalProject
|
||||
------------------PhysicalOlapScan[t3]
|
||||
----------------PhysicalDistribute[DistributionSpecReplicated]
|
||||
@ -2881,7 +2881,7 @@ PhysicalResultSink
|
||||
------------PhysicalProject
|
||||
--------------PhysicalOlapScan[t2]
|
||||
------------PhysicalDistribute[DistributionSpecHash]
|
||||
--------------NestedLoopJoin[CROSS_JOIN]
|
||||
--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
|
||||
----------------PhysicalProject
|
||||
------------------PhysicalOlapScan[t1]
|
||||
----------------PhysicalDistribute[DistributionSpecReplicated]
|
||||
@ -2903,7 +2903,7 @@ PhysicalResultSink
|
||||
------------PhysicalProject
|
||||
--------------PhysicalOlapScan[t2]
|
||||
------------PhysicalDistribute[DistributionSpecHash]
|
||||
--------------NestedLoopJoin[CROSS_JOIN]
|
||||
--------------NestedLoopJoin[CROSS_JOIN](t1.c1 = t3.c3)
|
||||
----------------PhysicalProject
|
||||
------------------PhysicalOlapScan[t3]
|
||||
----------------PhysicalDistribute[DistributionSpecReplicated]
|
||||
|
||||
@ -41,7 +41,7 @@ suite("test_infer_predicate") {
|
||||
|
||||
explain {
|
||||
sql "select * from infer_tb1 inner join infer_tb2 where cast(infer_tb2.k4 as int) = infer_tb1.k2 and infer_tb2.k4 = 1;"
|
||||
contains "PREDICATES: k2"
|
||||
contains "PREDICATES: CAST(k2"
|
||||
}
|
||||
|
||||
explain {
|
||||
|
||||
Reference in New Issue
Block a user