[Fix](Nereids) fix infer predicate lost cast of source expression (#23692)
Problem: When inferring predicate,we lost cast of source expressions and some datatype derivation. Example: a = b and cast(a as targetType) = constant (cast(a as targetType) = constant ) this expression is define as source expression. we expect getting cast(b as targetType) = constant instead of b = constant Reason: When inferring predicate, we will compare original type of a and b. if they can be cast without precision lost, a new predicate would be created. But created predicate forgot to cast to target type Solved: Add cast to target type, and open make other datatype valid also.
This commit is contained in:
@ -59,12 +59,12 @@ public class PredicatePropagation {
|
||||
}
|
||||
|
||||
/**
|
||||
* Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression`
|
||||
* Use the left or right child of `equalExpr` to replace the left or right child of `expression`
|
||||
* Now only support infer `ComparisonPredicate`.
|
||||
* TODO: We should determine whether `expression` satisfies the condition for replacement
|
||||
* eg: Satisfy `expression` is non-deterministic
|
||||
*/
|
||||
private Expression doInfer(Expression leftSlotEqualToRightSlot, Expression expression) {
|
||||
private Expression doInfer(Expression equalExpr, Expression expression) {
|
||||
return expression.accept(new DefaultExpressionRewriter<Void>() {
|
||||
|
||||
@Override
|
||||
@ -76,36 +76,43 @@ public class PredicatePropagation {
|
||||
public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) {
|
||||
// we need to get expression covered by cast, because we want to infer different datatype
|
||||
if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) && (cp.right().isConstant())) {
|
||||
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left()));
|
||||
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left()), equalExpr);
|
||||
} else if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && cp.left().isConstant()) {
|
||||
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right()));
|
||||
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right()), equalExpr);
|
||||
}
|
||||
return super.visit(cp, context);
|
||||
}
|
||||
|
||||
private boolean isDataTypeValid(DataType originDataType, Expression expr) {
|
||||
if ((leftSlotEqualToRightSlot.child(0).getDataType() instanceof IntegralType)
|
||||
&& (leftSlotEqualToRightSlot.child(1).getDataType() instanceof IntegralType)
|
||||
if ((expr.child(0).getDataType() instanceof IntegralType)
|
||||
&& (expr.child(1).getDataType() instanceof IntegralType)
|
||||
&& (originDataType instanceof IntegralType)) {
|
||||
// infer filter can not be lower than original datatype, or dataset would be wrong
|
||||
if (!((IntegralType) originDataType).widerThan(
|
||||
(IntegralType) leftSlotEqualToRightSlot.child(0).getDataType())
|
||||
(IntegralType) expr.child(0).getDataType())
|
||||
&& !((IntegralType) originDataType).widerThan(
|
||||
(IntegralType) leftSlotEqualToRightSlot.child(1).getDataType())) {
|
||||
(IntegralType) expr.child(1).getDataType())) {
|
||||
return true;
|
||||
}
|
||||
} else if (expr.child(0).getDataType().equals(expr.child(1).getDataType())) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private Expression replaceSlot(Expression expr, DataType originDataType) {
|
||||
return expr.rewriteUp(e -> {
|
||||
if (isDataTypeValid(originDataType, leftSlotEqualToRightSlot)) {
|
||||
if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(0))) {
|
||||
return leftSlotEqualToRightSlot.child(1);
|
||||
} else if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(1))) {
|
||||
return leftSlotEqualToRightSlot.child(0);
|
||||
}
|
||||
private Expression replaceSlot(Expression sourcePredicate, DataType originDataType, Expression equal) {
|
||||
if (!isDataTypeValid(originDataType, equal)) {
|
||||
return sourcePredicate;
|
||||
}
|
||||
return sourcePredicate.rewriteUp(e -> {
|
||||
// we can not replace Cast expression to slot because when rewrite up, we have replace child of cast
|
||||
if (e instanceof Cast) {
|
||||
return e;
|
||||
}
|
||||
if (ExpressionUtils.isTwoExpressionEqualWithCast(e, equal.child(0))) {
|
||||
return equal.child(1);
|
||||
} else if (ExpressionUtils.isTwoExpressionEqualWithCast(e, equal.child(1))) {
|
||||
return equal.child(0);
|
||||
}
|
||||
return e;
|
||||
});
|
||||
|
||||
@ -17,15 +17,33 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
import org.apache.doris.utframe.TestWithFeService;
|
||||
|
||||
import com.google.common.collect.Sets;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
public class InferPredicatesTest extends TestWithFeService implements MemoPatternMatchSupported {
|
||||
|
||||
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
|
||||
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
|
||||
private final PredicatePropagation propagation = new PredicatePropagation();
|
||||
|
||||
@Override
|
||||
protected void runBeforeAll() throws Exception {
|
||||
createDatabase("test");
|
||||
@ -628,4 +646,16 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
|
||||
).when(join -> join.getJoinType() == JoinType.LEFT_OUTER_JOIN)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testInfer() {
|
||||
EqualTo equalTo = new EqualTo(new Cast(scan1.getOutput().get(0), BigIntType.INSTANCE), Literal.of(1));
|
||||
EqualTo equalTo2 = new EqualTo(scan2.getOutput().get(0), scan1.getOutput().get(0));
|
||||
Set<Expression> predicates = Sets.newHashSet();
|
||||
predicates.add(equalTo2);
|
||||
predicates.add(equalTo);
|
||||
Set<Expression> newPredicates = propagation.infer(predicates);
|
||||
Optional<Expression> newPredicate = newPredicates.stream().findFirst();
|
||||
Assertions.assertTrue(newPredicate.get().equals(new EqualTo(new Cast(scan2.getOutput().get(0), BigIntType.INSTANCE), Literal.of(1))));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user