[feature](Nereids): Rewrite InPredicate to disjunction if there exist items < 3 elements in InPredicate (#17646)
* [feature](Nereids): Rewrite InPredicate to disjunction if there exists < 3 elements in InPredicate * fix SimplifyRange
This commit is contained in:
@ -22,14 +22,26 @@ import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContex
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.InPredicate;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Rewrite InPredicate to an EqualTo Expression, if there exists exactly one element in InPredicate.options
|
||||
* Paper: Quantifying TPC-H Choke Points and Their Optimizations
|
||||
* - Figure 14:
|
||||
* <p>
|
||||
* Rewrite InPredicate to disjunction, if there exists < 3 elements in InPredicate
|
||||
* Examples:
|
||||
* where A in (x, y) ==> where A = x or A = y
|
||||
* Examples:
|
||||
* where A in (x) ==> where A = x
|
||||
* where A not in (x) ==> where not A = x (After ExpressionTranslator, "not A = x" will be translated to "A != x")
|
||||
* <p>
|
||||
* NOTICE: it's related with `SimplifyRange`.
|
||||
* They are same processes, so must change synchronously.
|
||||
*/
|
||||
public class InPredicateToEqualToRule extends AbstractExpressionRewriteRule {
|
||||
|
||||
@ -37,11 +49,16 @@ public class InPredicateToEqualToRule extends AbstractExpressionRewriteRule {
|
||||
|
||||
@Override
|
||||
public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) {
|
||||
Expression left = inPredicate.getCompareExpr();
|
||||
List<Expression> right = inPredicate.getOptions();
|
||||
if (right.size() != 1) {
|
||||
return new InPredicate(left.accept(this, context), right);
|
||||
Expression cmpExpr = inPredicate.getCompareExpr();
|
||||
List<Expression> options = inPredicate.getOptions();
|
||||
Preconditions.checkArgument(options.size() > 0, "InPredicate.options should not be empty");
|
||||
if (options.size() > 2) {
|
||||
return new InPredicate(cmpExpr.accept(this, context), options);
|
||||
}
|
||||
return new EqualTo(left.accept(this, context), right.get(0).accept(this, context));
|
||||
Expression newCmpExpr = cmpExpr.accept(this, context);
|
||||
List<Expression> disjunction = options.stream()
|
||||
.map(option -> new EqualTo(newCmpExpr, option.accept(this, context)))
|
||||
.collect(Collectors.toList());
|
||||
return ExpressionUtils.or(disjunction);
|
||||
}
|
||||
}
|
||||
|
||||
@ -44,6 +44,7 @@ import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Iterator;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@ -59,7 +60,7 @@ import java.util.stream.Collectors;
|
||||
* a > 1 or a > 2 => a > 1
|
||||
* a in (1,2,3) and a > 1 => a in (2,3)
|
||||
* a in (1,2,3) and a in (3,4,5) => a = 3
|
||||
* a in(1,2,3) and a in (4,5,6) => false
|
||||
* a in (1,2,3) and a in (4,5,6) => false
|
||||
* The logic is as follows:
|
||||
* 1. for `And` expression.
|
||||
* 1. extract conjunctions then build `ValueDesc` for each conjunction
|
||||
@ -402,8 +403,13 @@ public class SimplifyRange extends AbstractExpressionRewriteRule {
|
||||
|
||||
@Override
|
||||
public Expression toExpression() {
|
||||
// NOTICE: it's related with `InPredicateToEqualToRule`
|
||||
// They are same processes, so must change synchronously.
|
||||
if (values.size() == 1) {
|
||||
return new EqualTo(reference, values.iterator().next());
|
||||
} else if (values.size() == 2) {
|
||||
Iterator<Literal> iterator = values.iterator();
|
||||
return new Or(new EqualTo(reference, iterator.next()), new EqualTo(reference, iterator.next()));
|
||||
} else {
|
||||
return new InPredicate(reference, Lists.newArrayList(values));
|
||||
}
|
||||
|
||||
@ -55,7 +55,7 @@ public class SSBJoinReorderTest extends SSBTestBase implements MemoPatternMatchS
|
||||
"(lo_partkey = p_partkey)"
|
||||
),
|
||||
ImmutableList.of(
|
||||
"d_year IN (1997, 1998)",
|
||||
"((d_year = 1997) OR (d_year = 1998))",
|
||||
"(c_region = 'AMERICA')",
|
||||
"(s_region = 'AMERICA')",
|
||||
"((p_mfgr = 'MFGR#1') OR (p_mfgr = 'MFGR#2'))"
|
||||
@ -74,7 +74,7 @@ public class SSBJoinReorderTest extends SSBTestBase implements MemoPatternMatchS
|
||||
"(lo_partkey = p_partkey)"
|
||||
),
|
||||
ImmutableList.of(
|
||||
"d_year IN (1997, 1998)",
|
||||
"((d_year = 1997) OR (d_year = 1998))",
|
||||
"(s_nation = 'UNITED STATES')",
|
||||
"(p_category = 'MFGR#14')"
|
||||
)
|
||||
|
||||
@ -175,20 +175,20 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
executor = new ExpressionRuleExecutor(ImmutableList.of(InPredicateToEqualToRule.INSTANCE));
|
||||
|
||||
assertRewrite("a in (1)", "a = 1");
|
||||
assertRewrite("a in (1, 2)", "a in (1, 2)");
|
||||
assertRewrite("a in (1, 2)", "((a = 1) OR (a = 2))");
|
||||
assertRewrite("a not in (1)", "not a = 1");
|
||||
assertRewrite("a not in (1, 2)", "not a in (1, 2)");
|
||||
assertRewrite("a not in (1, 2)", "not ((a = 1) OR (a = 2))");
|
||||
assertRewrite("a in (a in (1))", "a = (a = 1)");
|
||||
assertRewrite("a in (a in (1, 2))", "a = (a in (1, 2))");
|
||||
assertRewrite("a in (a in (1, 2))", "a = ((a = 1) OR (a = 2))");
|
||||
assertRewrite("(a in (1)) in (1)", "(a = 1) = 1");
|
||||
assertRewrite("(a in (1, 2)) in (1)", "(a in (1, 2)) = 1");
|
||||
assertRewrite("(a in (1)) in (1, 2)", "(a = 1) in (1, 2)");
|
||||
assertRewrite("(a in (1, 2)) in (1)", "((a = 1) OR (a = 2)) = 1");
|
||||
assertRewrite("(a in (1)) in (1, 2)", "((a = 1) = 1) OR ((a = 1) = 2)");
|
||||
assertRewrite("case a when b in (1) then a else c end in (1)",
|
||||
"case a when b = 1 then a else c end = 1");
|
||||
assertRewrite("case a when b not in (1) then a else c end not in (1)",
|
||||
"not case a when not b = 1 then a else c end = 1");
|
||||
assertRewrite("case a when b not in (1) then a else c end in (1, 2)",
|
||||
"case a when not b = 1 then a else c end in (1, 2)");
|
||||
"(CASE WHEN (a = ( not (b = 1))) THEN a ELSE c END = 1) OR (CASE WHEN (a = ( not (b = 1))) THEN a ELSE c END = 2)");
|
||||
|
||||
}
|
||||
|
||||
|
||||
@ -87,7 +87,7 @@ public class SimplifyRangeTest {
|
||||
assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", "(TA > 5 and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))");
|
||||
assertRewrite("TA in (1,2,3) and TA > 10", "FALSE");
|
||||
assertRewrite("TA in (1,2,3) and TA >= 1", "TA in (1,2,3)");
|
||||
assertRewrite("TA in (1,2,3) and TA > 1", "TA in (2,3)");
|
||||
assertRewrite("TA in (1,2,3) and TA > 1", "((TA = 2) OR (TA = 3))");
|
||||
assertRewrite("TA in (1,2,3) or TA >= 1", "TA >= 1");
|
||||
assertRewrite("TA in (1)", "TA in (1)");
|
||||
assertRewrite("TA in (1,2,3) and TA < 10", "TA in (1,2,3)");
|
||||
|
||||
Reference in New Issue
Block a user