[refactor](Nereids) refactor infer predicate rule to avoid lost cast (#25637)
extract slot and literal in comparison predicate. infer new one by equals predicates. use TypeCoercion to add cast on new comparison predicate to ensure it is correct. This reverts "[Fix](Nereids) Add cast comparison with slot reference when inferring predicate (#21171)" commit 58f2593ba1b65713e7b3c1ed39fc84be8cc3ff2c.
This commit is contained in:
@ -17,19 +17,28 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.parser.NereidsParser;
|
||||
import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.DateTimeType;
|
||||
import org.apache.doris.nereids.types.DateTimeV2Type;
|
||||
import org.apache.doris.nereids.types.DateType;
|
||||
import org.apache.doris.nereids.types.DateV2Type;
|
||||
import org.apache.doris.nereids.types.coercion.CharacterType;
|
||||
import org.apache.doris.nereids.types.coercion.DateLikeType;
|
||||
import org.apache.doris.nereids.types.coercion.IntegralType;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.TypeCoercionUtils;
|
||||
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@ -40,19 +49,61 @@ import java.util.stream.Collectors;
|
||||
*/
|
||||
public class PredicatePropagation {
|
||||
|
||||
private enum InferType {
|
||||
NONE(null),
|
||||
INTEGRAL(IntegralType.class),
|
||||
STRING(CharacterType.class),
|
||||
DATE(DateLikeType.class),
|
||||
OTHER(DataType.class)
|
||||
;
|
||||
|
||||
private final Class<? extends DataType> superClazz;
|
||||
|
||||
InferType(Class<? extends DataType> superClazz) {
|
||||
this.superClazz = superClazz;
|
||||
}
|
||||
}
|
||||
|
||||
private class ComparisonInferInfo {
|
||||
|
||||
public final InferType inferType;
|
||||
public final Optional<Expression> left;
|
||||
public final Optional<Expression> right;
|
||||
public final ComparisonPredicate comparisonPredicate;
|
||||
|
||||
public ComparisonInferInfo(InferType inferType,
|
||||
Optional<Expression> left, Optional<Expression> right,
|
||||
ComparisonPredicate comparisonPredicate) {
|
||||
this.inferType = inferType;
|
||||
this.left = left;
|
||||
this.right = right;
|
||||
this.comparisonPredicate = comparisonPredicate;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* infer additional predicates.
|
||||
*/
|
||||
public Set<Expression> infer(Set<Expression> predicates) {
|
||||
Set<Expression> inferred = Sets.newHashSet();
|
||||
for (Expression predicate : predicates) {
|
||||
if (canEquivalentInfer(predicate)) {
|
||||
List<Expression> newInferred = predicates.stream()
|
||||
.filter(p -> !p.equals(predicate))
|
||||
.map(p -> doInfer(predicate, p))
|
||||
.collect(Collectors.toList());
|
||||
inferred.addAll(newInferred);
|
||||
if (!(predicate instanceof ComparisonPredicate)) {
|
||||
continue;
|
||||
}
|
||||
ComparisonInferInfo equalInfo = getEquivalentInferInfo((ComparisonPredicate) predicate);
|
||||
if (equalInfo.inferType == InferType.NONE) {
|
||||
continue;
|
||||
}
|
||||
Set<Expression> newInferred = predicates.stream()
|
||||
.filter(ComparisonPredicate.class::isInstance)
|
||||
.filter(p -> !p.equals(predicate))
|
||||
.map(ComparisonPredicate.class::cast)
|
||||
.map(this::inferInferInfo)
|
||||
.filter(predicateInfo -> predicateInfo.inferType != InferType.NONE)
|
||||
.map(predicateInfo -> doInfer(equalInfo, predicateInfo))
|
||||
.filter(Objects::nonNull)
|
||||
.collect(Collectors.toSet());
|
||||
inferred.addAll(newInferred);
|
||||
}
|
||||
inferred.removeAll(predicates);
|
||||
return inferred;
|
||||
@ -64,64 +115,128 @@ public class PredicatePropagation {
|
||||
* TODO: We should determine whether `expression` satisfies the condition for replacement
|
||||
* eg: Satisfy `expression` is non-deterministic
|
||||
*/
|
||||
private Expression doInfer(Expression leftSlotEqualToRightSlot, Expression expression) {
|
||||
return expression.accept(new DefaultExpressionRewriter<Void>() {
|
||||
private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo predicateInfo) {
|
||||
Expression predicateLeft = predicateInfo.left.get();
|
||||
Expression predicateRight = predicateInfo.right.get();
|
||||
Expression equalLeft = equalInfo.left.get();
|
||||
Expression equalRight = equalInfo.right.get();
|
||||
Expression newLeft = inferOneSide(predicateLeft, equalLeft, equalRight);
|
||||
Expression newRight = inferOneSide(predicateRight, equalLeft, equalRight);
|
||||
if (newLeft == null || newRight == null) {
|
||||
return null;
|
||||
}
|
||||
ComparisonPredicate newPredicate = (ComparisonPredicate) predicateInfo
|
||||
.comparisonPredicate.withChildren(newLeft, newRight);
|
||||
return SimplifyComparisonPredicate.INSTANCE
|
||||
.rewrite(TypeCoercionUtils.processComparisonPredicate(newPredicate), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visit(Expression expr, Void context) {
|
||||
return expr;
|
||||
private Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
|
||||
if (predicateOneSide instanceof SlotReference) {
|
||||
if (predicateOneSide.equals(equalLeft)) {
|
||||
return equalRight;
|
||||
} else if (predicateOneSide.equals(equalRight)) {
|
||||
return equalLeft;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) {
|
||||
// we need to get expression covered by cast, because we want to infer different datatype
|
||||
if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) && (cp.right().isConstant())) {
|
||||
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left()));
|
||||
} else if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && cp.left().isConstant()) {
|
||||
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right()));
|
||||
}
|
||||
return super.visit(cp, context);
|
||||
} else if (predicateOneSide.isConstant()) {
|
||||
if (predicateOneSide instanceof IntegerLikeLiteral) {
|
||||
return new NereidsParser().parseExpression(((IntegerLikeLiteral) predicateOneSide).toSql());
|
||||
} else {
|
||||
return predicateOneSide;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private boolean isDataTypeValid(DataType originDataType, Expression expr) {
|
||||
if ((leftSlotEqualToRightSlot.child(0).getDataType() instanceof IntegralType)
|
||||
&& (leftSlotEqualToRightSlot.child(1).getDataType() instanceof IntegralType)
|
||||
&& (originDataType instanceof IntegralType)) {
|
||||
// infer filter can not be lower than original datatype, or dataset would be wrong
|
||||
if (!((IntegralType) originDataType).widerThan(
|
||||
(IntegralType) leftSlotEqualToRightSlot.child(0).getDataType())
|
||||
&& !((IntegralType) originDataType).widerThan(
|
||||
(IntegralType) leftSlotEqualToRightSlot.child(1).getDataType())) {
|
||||
return true;
|
||||
private Optional<Expression> validForInfer(Expression expression, InferType inferType) {
|
||||
if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
|
||||
return Optional.empty();
|
||||
}
|
||||
if (expression instanceof SlotReference || expression.isConstant()) {
|
||||
return Optional.of(expression);
|
||||
}
|
||||
if (inferType == InferType.INTEGRAL) {
|
||||
if (expression instanceof Cast) {
|
||||
// avoid cast from wider type to narrower type, such as cast(int as smallint)
|
||||
// IntegralType dataType = (IntegralType) expression.getDataType();
|
||||
// DataType childType = ((Cast) expression).child().getDataType();
|
||||
// if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
|
||||
// return validForInfer(((Cast) expression).child(), inferType);
|
||||
// }
|
||||
return validForInfer(((Cast) expression).child(), inferType);
|
||||
}
|
||||
} else if (inferType == InferType.DATE) {
|
||||
if (expression instanceof Cast) {
|
||||
DataType dataType = expression.getDataType();
|
||||
DataType childType = ((Cast) expression).child().getDataType();
|
||||
// avoid lost precision
|
||||
if (dataType instanceof DateType) {
|
||||
if (childType instanceof DateV2Type || childType instanceof DateType) {
|
||||
return validForInfer(((Cast) expression).child(), inferType);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private Expression replaceSlot(Expression expr, DataType originDataType) {
|
||||
return expr.rewriteUp(e -> {
|
||||
if (isDataTypeValid(originDataType, leftSlotEqualToRightSlot)) {
|
||||
if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(0))) {
|
||||
return leftSlotEqualToRightSlot.child(1);
|
||||
} else if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(1))) {
|
||||
return leftSlotEqualToRightSlot.child(0);
|
||||
}
|
||||
} else if (dataType instanceof DateV2Type) {
|
||||
if (childType instanceof DateType || childType instanceof DateV2Type) {
|
||||
return validForInfer(((Cast) expression).child(), inferType);
|
||||
}
|
||||
return e;
|
||||
});
|
||||
} else if (dataType instanceof DateTimeType) {
|
||||
if (!(childType instanceof DateTimeV2Type)) {
|
||||
return validForInfer(((Cast) expression).child(), inferType);
|
||||
}
|
||||
} else if (dataType instanceof DateTimeV2Type) {
|
||||
return validForInfer(((Cast) expression).child(), inferType);
|
||||
}
|
||||
}
|
||||
}, null);
|
||||
} else if (inferType == InferType.STRING) {
|
||||
if (expression instanceof Cast) {
|
||||
DataType dataType = expression.getDataType();
|
||||
DataType childType = ((Cast) expression).child().getDataType();
|
||||
// avoid substring cast such as cast(char(3) as char(2))
|
||||
if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
|
||||
return validForInfer(((Cast) expression).child(), inferType);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Optional.empty();
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
|
||||
DataType leftType = comparisonPredicate.left().getDataType();
|
||||
InferType inferType;
|
||||
if (leftType instanceof CharacterType) {
|
||||
inferType = InferType.STRING;
|
||||
} else if (leftType instanceof IntegralType) {
|
||||
inferType = InferType.INTEGRAL;
|
||||
} else if (leftType instanceof DateLikeType) {
|
||||
inferType = InferType.DATE;
|
||||
} else {
|
||||
inferType = InferType.OTHER;
|
||||
}
|
||||
Optional<Expression> left = validForInfer(comparisonPredicate.left(), inferType);
|
||||
Optional<Expression> right = validForInfer(comparisonPredicate.right(), inferType);
|
||||
if (!left.isPresent() || !right.isPresent()) {
|
||||
inferType = InferType.NONE;
|
||||
}
|
||||
return new ComparisonInferInfo(inferType, left, right, comparisonPredicate);
|
||||
}
|
||||
|
||||
/**
|
||||
* Currently only equivalence derivation is supported
|
||||
* and requires that the left and right sides of an expression must be slot
|
||||
*/
|
||||
private boolean canEquivalentInfer(Expression predicate) {
|
||||
return predicate instanceof EqualTo
|
||||
&& predicate.children().stream().allMatch(e ->
|
||||
(e instanceof SlotReference) || (e instanceof Cast && e.child(0) instanceof SlotReference))
|
||||
&& predicate.child(0).getDataType().equals(predicate.child(1).getDataType());
|
||||
private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate predicate) {
|
||||
if (!(predicate instanceof EqualTo)) {
|
||||
return new ComparisonInferInfo(InferType.NONE,
|
||||
Optional.of(predicate.left()), Optional.of(predicate.right()), predicate);
|
||||
}
|
||||
ComparisonInferInfo info = inferInferInfo(predicate);
|
||||
if (info.inferType == InferType.NONE) {
|
||||
return info;
|
||||
}
|
||||
if (info.left.get() instanceof SlotReference && info.right.get() instanceof SlotReference) {
|
||||
return info;
|
||||
}
|
||||
return new ComparisonInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -39,7 +39,6 @@ import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Predicate;
|
||||
@ -253,34 +252,6 @@ public class ExpressionUtils {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* get slot covered by cast
|
||||
* example: input: cast(cast(table.columnA)) output: columnA.datatype
|
||||
*
|
||||
*/
|
||||
public static DataType getDatatypeCoveredByCast(Expression expr) {
|
||||
if (expr instanceof Cast) {
|
||||
return getDatatypeCoveredByCast(((Cast) expr).child());
|
||||
}
|
||||
return expr.getDataType();
|
||||
}
|
||||
|
||||
/**
|
||||
* judge if expression is slot covered by cast
|
||||
* example: cast(cast(table.columnA))
|
||||
*/
|
||||
public static boolean isExpressionSlotCoveredByCast(Expression expr) {
|
||||
if (expr instanceof Cast) {
|
||||
return isExpressionSlotCoveredByCast(((Cast) expr).child());
|
||||
}
|
||||
return expr instanceof SlotReference;
|
||||
}
|
||||
|
||||
public static boolean isTwoExpressionEqualWithCast(Expression left, Expression right) {
|
||||
return ExpressionUtils.extractSlotOrCastOnSlot(left)
|
||||
.equals(ExpressionUtils.extractSlotOrCastOnSlot(right));
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace expression node in the expression tree by `replaceMap` in top-down manner.
|
||||
* For example.
|
||||
|
||||
Reference in New Issue
Block a user