[feature](Nereids): Rewrite InPredicate to disjunction if there exist items < 3 elements in InPredicate (#17646)

* [feature](Nereids): Rewrite InPredicate to disjunction if there exists < 3 elements in InPredicate

* fix SimplifyRange
This commit is contained in:
jakevin
2023-03-15 08:23:56 +08:00
committed by GitHub
parent 02220560c5
commit 7872f3626a
5 changed files with 39 additions and 16 deletions

View File

@ -22,14 +22,26 @@ import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContex
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.base.Preconditions;
import java.util.List;
import java.util.stream.Collectors;
/**
* Rewrite InPredicate to an EqualTo Expression, if there exists exactly one element in InPredicate.options
* Paper: Quantifying TPC-H Choke Points and Their Optimizations
* - Figure 14:
* <p>
* Rewrite InPredicate to disjunction, if there exists < 3 elements in InPredicate
* Examples:
* where A in (x, y) ==> where A = x or A = y
* Examples:
* where A in (x) ==> where A = x
* where A not in (x) ==> where not A = x (After ExpressionTranslator, "not A = x" will be translated to "A != x")
* <p>
* NOTICE: it's related with `SimplifyRange`.
* They are same processes, so must change synchronously.
*/
public class InPredicateToEqualToRule extends AbstractExpressionRewriteRule {
@ -37,11 +49,16 @@ public class InPredicateToEqualToRule extends AbstractExpressionRewriteRule {
@Override
public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) {
Expression left = inPredicate.getCompareExpr();
List<Expression> right = inPredicate.getOptions();
if (right.size() != 1) {
return new InPredicate(left.accept(this, context), right);
Expression cmpExpr = inPredicate.getCompareExpr();
List<Expression> options = inPredicate.getOptions();
Preconditions.checkArgument(options.size() > 0, "InPredicate.options should not be empty");
if (options.size() > 2) {
return new InPredicate(cmpExpr.accept(this, context), options);
}
return new EqualTo(left.accept(this, context), right.get(0).accept(this, context));
Expression newCmpExpr = cmpExpr.accept(this, context);
List<Expression> disjunction = options.stream()
.map(option -> new EqualTo(newCmpExpr, option.accept(this, context)))
.collect(Collectors.toList());
return ExpressionUtils.or(disjunction);
}
}

View File

@ -44,6 +44,7 @@ import com.google.common.collect.Sets;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@ -59,7 +60,7 @@ import java.util.stream.Collectors;
* a > 1 or a > 2 => a > 1
* a in (1,2,3) and a > 1 => a in (2,3)
* a in (1,2,3) and a in (3,4,5) => a = 3
* a in(1,2,3) and a in (4,5,6) => false
* a in (1,2,3) and a in (4,5,6) => false
* The logic is as follows:
* 1. for `And` expression.
* 1. extract conjunctions then build `ValueDesc` for each conjunction
@ -402,8 +403,13 @@ public class SimplifyRange extends AbstractExpressionRewriteRule {
@Override
public Expression toExpression() {
// NOTICE: it's related with `InPredicateToEqualToRule`
// They are same processes, so must change synchronously.
if (values.size() == 1) {
return new EqualTo(reference, values.iterator().next());
} else if (values.size() == 2) {
Iterator<Literal> iterator = values.iterator();
return new Or(new EqualTo(reference, iterator.next()), new EqualTo(reference, iterator.next()));
} else {
return new InPredicate(reference, Lists.newArrayList(values));
}

View File

@ -55,7 +55,7 @@ public class SSBJoinReorderTest extends SSBTestBase implements MemoPatternMatchS
"(lo_partkey = p_partkey)"
),
ImmutableList.of(
"d_year IN (1997, 1998)",
"((d_year = 1997) OR (d_year = 1998))",
"(c_region = 'AMERICA')",
"(s_region = 'AMERICA')",
"((p_mfgr = 'MFGR#1') OR (p_mfgr = 'MFGR#2'))"
@ -74,7 +74,7 @@ public class SSBJoinReorderTest extends SSBTestBase implements MemoPatternMatchS
"(lo_partkey = p_partkey)"
),
ImmutableList.of(
"d_year IN (1997, 1998)",
"((d_year = 1997) OR (d_year = 1998))",
"(s_nation = 'UNITED STATES')",
"(p_category = 'MFGR#14')"
)

View File

@ -175,20 +175,20 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
executor = new ExpressionRuleExecutor(ImmutableList.of(InPredicateToEqualToRule.INSTANCE));
assertRewrite("a in (1)", "a = 1");
assertRewrite("a in (1, 2)", "a in (1, 2)");
assertRewrite("a in (1, 2)", "((a = 1) OR (a = 2))");
assertRewrite("a not in (1)", "not a = 1");
assertRewrite("a not in (1, 2)", "not a in (1, 2)");
assertRewrite("a not in (1, 2)", "not ((a = 1) OR (a = 2))");
assertRewrite("a in (a in (1))", "a = (a = 1)");
assertRewrite("a in (a in (1, 2))", "a = (a in (1, 2))");
assertRewrite("a in (a in (1, 2))", "a = ((a = 1) OR (a = 2))");
assertRewrite("(a in (1)) in (1)", "(a = 1) = 1");
assertRewrite("(a in (1, 2)) in (1)", "(a in (1, 2)) = 1");
assertRewrite("(a in (1)) in (1, 2)", "(a = 1) in (1, 2)");
assertRewrite("(a in (1, 2)) in (1)", "((a = 1) OR (a = 2)) = 1");
assertRewrite("(a in (1)) in (1, 2)", "((a = 1) = 1) OR ((a = 1) = 2)");
assertRewrite("case a when b in (1) then a else c end in (1)",
"case a when b = 1 then a else c end = 1");
assertRewrite("case a when b not in (1) then a else c end not in (1)",
"not case a when not b = 1 then a else c end = 1");
assertRewrite("case a when b not in (1) then a else c end in (1, 2)",
"case a when not b = 1 then a else c end in (1, 2)");
"(CASE WHEN (a = ( not (b = 1))) THEN a ELSE c END = 1) OR (CASE WHEN (a = ( not (b = 1))) THEN a ELSE c END = 2)");
}

View File

@ -87,7 +87,7 @@ public class SimplifyRangeTest {
assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", "(TA > 5 and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))");
assertRewrite("TA in (1,2,3) and TA > 10", "FALSE");
assertRewrite("TA in (1,2,3) and TA >= 1", "TA in (1,2,3)");
assertRewrite("TA in (1,2,3) and TA > 1", "TA in (2,3)");
assertRewrite("TA in (1,2,3) and TA > 1", "((TA = 2) OR (TA = 3))");
assertRewrite("TA in (1,2,3) or TA >= 1", "TA >= 1");
assertRewrite("TA in (1)", "TA in (1)");
assertRewrite("TA in (1,2,3) and TA < 10", "TA in (1,2,3)");