[feature](Nereids) covert predicate to SARGABLE (#25180)
covert predicate to SARGABLE 1. support format like `1 - a` 2. support rearrange `year/month/week/day/minutes/seconds_sub/add` function
This commit is contained in:
@ -21,16 +21,34 @@ import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.trees.expressions.Add;
|
||||
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Divide;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.GreaterThan;
|
||||
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
|
||||
import org.apache.doris.nereids.trees.expressions.LessThan;
|
||||
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
|
||||
import org.apache.doris.nereids.trees.expressions.Multiply;
|
||||
import org.apache.doris.nereids.trees.expressions.Subtract;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.DaysAdd;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.DaysSub;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursAdd;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursSub;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.MinutesAdd;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.MinutesSub;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.MonthsAdd;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.MonthsSub;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.SecondsAdd;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.SecondsSub;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.WeeksAdd;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.WeeksSub;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.YearsAdd;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.YearsSub;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.util.TypeCoercionUtils;
|
||||
import org.apache.doris.nereids.util.TypeUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/**
|
||||
* Simplify arithmetic comparison rule.
|
||||
@ -40,68 +58,90 @@ import org.apache.doris.nereids.util.TypeUtils;
|
||||
public class SimplifyArithmeticComparisonRule extends AbstractExpressionRewriteRule {
|
||||
public static final SimplifyArithmeticComparisonRule INSTANCE = new SimplifyArithmeticComparisonRule();
|
||||
|
||||
@Override
|
||||
public Expression visit(Expression expr, ExpressionRewriteContext context) {
|
||||
return expr;
|
||||
}
|
||||
// don't rearrange multiplication because divide may loss precision
|
||||
final Map<Class<? extends Expression>, Class<? extends Expression>> rearrangementMap = ImmutableMap
|
||||
.<Class<? extends Expression>, Class<? extends Expression>>builder()
|
||||
.put(Add.class, Subtract.class)
|
||||
.put(Subtract.class, Add.class)
|
||||
.put(Divide.class, Multiply.class)
|
||||
.put(YearsSub.class, YearsAdd.class)
|
||||
.put(YearsAdd.class, YearsSub.class)
|
||||
.put(MonthsSub.class, MonthsAdd.class)
|
||||
.put(MonthsAdd.class, MonthsSub.class)
|
||||
.put(WeeksSub.class, WeeksAdd.class)
|
||||
.put(WeeksAdd.class, WeeksSub.class)
|
||||
.put(DaysSub.class, DaysAdd.class)
|
||||
.put(DaysAdd.class, DaysSub.class)
|
||||
.put(HoursSub.class, HoursAdd.class)
|
||||
.put(HoursAdd.class, HoursSub.class)
|
||||
.put(MinutesSub.class, MinutesAdd.class)
|
||||
.put(MinutesAdd.class, MinutesSub.class)
|
||||
.put(SecondsSub.class, SecondsAdd.class)
|
||||
.put(SecondsAdd.class, SecondsSub.class)
|
||||
.build();
|
||||
|
||||
private Expression process(ComparisonPredicate predicate) {
|
||||
Expression left = predicate.left();
|
||||
Expression right = predicate.right();
|
||||
if (TypeUtils.isAddOrSubtract(left)) {
|
||||
Expression p = left.child(1);
|
||||
if (p.isConstant()) {
|
||||
if (TypeUtils.isAdd(left)) {
|
||||
right = new Subtract(right, p);
|
||||
}
|
||||
if (TypeUtils.isSubtract(left)) {
|
||||
right = new Add(right, p);
|
||||
}
|
||||
left = left.child(0);
|
||||
@Override
|
||||
public Expression visitComparisonPredicate(ComparisonPredicate comparison, ExpressionRewriteContext context) {
|
||||
ComparisonPredicate newComparison = comparison;
|
||||
if (couldRearrange(comparison)) {
|
||||
newComparison = normalize(comparison);
|
||||
if (newComparison == null) {
|
||||
return comparison;
|
||||
}
|
||||
try {
|
||||
List<Expression> children = tryRearrangeChildren(newComparison.left(), newComparison.right());
|
||||
newComparison = (ComparisonPredicate) newComparison.withChildren(children);
|
||||
} catch (Exception e) {
|
||||
return comparison;
|
||||
}
|
||||
}
|
||||
if (TypeUtils.isDivide(left)) {
|
||||
Expression p = left.child(1);
|
||||
if (p.isLiteral()) {
|
||||
right = new Multiply(right, p);
|
||||
left = left.child(0);
|
||||
if (p.toString().startsWith("-")) {
|
||||
Expression tmp = right;
|
||||
right = left;
|
||||
left = tmp;
|
||||
}
|
||||
return TypeCoercionUtils.processComparisonPredicate(newComparison);
|
||||
}
|
||||
|
||||
private boolean couldRearrange(ComparisonPredicate cmp) {
|
||||
return rearrangementMap.containsKey(cmp.left().getClass())
|
||||
&& !cmp.left().isConstant()
|
||||
&& cmp.left().children().stream().anyMatch(Expression::isConstant);
|
||||
}
|
||||
|
||||
private List<Expression> tryRearrangeChildren(Expression left, Expression right) throws Exception {
|
||||
if (!left.child(1).isLiteral()) {
|
||||
throw new RuntimeException(String.format("Expected literal when arranging children for Expr %s", left));
|
||||
}
|
||||
Literal leftLiteral = (Literal) left.child(1);
|
||||
Expression leftExpr = left.child(0);
|
||||
|
||||
Class<? extends Expression> oppositeOperator = rearrangementMap.get(left.getClass());
|
||||
Expression newChild = oppositeOperator.getConstructor(Expression.class, Expression.class)
|
||||
.newInstance(right, leftLiteral);
|
||||
|
||||
if (left instanceof Divide && leftLiteral.compareTo(new IntegerLiteral(0)) < 0) {
|
||||
// Multiplying by a negative number will change the operator.
|
||||
return Arrays.asList(newChild, leftExpr);
|
||||
}
|
||||
return Arrays.asList(leftExpr, newChild);
|
||||
}
|
||||
|
||||
// Ensure that the second child must be Literal, such as
|
||||
private @Nullable ComparisonPredicate normalize(ComparisonPredicate comparison) {
|
||||
if (!(comparison.left().child(1) instanceof Literal)) {
|
||||
Expression left = comparison.left();
|
||||
if (comparison.left() instanceof Add) {
|
||||
// 1 + a > 1 => a + 1 > 1
|
||||
Expression newLeft = left.withChildren(left.child(1), left.child(0));
|
||||
comparison = (ComparisonPredicate) comparison.withChildren(newLeft, comparison.right());
|
||||
} else if (comparison.left() instanceof Subtract) {
|
||||
// 1 - a > 1 => a + 1 < 1
|
||||
Expression newLeft = left.child(0);
|
||||
Expression newRight = new Add(left.child(1), comparison.right());
|
||||
comparison = (ComparisonPredicate) comparison.withChildren(newLeft, newRight);
|
||||
comparison = comparison.commute();
|
||||
} else {
|
||||
// Don't normalize division/multiplication because the slot sign is undecided.
|
||||
return null;
|
||||
}
|
||||
}
|
||||
if (left != predicate.left() || right != predicate.right()) {
|
||||
predicate = (ComparisonPredicate) predicate.withChildren(left, right);
|
||||
return TypeCoercionUtils.processComparisonPredicate(predicate);
|
||||
} else {
|
||||
return predicate;
|
||||
}
|
||||
return comparison;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitGreaterThan(GreaterThan greaterThan, ExpressionRewriteContext context) {
|
||||
return process(greaterThan);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, ExpressionRewriteContext context) {
|
||||
return process(greaterThanEqual);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitEqualTo(EqualTo equalTo, ExpressionRewriteContext context) {
|
||||
return process(equalTo);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitLessThan(LessThan lessThan, ExpressionRewriteContext context) {
|
||||
return process(lessThan);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitLessThanEqual(LessThanEqual lessThanEqual, ExpressionRewriteContext context) {
|
||||
return process(lessThanEqual);
|
||||
}
|
||||
}
|
||||
|
||||
@ -25,9 +25,9 @@ import org.apache.doris.nereids.rules.expression.rules.SimplifyArithmeticRule;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class SimplifyArithmeticRuleTest extends ExpressionRewriteTestHelper {
|
||||
class SimplifyArithmeticRuleTest extends ExpressionRewriteTestHelper {
|
||||
@Test
|
||||
public void testSimplifyArithmetic() {
|
||||
void testSimplifyArithmetic() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
SimplifyArithmeticRule.INSTANCE,
|
||||
FunctionBinder.INSTANCE,
|
||||
@ -53,7 +53,7 @@ public class SimplifyArithmeticRuleTest extends ExpressionRewriteTestHelper {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSimplifyArithmeticComparison() {
|
||||
void testSimplifyArithmeticComparison() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
SimplifyArithmeticRule.INSTANCE,
|
||||
FoldConstantRule.INSTANCE,
|
||||
@ -88,7 +88,35 @@ public class SimplifyArithmeticRuleTest extends ExpressionRewriteTestHelper {
|
||||
assertRewriteAfterTypeCoercion("IA * ID > IB * IC", "IA * ID > IB * IC");
|
||||
assertRewriteAfterTypeCoercion("IA * ID / 2 > IB * IC", "cast((IA * ID) as DOUBLE) > cast((IB * IC) as DOUBLE) * 2");
|
||||
assertRewriteAfterTypeCoercion("IA * ID / -2 > IB * IC", "cast((IB * IC) as DOUBLE) * -2 > cast((IA * ID) as DOUBLE)");
|
||||
assertRewriteAfterTypeCoercion("1 - IA > 1", "(cast(IA as BIGINT) < 0)");
|
||||
assertRewriteAfterTypeCoercion("1 - IA + 1 * 3 - 5 > 1", "(cast(IA as BIGINT) < -2)");
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSimplifyDateTimeComparison() {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(
|
||||
SimplifyArithmeticRule.INSTANCE,
|
||||
FoldConstantRule.INSTANCE,
|
||||
SimplifyArithmeticComparisonRule.INSTANCE,
|
||||
SimplifyArithmeticRule.INSTANCE,
|
||||
FunctionBinder.INSTANCE,
|
||||
FoldConstantRule.INSTANCE
|
||||
));
|
||||
assertRewriteAfterTypeCoercion("years_add(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2020-01-01 00:00:00')");
|
||||
assertRewriteAfterTypeCoercion("years_sub(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2022-01-01 00:00:00')");
|
||||
assertRewriteAfterTypeCoercion("months_add(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2020-12-01 00:00:00')");
|
||||
assertRewriteAfterTypeCoercion("months_sub(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2021-02-01 00:00:00')");
|
||||
assertRewriteAfterTypeCoercion("weeks_add(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2020-12-25 00:00:00')");
|
||||
assertRewriteAfterTypeCoercion("weeks_sub(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2021-01-08 00:00:00')");
|
||||
assertRewriteAfterTypeCoercion("days_add(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2020-12-31 00:00:00')");
|
||||
assertRewriteAfterTypeCoercion("days_sub(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2021-01-02 00:00:00')");
|
||||
assertRewriteAfterTypeCoercion("hours_add(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2020-12-31 23:00:00')");
|
||||
assertRewriteAfterTypeCoercion("hours_sub(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2021-01-01 01:00:00')");
|
||||
assertRewriteAfterTypeCoercion("minutes_add(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2020-12-31 23:59:00')");
|
||||
assertRewriteAfterTypeCoercion("minutes_sub(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2021-01-01 00:01:00')");
|
||||
assertRewriteAfterTypeCoercion("seconds_add(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2020-12-31 23:59:59')");
|
||||
assertRewriteAfterTypeCoercion("seconds_sub(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2021-01-01 00:00:01')");
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user