[feature](Nereids): InferPredicates support In (#29458)

This commit is contained in:
jakevin
2024-01-05 21:25:30 +08:00
committed by GitHub
parent f40cce1406
commit 7a0734dbd6
10 changed files with 220 additions and 116 deletions

View File

@ -27,7 +27,6 @@ import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
@ -37,6 +36,7 @@ import java.util.stream.Collectors;
/**
* infer additional predicates for `LogicalFilter` and `LogicalJoin`.
* <pre>
* The logic is as follows:
* 1. poll up bottom predicate then infer additional predicates
* for example:
@ -49,9 +49,9 @@ import java.util.stream.Collectors;
* select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t2.id = 1
* 2. put these predicates into `otherJoinConjuncts` , these predicates are processed in the next
* round of predicate push-down
* </pre>
*/
public class InferPredicates extends DefaultPlanRewriter<JobContext> implements CustomRewriter {
private final PredicatePropagation propagation = new PredicatePropagation();
private final PullUpPredicates pollUpPredicates = new PullUpPredicates();
@Override
@ -62,6 +62,9 @@ public class InferPredicates extends DefaultPlanRewriter<JobContext> implements
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, JobContext context) {
join = visitChildren(this, join, context);
if (join.isMarkJoin()) {
return join;
}
Plan left = join.left();
Plan right = join.right();
Set<Expression> expressions = getAllExpressions(left, right, join.getOnClauseCondition());
@ -86,7 +89,7 @@ public class InferPredicates extends DefaultPlanRewriter<JobContext> implements
break;
}
if (left != join.left() || right != join.right()) {
return join.withChildren(ImmutableList.of(left, right));
return join.withChildren(left, right);
} else {
return join;
}
@ -109,7 +112,7 @@ public class InferPredicates extends DefaultPlanRewriter<JobContext> implements
Set<Expression> baseExpressions = pullUpPredicates(left);
baseExpressions.addAll(pullUpPredicates(right));
condition.ifPresent(on -> baseExpressions.addAll(ExpressionUtils.extractConjunction(on)));
baseExpressions.addAll(propagation.infer(baseExpressions));
baseExpressions.addAll(PredicatePropagation.infer(baseExpressions));
return baseExpressions;
}

View File

@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.types.DataType;
@ -55,8 +56,7 @@ public class PredicatePropagation {
INTEGRAL(IntegralType.class),
STRING(CharacterType.class),
DATE(DateLikeType.class),
OTHER(DataType.class)
;
OTHER(DataType.class);
private final Class<? extends DataType> superClazz;
@ -65,15 +65,15 @@ public class PredicatePropagation {
}
}
private class ComparisonInferInfo {
private static class EqualInferInfo {
public final InferType inferType;
public final Optional<Expression> left;
public final Optional<Expression> right;
public final Expression left;
public final Expression right;
public final ComparisonPredicate comparisonPredicate;
public ComparisonInferInfo(InferType inferType,
Optional<Expression> left, Optional<Expression> right,
public EqualInferInfo(InferType inferType,
Expression left, Expression right,
ComparisonPredicate comparisonPredicate) {
this.inferType = inferType;
this.left = left;
@ -85,26 +85,27 @@ public class PredicatePropagation {
/**
* infer additional predicates.
*/
public Set<Expression> infer(Set<Expression> predicates) {
public static Set<Expression> infer(Set<Expression> predicates) {
Set<Expression> inferred = Sets.newHashSet();
for (Expression predicate : predicates) {
// if we support more infer predicate expression type, we should impl withInferred() method.
// And should add inferred props in withChildren() method just like ComparisonPredicate,
// and it's subclass, to mark the predicate is from infer.
if (!(predicate instanceof ComparisonPredicate)) {
if (!(predicate instanceof ComparisonPredicate
|| (predicate instanceof InPredicate && ((InPredicate) predicate).isLiteralChildren()))) {
continue;
}
ComparisonInferInfo equalInfo = getEquivalentInferInfo((ComparisonPredicate) predicate);
if (predicate instanceof InPredicate) {
continue;
}
EqualInferInfo equalInfo = getEqualInferInfo((ComparisonPredicate) predicate);
if (equalInfo.inferType == InferType.NONE) {
continue;
}
Set<Expression> newInferred = predicates.stream()
.filter(ComparisonPredicate.class::isInstance)
.filter(p -> !p.equals(predicate))
.map(ComparisonPredicate.class::cast)
.map(this::inferInferInfo)
.filter(predicateInfo -> predicateInfo.inferType != InferType.NONE)
.map(predicateInfo -> doInfer(equalInfo, predicateInfo))
.filter(p -> p instanceof ComparisonPredicate || p instanceof InPredicate)
.map(predicateInfo -> doInferPredicate(equalInfo, predicateInfo))
.filter(Objects::nonNull)
.collect(Collectors.toSet());
inferred.addAll(newInferred);
@ -113,17 +114,64 @@ public class PredicatePropagation {
return inferred;
}
private static Expression doInferPredicate(EqualInferInfo equalInfo, Expression predicate) {
Expression equalLeft = equalInfo.left;
Expression equalRight = equalInfo.right;
DataType leftType = predicate.child(0).getDataType();
InferType inferType;
if (leftType instanceof CharacterType) {
inferType = InferType.STRING;
} else if (leftType instanceof IntegralType) {
inferType = InferType.INTEGRAL;
} else if (leftType instanceof DateLikeType) {
inferType = InferType.DATE;
} else {
inferType = InferType.OTHER;
}
if (predicate instanceof ComparisonPredicate) {
ComparisonPredicate comparisonPredicate = (ComparisonPredicate) predicate;
Optional<Expression> left = validForInfer(comparisonPredicate.left(), inferType);
Optional<Expression> right = validForInfer(comparisonPredicate.right(), inferType);
if (!left.isPresent() || !right.isPresent()) {
return null;
}
} else if (predicate instanceof InPredicate) {
InPredicate inPredicate = (InPredicate) predicate;
Optional<Expression> left = validForInfer(inPredicate.getCompareExpr(), inferType);
if (!left.isPresent()) {
return null;
}
}
Expression newPredicate = predicate.rewriteUp(e -> {
if (e.equals(equalLeft)) {
return equalRight;
} else if (e.equals(equalRight)) {
return equalLeft;
} else {
return e;
}
});
if (predicate instanceof ComparisonPredicate) {
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) newPredicate).withInferred(true);
} else {
return TypeCoercionUtils.processInPredicate((InPredicate) newPredicate).withInferred(true);
}
}
/**
* Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression`
* Now only support infer `ComparisonPredicate`.
* TODO: We should determine whether `expression` satisfies the condition for replacement
* eg: Satisfy `expression` is non-deterministic
*/
private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo predicateInfo) {
Expression predicateLeft = predicateInfo.left.get();
Expression predicateRight = predicateInfo.right.get();
Expression equalLeft = equalInfo.left.get();
Expression equalRight = equalInfo.right.get();
private static Expression doInfer(EqualInferInfo equalInfo, EqualInferInfo predicateInfo) {
Expression equalLeft = equalInfo.left;
Expression equalRight = equalInfo.right;
Expression predicateLeft = predicateInfo.left;
Expression predicateRight = predicateInfo.right;
Expression newLeft = inferOneSide(predicateLeft, equalLeft, equalRight);
Expression newRight = inferOneSide(predicateRight, equalLeft, equalRight);
if (newLeft == null || newRight == null) {
@ -136,7 +184,7 @@ public class PredicatePropagation {
return DateFunctionRewrite.INSTANCE.rewrite(expr, null).withInferred(true);
}
private Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
private static Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
if (predicateOneSide instanceof SlotReference) {
if (predicateOneSide.equals(equalLeft)) {
return equalRight;
@ -153,60 +201,55 @@ public class PredicatePropagation {
return null;
}
private Optional<Expression> validForInfer(Expression expression, InferType inferType) {
private static Optional<Expression> validForInfer(Expression expression, InferType inferType) {
if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
return Optional.empty();
}
if (expression instanceof SlotReference || expression.isConstant()) {
return Optional.of(expression);
}
if (!(expression instanceof Cast)) {
return Optional.empty();
}
Cast cast = (Cast) expression;
Expression child = cast.child();
DataType dataType = cast.getDataType();
DataType childType = child.getDataType();
if (inferType == InferType.INTEGRAL) {
if (expression instanceof Cast) {
// avoid cast from wider type to narrower type, such as cast(int as smallint)
// IntegralType dataType = (IntegralType) expression.getDataType();
// DataType childType = ((Cast) expression).child().getDataType();
// if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
// return validForInfer(((Cast) expression).child(), inferType);
// }
return validForInfer(((Cast) expression).child(), inferType);
}
// avoid cast from wider type to narrower type, such as cast(int as smallint)
// IntegralType dataType = (IntegralType) expression.getDataType();
// DataType childType = ((Cast) expression).child().getDataType();
// if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
// return validForInfer(((Cast) expression).child(), inferType);
// }
return validForInfer(child, inferType);
} else if (inferType == InferType.DATE) {
if (expression instanceof Cast) {
DataType dataType = expression.getDataType();
DataType childType = ((Cast) expression).child().getDataType();
// avoid lost precision
if (dataType instanceof DateType) {
if (childType instanceof DateV2Type || childType instanceof DateType) {
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (dataType instanceof DateV2Type) {
if (childType instanceof DateType || childType instanceof DateV2Type) {
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (dataType instanceof DateTimeType) {
if (!(childType instanceof DateTimeV2Type)) {
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (dataType instanceof DateTimeV2Type) {
return validForInfer(((Cast) expression).child(), inferType);
// avoid lost precision
if (dataType instanceof DateType) {
if (childType instanceof DateV2Type || childType instanceof DateType) {
return validForInfer(child, inferType);
}
} else if (dataType instanceof DateV2Type) {
if (childType instanceof DateType || childType instanceof DateV2Type) {
return validForInfer(child, inferType);
}
} else if (dataType instanceof DateTimeType) {
if (!(childType instanceof DateTimeV2Type)) {
return validForInfer(child, inferType);
}
} else if (dataType instanceof DateTimeV2Type) {
return validForInfer(child, inferType);
}
} else if (inferType == InferType.STRING) {
if (expression instanceof Cast) {
DataType dataType = expression.getDataType();
DataType childType = ((Cast) expression).child().getDataType();
// avoid substring cast such as cast(char(3) as char(2))
if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
return validForInfer(((Cast) expression).child(), inferType);
}
// avoid substring cast such as cast(char(3) as char(2))
if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
return validForInfer(child, inferType);
}
} else {
return Optional.empty();
}
return Optional.empty();
}
private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
private static EqualInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
DataType leftType = comparisonPredicate.left().getDataType();
InferType inferType;
if (leftType instanceof CharacterType) {
@ -223,25 +266,27 @@ public class PredicatePropagation {
if (!left.isPresent() || !right.isPresent()) {
inferType = InferType.NONE;
}
return new ComparisonInferInfo(inferType, left, right, comparisonPredicate);
return new EqualInferInfo(inferType, left.orElse(comparisonPredicate.left()),
right.orElse(comparisonPredicate.right()), comparisonPredicate);
}
/**
* Currently only equivalence derivation is supported
* and requires that the left and right sides of an expression must be slot
* <p>
* TODO: NullSafeEqual
*/
private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate predicate) {
private static EqualInferInfo getEqualInferInfo(ComparisonPredicate predicate) {
if (!(predicate instanceof EqualTo)) {
return new ComparisonInferInfo(InferType.NONE,
Optional.of(predicate.left()), Optional.of(predicate.right()), predicate);
return new EqualInferInfo(InferType.NONE, predicate.left(), predicate.right(), predicate);
}
ComparisonInferInfo info = inferInferInfo(predicate);
EqualInferInfo info = inferInferInfo(predicate);
if (info.inferType == InferType.NONE) {
return info;
}
if (info.left.get() instanceof SlotReference && info.right.get() instanceof SlotReference) {
if (info.left instanceof SlotReference && info.right instanceof SlotReference) {
return info;
}
return new ComparisonInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
return new EqualInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
}
}

View File

@ -47,7 +47,6 @@ import java.util.stream.Collectors;
*/
public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void> {
PredicatePropagation propagation = new PredicatePropagation();
Map<Plan, ImmutableSet<Expression>> cache = new IdentityHashMap<>();
@Override
@ -99,6 +98,7 @@ public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void
public ImmutableSet<Expression> visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Void context) {
return cacheOrElse(aggregate, () -> {
ImmutableSet<Expression> childPredicates = aggregate.child().accept(this, context);
// TODO
Map<Expression, Slot> expressionSlotMap = aggregate.getOutputExpressions()
.stream()
.filter(this::hasAgg)
@ -130,7 +130,7 @@ public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void
private ImmutableSet<Expression> getAvailableExpressions(Collection<Expression> predicates, Plan plan) {
Set<Expression> expressions = Sets.newHashSet(predicates);
expressions.addAll(propagation.infer(expressions));
expressions.addAll(PredicatePropagation.infer(expressions));
return expressions.stream()
.filter(p -> plan.getOutputSet().containsAll(p.getInputSlots()))
.collect(ImmutableSet.toImmutableSet());

View File

@ -39,10 +39,6 @@ public class EqualTo extends EqualPredicate implements PropagateNullable {
super(ImmutableList.of(left, right), "=", inferred);
}
private EqualTo(List<Expression> children) {
this(children, false);
}
private EqualTo(List<Expression> children, boolean inferred) {
super(children, "=", inferred);
}

View File

@ -48,6 +48,12 @@ public class InPredicate extends Expression {
this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null"));
}
public InPredicate(Expression compareExpr, List<Expression> options, boolean inferred) {
super(new Builder<Expression>().add(compareExpr).addAll(options).build(), inferred);
this.compareExpr = Objects.requireNonNull(compareExpr, "Compare Expr cannot be null");
this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null"));
}
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitInPredicate(this, context);
}
@ -80,6 +86,11 @@ public class InPredicate extends Expression {
});
}
@Override
public Expression withInferred(boolean inferred) {
return new InPredicate(children.get(0), ImmutableList.copyOf(children).subList(1, children.size()), true);
}
@Override
public String toString() {
return compareExpr + " IN " + options.stream()