From 0550dfaeb2a419c8f5730f3c2e7c84a1fe67f2c4 Mon Sep 17 00:00:00 2001 From: Henry2SS <45096548+Henry2SS@users.noreply.github.com> Date: Tue, 27 Dec 2022 18:39:53 +0800 Subject: [PATCH] [enhancement](rewrite) add OrToIn rule and fix ExtractCommonFactorsRule apply problems (#12872) Co-authored-by: wuhangze --- .../rewrite/ExtractCommonFactorsRule.java | 101 +++++++++++++++++- .../analysis/ListPartitionPrunerTest.java | 6 +- .../analysis/RangePartitionPruneTest.java | 14 +-- .../apache/doris/planner/QueryPlanTest.java | 21 ++++ .../ExtractCommonFactorsRuleFunctionTest.java | 1 - .../performance_p0/redundant_conjuncts.out | 2 +- 6 files changed, 129 insertions(+), 16 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java b/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java index 6ff72d858b..5a3bc34c8c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java @@ -25,6 +25,7 @@ import org.apache.doris.analysis.Expr; import org.apache.doris.analysis.InPredicate; import org.apache.doris.analysis.LiteralExpr; import org.apache.doris.analysis.SlotRef; +import org.apache.doris.analysis.TableName; import org.apache.doris.common.AnalysisException; import org.apache.doris.planner.PlanNode; import org.apache.doris.rewrite.ExprRewriter.ClauseType; @@ -68,6 +69,7 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule { @Override public Expr apply(Expr expr, Analyzer analyzer, ExprRewriter.ClauseType clauseType) throws AnalysisException { + Expr resultExpr = null; if (expr == null) { return null; } else if (expr instanceof CompoundPredicate @@ -77,12 +79,19 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule { return rewrittenExpr; } } else { - for (int i = 0; i < expr.getChildren().size(); i++) { + if (!(expr instanceof CompoundPredicate)) { + return expr; + } + + resultExpr = expr.clone(); + + for (int i = 0; i < resultExpr.getChildren().size(); i++) { Expr rewrittenExpr = apply(expr.getChild(i), analyzer, clauseType); if (rewrittenExpr != null) { - expr.setChild(i, rewrittenExpr); + resultExpr.setChild(i, rewrittenExpr); } } + return resultExpr; } return expr; } @@ -179,10 +188,10 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule { if (CollectionUtils.isNotEmpty(commonFactorList)) { result = new CompoundPredicate(CompoundPredicate.Operator.AND, makeCompound(commonFactorList, CompoundPredicate.Operator.AND), - makeCompound(remainingOrClause, CompoundPredicate.Operator.OR)); + makeCompoundRemaining(remainingOrClause, CompoundPredicate.Operator.OR)); result.setPrintSqlInParens(true); } else { - result = makeCompound(remainingOrClause, CompoundPredicate.Operator.OR); + result = makeCompoundRemaining(remainingOrClause, CompoundPredicate.Operator.OR); } if (LOG.isDebugEnabled()) { LOG.debug("equal ors: " + result.toSql()); @@ -399,6 +408,11 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule { /** * Rebuild CompoundPredicate, [a, e, f] AND => a and e and f + * Rewrite OR :[a, b, c] + * while (a.columnName == b.columnName == c.columnName) && (a,b,c) + * instance of (BinaryPredicate, InPredicate) + * && (a,b,c).op = BinaryPredicate.Operator.EQ =======>>>>>> + * =======>>>>>> columnName IN (a.value,b.value,c.value) */ private Expr makeCompound(List exprs, CompoundPredicate.Operator op) { if (CollectionUtils.isEmpty(exprs)) { @@ -415,6 +429,85 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule { return result; } + private Expr makeCompoundRemaining(List exprs, CompoundPredicate.Operator op) { + if (CollectionUtils.isEmpty(exprs)) { + return null; + } + if (exprs.size() == 1) { + return exprs.get(0); + } + + Expr rewritePredicate = null; + // only OR will be rewrite to IN + if (op == CompoundPredicate.Operator.OR) { + rewritePredicate = rewriteOrToIn(exprs); + // IF rewrite finished, rewritePredicate will not be null + // IF not rewrite, do compoundPredicate + if (rewritePredicate != null) { + return rewritePredicate; + } + } + + CompoundPredicate result = new CompoundPredicate(op, exprs.get(0), exprs.get(1)); + for (int i = 2; i < exprs.size(); i++) { + result = new CompoundPredicate(op, result.clone(), exprs.get(i)); + } + result.setPrintSqlInParens(true); + return result; + } + + private Expr rewriteOrToIn(List exprs) { + // remainingOR expr = BP IP + InPredicate inPredicate = null; + boolean isOrToInAllowed = true; + Set slotSet = new LinkedHashSet<>(); + + for (int i = 0; i < exprs.size(); i++) { + Expr predicate = exprs.get(i); + if (!(predicate instanceof BinaryPredicate) && !(predicate instanceof InPredicate)) { + isOrToInAllowed = false; + break; + } else if (!(predicate.getChild(0) instanceof SlotRef)) { + isOrToInAllowed = false; + break; + } else if (!(predicate.getChild(1) instanceof LiteralExpr)) { + isOrToInAllowed = false; + break; + } else if (predicate instanceof BinaryPredicate + && ((BinaryPredicate) predicate).getOp() != BinaryPredicate.Operator.EQ) { + isOrToInAllowed = false; + break; + } else { + TableName tableName = ((SlotRef) predicate.getChild(0)).getTableName(); + if (tableName != null) { + String tblName = tableName.toString(); + String columnWithTable = tblName + "." + ((SlotRef) predicate.getChild(0)).getColumnName(); + slotSet.add(columnWithTable); + } else { + slotSet.add(((SlotRef) predicate.getChild(0)).getColumnName()); + } + } + } + + // isOrToInAllowed : true, means can rewrite + // slotSet.size : nums of columnName in exprs, should be 1 + if (isOrToInAllowed && slotSet.size() == 1) { + // slotRef to get ColumnName + + // SlotRef firstSlot = (SlotRef) exprs.get(0).getChild(0); + List childrenList = exprs.get(0).getChildren(); + inPredicate = new InPredicate(exprs.get(0).getChild(0), + childrenList.subList(1, childrenList.size()), false); + + for (int i = 1; i < exprs.size(); i++) { + childrenList = exprs.get(i).getChildren(); + inPredicate.addChildren(childrenList.subList(1, childrenList.size())); + } + } + + return inPredicate; + } + /** * Convert RangeSet to Compound Predicate * @param slotRef: diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/ListPartitionPrunerTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/ListPartitionPrunerTest.java index d7bae60c8f..a377ef0d68 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/analysis/ListPartitionPrunerTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/ListPartitionPrunerTest.java @@ -109,9 +109,9 @@ public class ListPartitionPrunerTest extends PartitionPruneTestBase { addCase("select * from test.t4 where k1 >= 2 and k2 = \"shanghai\";", "partitions=2/3", "partitions=1/3"); // Disjunctive predicates - addCase("select * from test.t2 where k1=1 or k1=4", "partitions=3/3", "partitions=2/3"); - addCase("select * from test.t4 where k1=1 or k1=3", "partitions=3/3", "partitions=2/3"); - addCase("select * from test.t4 where k2=\"tianjin\" or k2=\"shanghai\"", "partitions=3/3", "partitions=2/3"); + addCase("select * from test.t2 where k1=1 or k1=4", "partitions=2/3", "partitions=2/3"); + addCase("select * from test.t4 where k1=1 or k1=3", "partitions=2/3", "partitions=2/3"); + addCase("select * from test.t4 where k2=\"tianjin\" or k2=\"shanghai\"", "partitions=2/3", "partitions=2/3"); addCase("select * from test.t4 where k1 > 1 or k2 < \"shanghai\"", "partitions=3/3", "partitions=3/3"); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/RangePartitionPruneTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/RangePartitionPruneTest.java index 8c60f543f4..f1bbc3ba91 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/analysis/RangePartitionPruneTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/RangePartitionPruneTest.java @@ -171,19 +171,19 @@ public class RangePartitionPruneTest extends PartitionPruneTestBase { addCase("select * from test.multi_not_null where k1 > 10 and k1 is null", "partitions=0/2", "partitions=0/2"); // others predicates combination addCase("select * from test.t2 where k1 > 10 and k2 < 4", "partitions=6/9", "partitions=6/9"); - addCase("select * from test.t2 where k1 >10 and k1 < 10 and (k1=11 or k1=12)", "partitions=0/9", "partitions=0/9"); + addCase("select * from test.t2 where k1 >10 and k1 < 10 and (k1=11 or k1=12)", "partitions=1/9", "partitions=0/9"); addCase("select * from test.t2 where k1 > 20 and k1 < 7 and k1 = 10", "partitions=0/9", "partitions=0/9"); // 4. Disjunctive predicates - addCase("select * from test.t2 where k1=10 or k1=23", "partitions=9/9", "partitions=3/9"); - addCase("select * from test.t2 where (k1=10 or k1=23) and (k2=4 or k2=5)", "partitions=9/9", "partitions=1/9"); - addCase("select * from test.t2 where (k1=10 or k1=23) and (k2=4 or k2=11)", "partitions=9/9", "partitions=2/9"); - addCase("select * from test.t2 where (k1=10 or k1=23) and (k2=3 or k2=4 or k2=11)", "partitions=9/9", "partitions=3/9"); - addCase("select * from test.t1 where dt=20211123 or dt=20211124", "partitions=8/8", "partitions=2/8"); + addCase("select * from test.t2 where k1=10 or k1=23", "partitions=3/9", "partitions=3/9"); + addCase("select * from test.t2 where (k1=10 or k1=23) and (k2=4 or k2=5)", "partitions=1/9", "partitions=1/9"); + addCase("select * from test.t2 where (k1=10 or k1=23) and (k2=4 or k2=11)", "partitions=2/9", "partitions=2/9"); + addCase("select * from test.t2 where (k1=10 or k1=23) and (k2=3 or k2=4 or k2=11)", "partitions=3/9", "partitions=3/9"); + addCase("select * from test.t1 where dt=20211123 or dt=20211124", "partitions=2/8", "partitions=2/8"); addCase("select * from test.t1 where ((dt=20211123 and k1=1) or (dt=20211125 and k1=3))", "partitions=8/8", "partitions=2/8"); // TODO: predicates are "PREDICATES: ((`dt` = 20211123 AND `k1` = 1) OR (`dt` = 20211125 AND `k1` = 3)), `k2` > ", // maybe something goes wrong with ExtractCommonFactorsRule. - addCase("select * from test.t1 where ((dt=20211123 and k1=1) or (dt=20211125 and k1=3)) and k2>0", "partitions=8/8", "partitions=8/8"); + addCase("select * from test.t1 where ((dt=20211123 and k1=1) or (dt=20211125 and k1=3)) and k2>0", "partitions=8/8", "partitions=2/8"); addCase("select * from test.t2 where k1 > 10 or k2 < 1", "partitions=9/9", "partitions=9/9"); // add some cases for CompoundPredicate addCase("select * from test.t1 where (dt >= 20211121 and dt <= 20211122) or (dt >= 20211123 and dt <= 20211125)", diff --git a/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java b/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java index 56b68df235..d196d19195 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java @@ -2219,4 +2219,25 @@ public class QueryPlanTest extends TestWithFeService { String explainString = getSQLPlanOrErrorMsg(queryBaseTableStr); Assert.assertTrue(explainString.contains("PREAGGREGATION: ON")); } + + @Test + public void testRewriteOrToIn() throws Exception { + connectContext.setDatabase("default_cluster:test"); + String sql = "SELECT * from test1 where query_time = 1 or query_time = 2 or query_time in (3, 4)"; + String explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql); + Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2, 3, 4)")); + + sql = "SELECT * from test1 where (query_time = 1 or query_time = 2) and query_time in (3, 4)"; + explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql); + Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2), `query_time` IN (3, 4)")); + + sql = "SELECT * from test1 where (query_time = 1 or query_time = 2 or scan_bytes = 2) and scan_bytes in (2, 3)"; + explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql); + Assert.assertTrue(explainString.contains("PREDICATES: (`query_time` IN (1, 2) OR `scan_bytes` = 2), `scan_bytes` IN (2, 3)")); + + sql = "SELECT * from test1 where (query_time = 1 or query_time = 2) and (scan_bytes = 2 or scan_bytes = 3)"; + explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql); + Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2), `scan_bytes` IN (2, 3)") + || explainString.contains("PREDICATES: `query_time` IN (1, 2), `scan_bytes` IN (3, 2)")); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/rewrite/ExtractCommonFactorsRuleFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/rewrite/ExtractCommonFactorsRuleFunctionTest.java index bb6807abee..588df5c6a6 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/rewrite/ExtractCommonFactorsRuleFunctionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/rewrite/ExtractCommonFactorsRuleFunctionTest.java @@ -83,7 +83,6 @@ public class ExtractCommonFactorsRuleFunctionTest { Assert.assertEquals(1, StringUtils.countMatches(planString, "`tb1`.`k1` = `tb2`.`k1`")); } - @Test public void testWideCommonFactorsWithOrPredicate() throws Exception { String query = "select * from tb1 where tb1.k1 > 1000 or tb1.k1 < 200 or tb1.k1 = 300"; diff --git a/regression-test/data/performance_p0/redundant_conjuncts.out b/regression-test/data/performance_p0/redundant_conjuncts.out index 98178f31aa..7dbabccf37 100644 --- a/regression-test/data/performance_p0/redundant_conjuncts.out +++ b/regression-test/data/performance_p0/redundant_conjuncts.out @@ -23,7 +23,7 @@ PLAN FRAGMENT 0 0:VOlapScanNode TABLE: default_cluster:regression_test_performance_p0.redundant_conjuncts(redundant_conjuncts), PREAGGREGATION: OFF. Reason: No AggregateInfo - PREDICATES: (`k1` = 1 OR `k1` = 2) + PREDICATES: `k1` IN (1, 2) partitions=0/1, tablets=0/0, tabletList= cardinality=0, avgRowSize=8.0, numNodes=1