[Fix](Nereids) fix infer predicate lost cast of source expression (#23692)

Problem:
When inferring predicate,we lost cast of source expressions and some datatype derivation.

Example:
a = b and cast(a as targetType) = constant
(cast(a as targetType) = constant ) this expression is define as source expression.
we expect getting cast(b as targetType) = constant instead of b = constant

Reason:
When inferring predicate, we will compare original type of a and b. if they can be cast
without precision lost, a new predicate would be created. But created predicate forgot
to cast to target type

Solved:
Add cast to target type, and open make other datatype valid also.
This commit is contained in:
LiBinfeng
2023-09-11 14:30:31 +08:00
committed by GitHub
parent e847091dfe
commit be3618316f
3 changed files with 71 additions and 16 deletions

View File

@ -59,12 +59,12 @@ public class PredicatePropagation {
}
/**
* Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression`
* Use the left or right child of `equalExpr` to replace the left or right child of `expression`
* Now only support infer `ComparisonPredicate`.
* TODO: We should determine whether `expression` satisfies the condition for replacement
* eg: Satisfy `expression` is non-deterministic
*/
private Expression doInfer(Expression leftSlotEqualToRightSlot, Expression expression) {
private Expression doInfer(Expression equalExpr, Expression expression) {
return expression.accept(new DefaultExpressionRewriter<Void>() {
@Override
@ -76,36 +76,43 @@ public class PredicatePropagation {
public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) {
// we need to get expression covered by cast, because we want to infer different datatype
if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) && (cp.right().isConstant())) {
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left()));
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left()), equalExpr);
} else if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && cp.left().isConstant()) {
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right()));
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right()), equalExpr);
}
return super.visit(cp, context);
}
private boolean isDataTypeValid(DataType originDataType, Expression expr) {
if ((leftSlotEqualToRightSlot.child(0).getDataType() instanceof IntegralType)
&& (leftSlotEqualToRightSlot.child(1).getDataType() instanceof IntegralType)
if ((expr.child(0).getDataType() instanceof IntegralType)
&& (expr.child(1).getDataType() instanceof IntegralType)
&& (originDataType instanceof IntegralType)) {
// infer filter can not be lower than original datatype, or dataset would be wrong
if (!((IntegralType) originDataType).widerThan(
(IntegralType) leftSlotEqualToRightSlot.child(0).getDataType())
(IntegralType) expr.child(0).getDataType())
&& !((IntegralType) originDataType).widerThan(
(IntegralType) leftSlotEqualToRightSlot.child(1).getDataType())) {
(IntegralType) expr.child(1).getDataType())) {
return true;
}
} else if (expr.child(0).getDataType().equals(expr.child(1).getDataType())) {
return true;
}
return false;
}
private Expression replaceSlot(Expression expr, DataType originDataType) {
return expr.rewriteUp(e -> {
if (isDataTypeValid(originDataType, leftSlotEqualToRightSlot)) {
if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(0))) {
return leftSlotEqualToRightSlot.child(1);
} else if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(1))) {
return leftSlotEqualToRightSlot.child(0);
}
private Expression replaceSlot(Expression sourcePredicate, DataType originDataType, Expression equal) {
if (!isDataTypeValid(originDataType, equal)) {
return sourcePredicate;
}
return sourcePredicate.rewriteUp(e -> {
// we can not replace Cast expression to slot because when rewrite up, we have replace child of cast
if (e instanceof Cast) {
return e;
}
if (ExpressionUtils.isTwoExpressionEqualWithCast(e, equal.child(0))) {
return equal.child(1);
} else if (ExpressionUtils.isTwoExpressionEqualWithCast(e, equal.child(1))) {
return equal.child(0);
}
return e;
});

View File

@ -17,15 +17,33 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.utframe.TestWithFeService;
import com.google.common.collect.Sets;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.Optional;
import java.util.Set;
public class InferPredicatesTest extends TestWithFeService implements MemoPatternMatchSupported {
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
private final PredicatePropagation propagation = new PredicatePropagation();
@Override
protected void runBeforeAll() throws Exception {
createDatabase("test");
@ -628,4 +646,16 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
).when(join -> join.getJoinType() == JoinType.LEFT_OUTER_JOIN)
);
}
@Test
void testInfer() {
EqualTo equalTo = new EqualTo(new Cast(scan1.getOutput().get(0), BigIntType.INSTANCE), Literal.of(1));
EqualTo equalTo2 = new EqualTo(scan2.getOutput().get(0), scan1.getOutput().get(0));
Set<Expression> predicates = Sets.newHashSet();
predicates.add(equalTo2);
predicates.add(equalTo);
Set<Expression> newPredicates = propagation.infer(predicates);
Optional<Expression> newPredicate = newPredicates.stream().findFirst();
Assertions.assertTrue(newPredicate.get().equals(new EqualTo(new Cast(scan2.getOutput().get(0), BigIntType.INSTANCE), Literal.of(1))));
}
}