[refactor](Nereids) refactor infer predicate rule to avoid lost cast (#25637)

extract slot and literal in comparison predicate. infer new one by equals predicates.
use TypeCoercion to add cast on new comparison predicate to ensure it is correct.

This reverts "[Fix](Nereids) Add cast comparison with slot reference when inferring predicate (#21171)"
commit 58f2593ba1b65713e7b3c1ed39fc84be8cc3ff2c.
This commit is contained in:
morrySnow
2023-10-25 14:12:22 +08:00
committed by GitHub
parent 3b9ae91910
commit ae66464d6b
3 changed files with 171 additions and 85 deletions

View File

@ -17,19 +17,28 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.types.coercion.DateLikeType;
import org.apache.doris.nereids.types.coercion.IntegralType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.collect.Sets;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@ -40,19 +49,61 @@ import java.util.stream.Collectors;
*/
public class PredicatePropagation {
private enum InferType {
NONE(null),
INTEGRAL(IntegralType.class),
STRING(CharacterType.class),
DATE(DateLikeType.class),
OTHER(DataType.class)
;
private final Class<? extends DataType> superClazz;
InferType(Class<? extends DataType> superClazz) {
this.superClazz = superClazz;
}
}
private class ComparisonInferInfo {
public final InferType inferType;
public final Optional<Expression> left;
public final Optional<Expression> right;
public final ComparisonPredicate comparisonPredicate;
public ComparisonInferInfo(InferType inferType,
Optional<Expression> left, Optional<Expression> right,
ComparisonPredicate comparisonPredicate) {
this.inferType = inferType;
this.left = left;
this.right = right;
this.comparisonPredicate = comparisonPredicate;
}
}
/**
* infer additional predicates.
*/
public Set<Expression> infer(Set<Expression> predicates) {
Set<Expression> inferred = Sets.newHashSet();
for (Expression predicate : predicates) {
if (canEquivalentInfer(predicate)) {
List<Expression> newInferred = predicates.stream()
.filter(p -> !p.equals(predicate))
.map(p -> doInfer(predicate, p))
.collect(Collectors.toList());
inferred.addAll(newInferred);
if (!(predicate instanceof ComparisonPredicate)) {
continue;
}
ComparisonInferInfo equalInfo = getEquivalentInferInfo((ComparisonPredicate) predicate);
if (equalInfo.inferType == InferType.NONE) {
continue;
}
Set<Expression> newInferred = predicates.stream()
.filter(ComparisonPredicate.class::isInstance)
.filter(p -> !p.equals(predicate))
.map(ComparisonPredicate.class::cast)
.map(this::inferInferInfo)
.filter(predicateInfo -> predicateInfo.inferType != InferType.NONE)
.map(predicateInfo -> doInfer(equalInfo, predicateInfo))
.filter(Objects::nonNull)
.collect(Collectors.toSet());
inferred.addAll(newInferred);
}
inferred.removeAll(predicates);
return inferred;
@ -64,64 +115,128 @@ public class PredicatePropagation {
* TODO: We should determine whether `expression` satisfies the condition for replacement
* eg: Satisfy `expression` is non-deterministic
*/
private Expression doInfer(Expression leftSlotEqualToRightSlot, Expression expression) {
return expression.accept(new DefaultExpressionRewriter<Void>() {
private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo predicateInfo) {
Expression predicateLeft = predicateInfo.left.get();
Expression predicateRight = predicateInfo.right.get();
Expression equalLeft = equalInfo.left.get();
Expression equalRight = equalInfo.right.get();
Expression newLeft = inferOneSide(predicateLeft, equalLeft, equalRight);
Expression newRight = inferOneSide(predicateRight, equalLeft, equalRight);
if (newLeft == null || newRight == null) {
return null;
}
ComparisonPredicate newPredicate = (ComparisonPredicate) predicateInfo
.comparisonPredicate.withChildren(newLeft, newRight);
return SimplifyComparisonPredicate.INSTANCE
.rewrite(TypeCoercionUtils.processComparisonPredicate(newPredicate), null);
}
@Override
public Expression visit(Expression expr, Void context) {
return expr;
private Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
if (predicateOneSide instanceof SlotReference) {
if (predicateOneSide.equals(equalLeft)) {
return equalRight;
} else if (predicateOneSide.equals(equalRight)) {
return equalLeft;
}
@Override
public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) {
// we need to get expression covered by cast, because we want to infer different datatype
if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) && (cp.right().isConstant())) {
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left()));
} else if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && cp.left().isConstant()) {
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right()));
}
return super.visit(cp, context);
} else if (predicateOneSide.isConstant()) {
if (predicateOneSide instanceof IntegerLikeLiteral) {
return new NereidsParser().parseExpression(((IntegerLikeLiteral) predicateOneSide).toSql());
} else {
return predicateOneSide;
}
}
return null;
}
private boolean isDataTypeValid(DataType originDataType, Expression expr) {
if ((leftSlotEqualToRightSlot.child(0).getDataType() instanceof IntegralType)
&& (leftSlotEqualToRightSlot.child(1).getDataType() instanceof IntegralType)
&& (originDataType instanceof IntegralType)) {
// infer filter can not be lower than original datatype, or dataset would be wrong
if (!((IntegralType) originDataType).widerThan(
(IntegralType) leftSlotEqualToRightSlot.child(0).getDataType())
&& !((IntegralType) originDataType).widerThan(
(IntegralType) leftSlotEqualToRightSlot.child(1).getDataType())) {
return true;
private Optional<Expression> validForInfer(Expression expression, InferType inferType) {
if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
return Optional.empty();
}
if (expression instanceof SlotReference || expression.isConstant()) {
return Optional.of(expression);
}
if (inferType == InferType.INTEGRAL) {
if (expression instanceof Cast) {
// avoid cast from wider type to narrower type, such as cast(int as smallint)
// IntegralType dataType = (IntegralType) expression.getDataType();
// DataType childType = ((Cast) expression).child().getDataType();
// if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
// return validForInfer(((Cast) expression).child(), inferType);
// }
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (inferType == InferType.DATE) {
if (expression instanceof Cast) {
DataType dataType = expression.getDataType();
DataType childType = ((Cast) expression).child().getDataType();
// avoid lost precision
if (dataType instanceof DateType) {
if (childType instanceof DateV2Type || childType instanceof DateType) {
return validForInfer(((Cast) expression).child(), inferType);
}
}
return false;
}
private Expression replaceSlot(Expression expr, DataType originDataType) {
return expr.rewriteUp(e -> {
if (isDataTypeValid(originDataType, leftSlotEqualToRightSlot)) {
if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(0))) {
return leftSlotEqualToRightSlot.child(1);
} else if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(1))) {
return leftSlotEqualToRightSlot.child(0);
}
} else if (dataType instanceof DateV2Type) {
if (childType instanceof DateType || childType instanceof DateV2Type) {
return validForInfer(((Cast) expression).child(), inferType);
}
return e;
});
} else if (dataType instanceof DateTimeType) {
if (!(childType instanceof DateTimeV2Type)) {
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (dataType instanceof DateTimeV2Type) {
return validForInfer(((Cast) expression).child(), inferType);
}
}
}, null);
} else if (inferType == InferType.STRING) {
if (expression instanceof Cast) {
DataType dataType = expression.getDataType();
DataType childType = ((Cast) expression).child().getDataType();
// avoid substring cast such as cast(char(3) as char(2))
if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
return validForInfer(((Cast) expression).child(), inferType);
}
}
} else {
return Optional.empty();
}
return Optional.empty();
}
private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
DataType leftType = comparisonPredicate.left().getDataType();
InferType inferType;
if (leftType instanceof CharacterType) {
inferType = InferType.STRING;
} else if (leftType instanceof IntegralType) {
inferType = InferType.INTEGRAL;
} else if (leftType instanceof DateLikeType) {
inferType = InferType.DATE;
} else {
inferType = InferType.OTHER;
}
Optional<Expression> left = validForInfer(comparisonPredicate.left(), inferType);
Optional<Expression> right = validForInfer(comparisonPredicate.right(), inferType);
if (!left.isPresent() || !right.isPresent()) {
inferType = InferType.NONE;
}
return new ComparisonInferInfo(inferType, left, right, comparisonPredicate);
}
/**
* Currently only equivalence derivation is supported
* and requires that the left and right sides of an expression must be slot
*/
private boolean canEquivalentInfer(Expression predicate) {
return predicate instanceof EqualTo
&& predicate.children().stream().allMatch(e ->
(e instanceof SlotReference) || (e instanceof Cast && e.child(0) instanceof SlotReference))
&& predicate.child(0).getDataType().equals(predicate.child(1).getDataType());
private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate predicate) {
if (!(predicate instanceof EqualTo)) {
return new ComparisonInferInfo(InferType.NONE,
Optional.of(predicate.left()), Optional.of(predicate.right()), predicate);
}
ComparisonInferInfo info = inferInferInfo(predicate);
if (info.inferType == InferType.NONE) {
return info;
}
if (info.left.get() instanceof SlotReference && info.right.get() instanceof SlotReference) {
return info;
}
return new ComparisonInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
}
}

View File

@ -39,7 +39,6 @@ import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.types.DataType;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
@ -253,34 +252,6 @@ public class ExpressionUtils {
}
}
/**
* get slot covered by cast
* example: input: cast(cast(table.columnA)) output: columnA.datatype
*
*/
public static DataType getDatatypeCoveredByCast(Expression expr) {
if (expr instanceof Cast) {
return getDatatypeCoveredByCast(((Cast) expr).child());
}
return expr.getDataType();
}
/**
* judge if expression is slot covered by cast
* example: cast(cast(table.columnA))
*/
public static boolean isExpressionSlotCoveredByCast(Expression expr) {
if (expr instanceof Cast) {
return isExpressionSlotCoveredByCast(((Cast) expr).child());
}
return expr instanceof SlotReference;
}
public static boolean isTwoExpressionEqualWithCast(Expression left, Expression right) {
return ExpressionUtils.extractSlotOrCastOnSlot(left)
.equals(ExpressionUtils.extractSlotOrCastOnSlot(right));
}
/**
* Replace expression node in the expression tree by `replaceMap` in top-down manner.
* For example.