[feature](Nereids): InferPredicates support In (#29458)
This commit is contained in:
@ -27,7 +27,6 @@ import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.PlanUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
@ -37,6 +36,7 @@ import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* infer additional predicates for `LogicalFilter` and `LogicalJoin`.
|
||||
* <pre>
|
||||
* The logic is as follows:
|
||||
* 1. poll up bottom predicate then infer additional predicates
|
||||
* for example:
|
||||
@ -49,9 +49,9 @@ import java.util.stream.Collectors;
|
||||
* select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t2.id = 1
|
||||
* 2. put these predicates into `otherJoinConjuncts` , these predicates are processed in the next
|
||||
* round of predicate push-down
|
||||
* </pre>
|
||||
*/
|
||||
public class InferPredicates extends DefaultPlanRewriter<JobContext> implements CustomRewriter {
|
||||
private final PredicatePropagation propagation = new PredicatePropagation();
|
||||
private final PullUpPredicates pollUpPredicates = new PullUpPredicates();
|
||||
|
||||
@Override
|
||||
@ -62,6 +62,9 @@ public class InferPredicates extends DefaultPlanRewriter<JobContext> implements
|
||||
@Override
|
||||
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, JobContext context) {
|
||||
join = visitChildren(this, join, context);
|
||||
if (join.isMarkJoin()) {
|
||||
return join;
|
||||
}
|
||||
Plan left = join.left();
|
||||
Plan right = join.right();
|
||||
Set<Expression> expressions = getAllExpressions(left, right, join.getOnClauseCondition());
|
||||
@ -86,7 +89,7 @@ public class InferPredicates extends DefaultPlanRewriter<JobContext> implements
|
||||
break;
|
||||
}
|
||||
if (left != join.left() || right != join.right()) {
|
||||
return join.withChildren(ImmutableList.of(left, right));
|
||||
return join.withChildren(left, right);
|
||||
} else {
|
||||
return join;
|
||||
}
|
||||
@ -109,7 +112,7 @@ public class InferPredicates extends DefaultPlanRewriter<JobContext> implements
|
||||
Set<Expression> baseExpressions = pullUpPredicates(left);
|
||||
baseExpressions.addAll(pullUpPredicates(right));
|
||||
condition.ifPresent(on -> baseExpressions.addAll(ExpressionUtils.extractConjunction(on)));
|
||||
baseExpressions.addAll(propagation.infer(baseExpressions));
|
||||
baseExpressions.addAll(PredicatePropagation.infer(baseExpressions));
|
||||
return baseExpressions;
|
||||
}
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.InPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
@ -55,8 +56,7 @@ public class PredicatePropagation {
|
||||
INTEGRAL(IntegralType.class),
|
||||
STRING(CharacterType.class),
|
||||
DATE(DateLikeType.class),
|
||||
OTHER(DataType.class)
|
||||
;
|
||||
OTHER(DataType.class);
|
||||
|
||||
private final Class<? extends DataType> superClazz;
|
||||
|
||||
@ -65,15 +65,15 @@ public class PredicatePropagation {
|
||||
}
|
||||
}
|
||||
|
||||
private class ComparisonInferInfo {
|
||||
private static class EqualInferInfo {
|
||||
|
||||
public final InferType inferType;
|
||||
public final Optional<Expression> left;
|
||||
public final Optional<Expression> right;
|
||||
public final Expression left;
|
||||
public final Expression right;
|
||||
public final ComparisonPredicate comparisonPredicate;
|
||||
|
||||
public ComparisonInferInfo(InferType inferType,
|
||||
Optional<Expression> left, Optional<Expression> right,
|
||||
public EqualInferInfo(InferType inferType,
|
||||
Expression left, Expression right,
|
||||
ComparisonPredicate comparisonPredicate) {
|
||||
this.inferType = inferType;
|
||||
this.left = left;
|
||||
@ -85,26 +85,27 @@ public class PredicatePropagation {
|
||||
/**
|
||||
* infer additional predicates.
|
||||
*/
|
||||
public Set<Expression> infer(Set<Expression> predicates) {
|
||||
public static Set<Expression> infer(Set<Expression> predicates) {
|
||||
Set<Expression> inferred = Sets.newHashSet();
|
||||
for (Expression predicate : predicates) {
|
||||
// if we support more infer predicate expression type, we should impl withInferred() method.
|
||||
// And should add inferred props in withChildren() method just like ComparisonPredicate,
|
||||
// and it's subclass, to mark the predicate is from infer.
|
||||
if (!(predicate instanceof ComparisonPredicate)) {
|
||||
if (!(predicate instanceof ComparisonPredicate
|
||||
|| (predicate instanceof InPredicate && ((InPredicate) predicate).isLiteralChildren()))) {
|
||||
continue;
|
||||
}
|
||||
ComparisonInferInfo equalInfo = getEquivalentInferInfo((ComparisonPredicate) predicate);
|
||||
if (predicate instanceof InPredicate) {
|
||||
continue;
|
||||
}
|
||||
EqualInferInfo equalInfo = getEqualInferInfo((ComparisonPredicate) predicate);
|
||||
if (equalInfo.inferType == InferType.NONE) {
|
||||
continue;
|
||||
}
|
||||
Set<Expression> newInferred = predicates.stream()
|
||||
.filter(ComparisonPredicate.class::isInstance)
|
||||
.filter(p -> !p.equals(predicate))
|
||||
.map(ComparisonPredicate.class::cast)
|
||||
.map(this::inferInferInfo)
|
||||
.filter(predicateInfo -> predicateInfo.inferType != InferType.NONE)
|
||||
.map(predicateInfo -> doInfer(equalInfo, predicateInfo))
|
||||
.filter(p -> p instanceof ComparisonPredicate || p instanceof InPredicate)
|
||||
.map(predicateInfo -> doInferPredicate(equalInfo, predicateInfo))
|
||||
.filter(Objects::nonNull)
|
||||
.collect(Collectors.toSet());
|
||||
inferred.addAll(newInferred);
|
||||
@ -113,17 +114,64 @@ public class PredicatePropagation {
|
||||
return inferred;
|
||||
}
|
||||
|
||||
private static Expression doInferPredicate(EqualInferInfo equalInfo, Expression predicate) {
|
||||
Expression equalLeft = equalInfo.left;
|
||||
Expression equalRight = equalInfo.right;
|
||||
|
||||
DataType leftType = predicate.child(0).getDataType();
|
||||
InferType inferType;
|
||||
if (leftType instanceof CharacterType) {
|
||||
inferType = InferType.STRING;
|
||||
} else if (leftType instanceof IntegralType) {
|
||||
inferType = InferType.INTEGRAL;
|
||||
} else if (leftType instanceof DateLikeType) {
|
||||
inferType = InferType.DATE;
|
||||
} else {
|
||||
inferType = InferType.OTHER;
|
||||
}
|
||||
if (predicate instanceof ComparisonPredicate) {
|
||||
ComparisonPredicate comparisonPredicate = (ComparisonPredicate) predicate;
|
||||
Optional<Expression> left = validForInfer(comparisonPredicate.left(), inferType);
|
||||
Optional<Expression> right = validForInfer(comparisonPredicate.right(), inferType);
|
||||
if (!left.isPresent() || !right.isPresent()) {
|
||||
return null;
|
||||
}
|
||||
} else if (predicate instanceof InPredicate) {
|
||||
InPredicate inPredicate = (InPredicate) predicate;
|
||||
Optional<Expression> left = validForInfer(inPredicate.getCompareExpr(), inferType);
|
||||
if (!left.isPresent()) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
Expression newPredicate = predicate.rewriteUp(e -> {
|
||||
if (e.equals(equalLeft)) {
|
||||
return equalRight;
|
||||
} else if (e.equals(equalRight)) {
|
||||
return equalLeft;
|
||||
} else {
|
||||
return e;
|
||||
}
|
||||
});
|
||||
if (predicate instanceof ComparisonPredicate) {
|
||||
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) newPredicate).withInferred(true);
|
||||
} else {
|
||||
return TypeCoercionUtils.processInPredicate((InPredicate) newPredicate).withInferred(true);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression`
|
||||
* Now only support infer `ComparisonPredicate`.
|
||||
* TODO: We should determine whether `expression` satisfies the condition for replacement
|
||||
* eg: Satisfy `expression` is non-deterministic
|
||||
*/
|
||||
private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo predicateInfo) {
|
||||
Expression predicateLeft = predicateInfo.left.get();
|
||||
Expression predicateRight = predicateInfo.right.get();
|
||||
Expression equalLeft = equalInfo.left.get();
|
||||
Expression equalRight = equalInfo.right.get();
|
||||
private static Expression doInfer(EqualInferInfo equalInfo, EqualInferInfo predicateInfo) {
|
||||
Expression equalLeft = equalInfo.left;
|
||||
Expression equalRight = equalInfo.right;
|
||||
|
||||
Expression predicateLeft = predicateInfo.left;
|
||||
Expression predicateRight = predicateInfo.right;
|
||||
Expression newLeft = inferOneSide(predicateLeft, equalLeft, equalRight);
|
||||
Expression newRight = inferOneSide(predicateRight, equalLeft, equalRight);
|
||||
if (newLeft == null || newRight == null) {
|
||||
@ -136,7 +184,7 @@ public class PredicatePropagation {
|
||||
return DateFunctionRewrite.INSTANCE.rewrite(expr, null).withInferred(true);
|
||||
}
|
||||
|
||||
private Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
|
||||
private static Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
|
||||
if (predicateOneSide instanceof SlotReference) {
|
||||
if (predicateOneSide.equals(equalLeft)) {
|
||||
return equalRight;
|
||||
@ -153,60 +201,55 @@ public class PredicatePropagation {
|
||||
return null;
|
||||
}
|
||||
|
||||
private Optional<Expression> validForInfer(Expression expression, InferType inferType) {
|
||||
private static Optional<Expression> validForInfer(Expression expression, InferType inferType) {
|
||||
if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
|
||||
return Optional.empty();
|
||||
}
|
||||
if (expression instanceof SlotReference || expression.isConstant()) {
|
||||
return Optional.of(expression);
|
||||
}
|
||||
if (!(expression instanceof Cast)) {
|
||||
return Optional.empty();
|
||||
}
|
||||
Cast cast = (Cast) expression;
|
||||
Expression child = cast.child();
|
||||
DataType dataType = cast.getDataType();
|
||||
DataType childType = child.getDataType();
|
||||
if (inferType == InferType.INTEGRAL) {
|
||||
if (expression instanceof Cast) {
|
||||
// avoid cast from wider type to narrower type, such as cast(int as smallint)
|
||||
// IntegralType dataType = (IntegralType) expression.getDataType();
|
||||
// DataType childType = ((Cast) expression).child().getDataType();
|
||||
// if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
|
||||
// return validForInfer(((Cast) expression).child(), inferType);
|
||||
// }
|
||||
return validForInfer(((Cast) expression).child(), inferType);
|
||||
}
|
||||
// avoid cast from wider type to narrower type, such as cast(int as smallint)
|
||||
// IntegralType dataType = (IntegralType) expression.getDataType();
|
||||
// DataType childType = ((Cast) expression).child().getDataType();
|
||||
// if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
|
||||
// return validForInfer(((Cast) expression).child(), inferType);
|
||||
// }
|
||||
return validForInfer(child, inferType);
|
||||
} else if (inferType == InferType.DATE) {
|
||||
if (expression instanceof Cast) {
|
||||
DataType dataType = expression.getDataType();
|
||||
DataType childType = ((Cast) expression).child().getDataType();
|
||||
// avoid lost precision
|
||||
if (dataType instanceof DateType) {
|
||||
if (childType instanceof DateV2Type || childType instanceof DateType) {
|
||||
return validForInfer(((Cast) expression).child(), inferType);
|
||||
}
|
||||
} else if (dataType instanceof DateV2Type) {
|
||||
if (childType instanceof DateType || childType instanceof DateV2Type) {
|
||||
return validForInfer(((Cast) expression).child(), inferType);
|
||||
}
|
||||
} else if (dataType instanceof DateTimeType) {
|
||||
if (!(childType instanceof DateTimeV2Type)) {
|
||||
return validForInfer(((Cast) expression).child(), inferType);
|
||||
}
|
||||
} else if (dataType instanceof DateTimeV2Type) {
|
||||
return validForInfer(((Cast) expression).child(), inferType);
|
||||
// avoid lost precision
|
||||
if (dataType instanceof DateType) {
|
||||
if (childType instanceof DateV2Type || childType instanceof DateType) {
|
||||
return validForInfer(child, inferType);
|
||||
}
|
||||
} else if (dataType instanceof DateV2Type) {
|
||||
if (childType instanceof DateType || childType instanceof DateV2Type) {
|
||||
return validForInfer(child, inferType);
|
||||
}
|
||||
} else if (dataType instanceof DateTimeType) {
|
||||
if (!(childType instanceof DateTimeV2Type)) {
|
||||
return validForInfer(child, inferType);
|
||||
}
|
||||
} else if (dataType instanceof DateTimeV2Type) {
|
||||
return validForInfer(child, inferType);
|
||||
}
|
||||
} else if (inferType == InferType.STRING) {
|
||||
if (expression instanceof Cast) {
|
||||
DataType dataType = expression.getDataType();
|
||||
DataType childType = ((Cast) expression).child().getDataType();
|
||||
// avoid substring cast such as cast(char(3) as char(2))
|
||||
if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
|
||||
return validForInfer(((Cast) expression).child(), inferType);
|
||||
}
|
||||
// avoid substring cast such as cast(char(3) as char(2))
|
||||
if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
|
||||
return validForInfer(child, inferType);
|
||||
}
|
||||
} else {
|
||||
return Optional.empty();
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
|
||||
private static EqualInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
|
||||
DataType leftType = comparisonPredicate.left().getDataType();
|
||||
InferType inferType;
|
||||
if (leftType instanceof CharacterType) {
|
||||
@ -223,25 +266,27 @@ public class PredicatePropagation {
|
||||
if (!left.isPresent() || !right.isPresent()) {
|
||||
inferType = InferType.NONE;
|
||||
}
|
||||
return new ComparisonInferInfo(inferType, left, right, comparisonPredicate);
|
||||
return new EqualInferInfo(inferType, left.orElse(comparisonPredicate.left()),
|
||||
right.orElse(comparisonPredicate.right()), comparisonPredicate);
|
||||
}
|
||||
|
||||
/**
|
||||
* Currently only equivalence derivation is supported
|
||||
* and requires that the left and right sides of an expression must be slot
|
||||
* <p>
|
||||
* TODO: NullSafeEqual
|
||||
*/
|
||||
private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate predicate) {
|
||||
private static EqualInferInfo getEqualInferInfo(ComparisonPredicate predicate) {
|
||||
if (!(predicate instanceof EqualTo)) {
|
||||
return new ComparisonInferInfo(InferType.NONE,
|
||||
Optional.of(predicate.left()), Optional.of(predicate.right()), predicate);
|
||||
return new EqualInferInfo(InferType.NONE, predicate.left(), predicate.right(), predicate);
|
||||
}
|
||||
ComparisonInferInfo info = inferInferInfo(predicate);
|
||||
EqualInferInfo info = inferInferInfo(predicate);
|
||||
if (info.inferType == InferType.NONE) {
|
||||
return info;
|
||||
}
|
||||
if (info.left.get() instanceof SlotReference && info.right.get() instanceof SlotReference) {
|
||||
if (info.left instanceof SlotReference && info.right instanceof SlotReference) {
|
||||
return info;
|
||||
}
|
||||
return new ComparisonInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
|
||||
return new EqualInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
|
||||
}
|
||||
}
|
||||
|
||||
@ -47,7 +47,6 @@ import java.util.stream.Collectors;
|
||||
*/
|
||||
public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void> {
|
||||
|
||||
PredicatePropagation propagation = new PredicatePropagation();
|
||||
Map<Plan, ImmutableSet<Expression>> cache = new IdentityHashMap<>();
|
||||
|
||||
@Override
|
||||
@ -99,6 +98,7 @@ public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void
|
||||
public ImmutableSet<Expression> visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Void context) {
|
||||
return cacheOrElse(aggregate, () -> {
|
||||
ImmutableSet<Expression> childPredicates = aggregate.child().accept(this, context);
|
||||
// TODO
|
||||
Map<Expression, Slot> expressionSlotMap = aggregate.getOutputExpressions()
|
||||
.stream()
|
||||
.filter(this::hasAgg)
|
||||
@ -130,7 +130,7 @@ public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void
|
||||
|
||||
private ImmutableSet<Expression> getAvailableExpressions(Collection<Expression> predicates, Plan plan) {
|
||||
Set<Expression> expressions = Sets.newHashSet(predicates);
|
||||
expressions.addAll(propagation.infer(expressions));
|
||||
expressions.addAll(PredicatePropagation.infer(expressions));
|
||||
return expressions.stream()
|
||||
.filter(p -> plan.getOutputSet().containsAll(p.getInputSlots()))
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
|
||||
@ -39,10 +39,6 @@ public class EqualTo extends EqualPredicate implements PropagateNullable {
|
||||
super(ImmutableList.of(left, right), "=", inferred);
|
||||
}
|
||||
|
||||
private EqualTo(List<Expression> children) {
|
||||
this(children, false);
|
||||
}
|
||||
|
||||
private EqualTo(List<Expression> children, boolean inferred) {
|
||||
super(children, "=", inferred);
|
||||
}
|
||||
|
||||
@ -48,6 +48,12 @@ public class InPredicate extends Expression {
|
||||
this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null"));
|
||||
}
|
||||
|
||||
public InPredicate(Expression compareExpr, List<Expression> options, boolean inferred) {
|
||||
super(new Builder<Expression>().add(compareExpr).addAll(options).build(), inferred);
|
||||
this.compareExpr = Objects.requireNonNull(compareExpr, "Compare Expr cannot be null");
|
||||
this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null"));
|
||||
}
|
||||
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitInPredicate(this, context);
|
||||
}
|
||||
@ -80,6 +86,11 @@ public class InPredicate extends Expression {
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression withInferred(boolean inferred) {
|
||||
return new InPredicate(children.get(0), ImmutableList.copyOf(children).subList(1, children.size()), true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return compareExpr + " IN " + options.stream()
|
||||
|
||||
Reference in New Issue
Block a user