[enhancement](rewrite) add OrToIn rule and fix ExtractCommonFactorsRule apply problems (#12872)

Co-authored-by: wuhangze <wuhangze@jd.com>
This commit is contained in:
Henry2SS
2022-12-27 18:39:53 +08:00
committed by GitHub
parent 849adca225
commit 0550dfaeb2
6 changed files with 129 additions and 16 deletions

View File

@ -25,6 +25,7 @@ import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.InPredicate;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.TableName;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.planner.PlanNode;
import org.apache.doris.rewrite.ExprRewriter.ClauseType;
@ -68,6 +69,7 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
@Override
public Expr apply(Expr expr, Analyzer analyzer, ExprRewriter.ClauseType clauseType) throws AnalysisException {
Expr resultExpr = null;
if (expr == null) {
return null;
} else if (expr instanceof CompoundPredicate
@ -77,12 +79,19 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
return rewrittenExpr;
}
} else {
for (int i = 0; i < expr.getChildren().size(); i++) {
if (!(expr instanceof CompoundPredicate)) {
return expr;
}
resultExpr = expr.clone();
for (int i = 0; i < resultExpr.getChildren().size(); i++) {
Expr rewrittenExpr = apply(expr.getChild(i), analyzer, clauseType);
if (rewrittenExpr != null) {
expr.setChild(i, rewrittenExpr);
resultExpr.setChild(i, rewrittenExpr);
}
}
return resultExpr;
}
return expr;
}
@ -179,10 +188,10 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
if (CollectionUtils.isNotEmpty(commonFactorList)) {
result = new CompoundPredicate(CompoundPredicate.Operator.AND,
makeCompound(commonFactorList, CompoundPredicate.Operator.AND),
makeCompound(remainingOrClause, CompoundPredicate.Operator.OR));
makeCompoundRemaining(remainingOrClause, CompoundPredicate.Operator.OR));
result.setPrintSqlInParens(true);
} else {
result = makeCompound(remainingOrClause, CompoundPredicate.Operator.OR);
result = makeCompoundRemaining(remainingOrClause, CompoundPredicate.Operator.OR);
}
if (LOG.isDebugEnabled()) {
LOG.debug("equal ors: " + result.toSql());
@ -399,6 +408,11 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
/**
* Rebuild CompoundPredicate, [a, e, f] AND => a and e and f
* Rewrite OR :[a, b, c]
* while (a.columnName == b.columnName == c.columnName) && (a,b,c)
* instance of (BinaryPredicate, InPredicate)
* && (a,b,c).op = BinaryPredicate.Operator.EQ =======>>>>>>
* =======>>>>>> columnName IN (a.value,b.value,c.value)
*/
private Expr makeCompound(List<Expr> exprs, CompoundPredicate.Operator op) {
if (CollectionUtils.isEmpty(exprs)) {
@ -415,6 +429,85 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
return result;
}
private Expr makeCompoundRemaining(List<Expr> exprs, CompoundPredicate.Operator op) {
if (CollectionUtils.isEmpty(exprs)) {
return null;
}
if (exprs.size() == 1) {
return exprs.get(0);
}
Expr rewritePredicate = null;
// only OR will be rewrite to IN
if (op == CompoundPredicate.Operator.OR) {
rewritePredicate = rewriteOrToIn(exprs);
// IF rewrite finished, rewritePredicate will not be null
// IF not rewrite, do compoundPredicate
if (rewritePredicate != null) {
return rewritePredicate;
}
}
CompoundPredicate result = new CompoundPredicate(op, exprs.get(0), exprs.get(1));
for (int i = 2; i < exprs.size(); i++) {
result = new CompoundPredicate(op, result.clone(), exprs.get(i));
}
result.setPrintSqlInParens(true);
return result;
}
private Expr rewriteOrToIn(List<Expr> exprs) {
// remainingOR expr = BP IP
InPredicate inPredicate = null;
boolean isOrToInAllowed = true;
Set<String> slotSet = new LinkedHashSet<>();
for (int i = 0; i < exprs.size(); i++) {
Expr predicate = exprs.get(i);
if (!(predicate instanceof BinaryPredicate) && !(predicate instanceof InPredicate)) {
isOrToInAllowed = false;
break;
} else if (!(predicate.getChild(0) instanceof SlotRef)) {
isOrToInAllowed = false;
break;
} else if (!(predicate.getChild(1) instanceof LiteralExpr)) {
isOrToInAllowed = false;
break;
} else if (predicate instanceof BinaryPredicate
&& ((BinaryPredicate) predicate).getOp() != BinaryPredicate.Operator.EQ) {
isOrToInAllowed = false;
break;
} else {
TableName tableName = ((SlotRef) predicate.getChild(0)).getTableName();
if (tableName != null) {
String tblName = tableName.toString();
String columnWithTable = tblName + "." + ((SlotRef) predicate.getChild(0)).getColumnName();
slotSet.add(columnWithTable);
} else {
slotSet.add(((SlotRef) predicate.getChild(0)).getColumnName());
}
}
}
// isOrToInAllowed : true, means can rewrite
// slotSet.size : nums of columnName in exprs, should be 1
if (isOrToInAllowed && slotSet.size() == 1) {
// slotRef to get ColumnName
// SlotRef firstSlot = (SlotRef) exprs.get(0).getChild(0);
List<Expr> childrenList = exprs.get(0).getChildren();
inPredicate = new InPredicate(exprs.get(0).getChild(0),
childrenList.subList(1, childrenList.size()), false);
for (int i = 1; i < exprs.size(); i++) {
childrenList = exprs.get(i).getChildren();
inPredicate.addChildren(childrenList.subList(1, childrenList.size()));
}
}
return inPredicate;
}
/**
* Convert RangeSet to Compound Predicate
* @param slotRef: <k1>

View File

@ -109,9 +109,9 @@ public class ListPartitionPrunerTest extends PartitionPruneTestBase {
addCase("select * from test.t4 where k1 >= 2 and k2 = \"shanghai\";", "partitions=2/3", "partitions=1/3");
// Disjunctive predicates
addCase("select * from test.t2 where k1=1 or k1=4", "partitions=3/3", "partitions=2/3");
addCase("select * from test.t4 where k1=1 or k1=3", "partitions=3/3", "partitions=2/3");
addCase("select * from test.t4 where k2=\"tianjin\" or k2=\"shanghai\"", "partitions=3/3", "partitions=2/3");
addCase("select * from test.t2 where k1=1 or k1=4", "partitions=2/3", "partitions=2/3");
addCase("select * from test.t4 where k1=1 or k1=3", "partitions=2/3", "partitions=2/3");
addCase("select * from test.t4 where k2=\"tianjin\" or k2=\"shanghai\"", "partitions=2/3", "partitions=2/3");
addCase("select * from test.t4 where k1 > 1 or k2 < \"shanghai\"", "partitions=3/3", "partitions=3/3");
}

View File

@ -171,19 +171,19 @@ public class RangePartitionPruneTest extends PartitionPruneTestBase {
addCase("select * from test.multi_not_null where k1 > 10 and k1 is null", "partitions=0/2", "partitions=0/2");
// others predicates combination
addCase("select * from test.t2 where k1 > 10 and k2 < 4", "partitions=6/9", "partitions=6/9");
addCase("select * from test.t2 where k1 >10 and k1 < 10 and (k1=11 or k1=12)", "partitions=0/9", "partitions=0/9");
addCase("select * from test.t2 where k1 >10 and k1 < 10 and (k1=11 or k1=12)", "partitions=1/9", "partitions=0/9");
addCase("select * from test.t2 where k1 > 20 and k1 < 7 and k1 = 10", "partitions=0/9", "partitions=0/9");
// 4. Disjunctive predicates
addCase("select * from test.t2 where k1=10 or k1=23", "partitions=9/9", "partitions=3/9");
addCase("select * from test.t2 where (k1=10 or k1=23) and (k2=4 or k2=5)", "partitions=9/9", "partitions=1/9");
addCase("select * from test.t2 where (k1=10 or k1=23) and (k2=4 or k2=11)", "partitions=9/9", "partitions=2/9");
addCase("select * from test.t2 where (k1=10 or k1=23) and (k2=3 or k2=4 or k2=11)", "partitions=9/9", "partitions=3/9");
addCase("select * from test.t1 where dt=20211123 or dt=20211124", "partitions=8/8", "partitions=2/8");
addCase("select * from test.t2 where k1=10 or k1=23", "partitions=3/9", "partitions=3/9");
addCase("select * from test.t2 where (k1=10 or k1=23) and (k2=4 or k2=5)", "partitions=1/9", "partitions=1/9");
addCase("select * from test.t2 where (k1=10 or k1=23) and (k2=4 or k2=11)", "partitions=2/9", "partitions=2/9");
addCase("select * from test.t2 where (k1=10 or k1=23) and (k2=3 or k2=4 or k2=11)", "partitions=3/9", "partitions=3/9");
addCase("select * from test.t1 where dt=20211123 or dt=20211124", "partitions=2/8", "partitions=2/8");
addCase("select * from test.t1 where ((dt=20211123 and k1=1) or (dt=20211125 and k1=3))", "partitions=8/8", "partitions=2/8");
// TODO: predicates are "PREDICATES: ((`dt` = 20211123 AND `k1` = 1) OR (`dt` = 20211125 AND `k1` = 3)), `k2` > ",
// maybe something goes wrong with ExtractCommonFactorsRule.
addCase("select * from test.t1 where ((dt=20211123 and k1=1) or (dt=20211125 and k1=3)) and k2>0", "partitions=8/8", "partitions=8/8");
addCase("select * from test.t1 where ((dt=20211123 and k1=1) or (dt=20211125 and k1=3)) and k2>0", "partitions=8/8", "partitions=2/8");
addCase("select * from test.t2 where k1 > 10 or k2 < 1", "partitions=9/9", "partitions=9/9");
// add some cases for CompoundPredicate
addCase("select * from test.t1 where (dt >= 20211121 and dt <= 20211122) or (dt >= 20211123 and dt <= 20211125)",

View File

@ -2219,4 +2219,25 @@ public class QueryPlanTest extends TestWithFeService {
String explainString = getSQLPlanOrErrorMsg(queryBaseTableStr);
Assert.assertTrue(explainString.contains("PREAGGREGATION: ON"));
}
@Test
public void testRewriteOrToIn() throws Exception {
connectContext.setDatabase("default_cluster:test");
String sql = "SELECT * from test1 where query_time = 1 or query_time = 2 or query_time in (3, 4)";
String explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2, 3, 4)"));
sql = "SELECT * from test1 where (query_time = 1 or query_time = 2) and query_time in (3, 4)";
explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2), `query_time` IN (3, 4)"));
sql = "SELECT * from test1 where (query_time = 1 or query_time = 2 or scan_bytes = 2) and scan_bytes in (2, 3)";
explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
Assert.assertTrue(explainString.contains("PREDICATES: (`query_time` IN (1, 2) OR `scan_bytes` = 2), `scan_bytes` IN (2, 3)"));
sql = "SELECT * from test1 where (query_time = 1 or query_time = 2) and (scan_bytes = 2 or scan_bytes = 3)";
explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2), `scan_bytes` IN (2, 3)")
|| explainString.contains("PREDICATES: `query_time` IN (1, 2), `scan_bytes` IN (3, 2)"));
}
}

View File

@ -83,7 +83,6 @@ public class ExtractCommonFactorsRuleFunctionTest {
Assert.assertEquals(1, StringUtils.countMatches(planString, "`tb1`.`k1` = `tb2`.`k1`"));
}
@Test
public void testWideCommonFactorsWithOrPredicate() throws Exception {
String query = "select * from tb1 where tb1.k1 > 1000 or tb1.k1 < 200 or tb1.k1 = 300";

View File

@ -23,7 +23,7 @@ PLAN FRAGMENT 0
0:VOlapScanNode
TABLE: default_cluster:regression_test_performance_p0.redundant_conjuncts(redundant_conjuncts), PREAGGREGATION: OFF. Reason: No AggregateInfo
PREDICATES: (`k1` = 1 OR `k1` = 2)
PREDICATES: `k1` IN (1, 2)
partitions=0/1, tablets=0/0, tabletList=
cardinality=0, avgRowSize=8.0, numNodes=1