[Query] Optimize where clause by extracting the common predicate in the OR compound predicate. (#3278)

Queries like below cannot finish in a acceptable time, `store_sales` has 2800w rows, `customer_address` has 5w rows, for now Doris will create only one cross join node to execute this sql, 
the time of eval the where clause is about 200-300 ns, the total count of eval will be  2800w * 5w, this is extremely large, and this will cost 2800w * 5w * 250 ns = 4 billion seconds;

```
select avg(ss_quantity)
       ,avg(ss_ext_sales_price)
       ,avg(ss_ext_wholesale_cost)
       ,sum(ss_ext_wholesale_cost)
 from store_sales, customer_address 
 where  ((ss_addr_sk = ca_address_sk
  and ca_country = 'United States'
  and ca_state in ('CO', 'IL', 'MN')
  and ss_net_profit between 100 and 200  
     ) or
     (ss_addr_sk = ca_address_sk
  and ca_country = 'United States'
  and ca_state in ('OH', 'MT', 'NM')
  and ss_net_profit between 150 and 300  
     ) or
     (ss_addr_sk = ca_address_sk
  and ca_country = 'United States'
  and ca_state in ('TX', 'MO', 'MI')
  and ss_net_profit between 50 and 250  
     ))
```

but this  sql can be rewrite to 
```
select avg(ss_quantity)
       ,avg(ss_ext_sales_price)
       ,avg(ss_ext_wholesale_cost)
       ,sum(ss_ext_wholesale_cost)
 from store_sales, customer_address 
 where ss_addr_sk = ca_address_sk
  and ca_country = 'United States' and (((ca_state in ('CO', 'IL', 'MN')
  and ss_net_profit between 100 and 200  
     ) or
     (ca_state in ('OH', 'MT', 'NM')
  and ss_net_profit between 150 and 300  
     ) or
     (ca_state in ('TX', 'MO', 'MI')
  and ss_net_profit between 50 and 250  
     ))
 )
```
there for  we can do a hash join first and then use 
```
(((ca_state in ('CO', 'IL', 'MN')
  and ss_net_profit between 100 and 200  
     ) or
     (ca_state in ('OH', 'MT', 'NM')
  and ss_net_profit between 150 and 300  
     ) or
     (ca_state in ('TX', 'MO', 'MI')
  and ss_net_profit between 50 and 250  
     ))
 )
```
to filter the value,

in TPCDS 10g dataset,  the rewritten sql only cost about 1 seconds.
This commit is contained in:
yangzhg
2020-04-09 21:57:45 +08:00
committed by GitHub
parent 3dc7ef634b
commit 8699bb7bd4
2 changed files with 324 additions and 5 deletions

View File

@ -54,6 +54,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@ -495,6 +496,10 @@ public class SelectStmt extends QueryStmt {
}
private void whereClauseRewrite() {
Expr deDuplicatedWhere = deduplicateOrs(whereClause);
if (deDuplicatedWhere != null) {
whereClause = deDuplicatedWhere;
}
if (whereClause instanceof IntLiteral) {
if (((IntLiteral) whereClause).getLongValue() == 0) {
whereClause = new BoolLiteral(false);
@ -504,6 +509,132 @@ public class SelectStmt extends QueryStmt {
}
}
/**
* this function only process (a and b and c) or (d and e and f) like clause,
* this function will extract this to [[a, b, c], [d, e, f]]
*/
private List<List<Expr>> extractDuplicateOrs(CompoundPredicate expr) {
List<List<Expr>> orExprs = new ArrayList<>();
for (Expr child : expr.getChildren()) {
if (child instanceof CompoundPredicate) {
CompoundPredicate childCp = (CompoundPredicate) child;
if (childCp.getOp() == CompoundPredicate.Operator.OR) {
orExprs.addAll(extractDuplicateOrs(childCp));
continue;
} else if (childCp.getOp() == CompoundPredicate.Operator.AND) {
orExprs.add(flatAndExpr(child));
continue;
}
}
orExprs.add(Arrays.asList(child));
}
return orExprs;
}
/**
* This function attempts to apply the inverse OR distributive law:
* ((A AND B) OR (A AND C)) => (A AND (B OR C))
* That is, locate OR clauses in which every subclause contains an
* identical term, and pull out the duplicated terms.
*/
private Expr deduplicateOrs(Expr expr) {
if (expr instanceof CompoundPredicate && ((CompoundPredicate) expr).getOp() == CompoundPredicate.Operator.OR) {
Expr rewritedExpr = processDuplicateOrs(extractDuplicateOrs((CompoundPredicate) expr));
if (rewritedExpr != null) {
return rewritedExpr;
}
} else {
for (int i = 0; i < expr.getChildren().size(); i++) {
Expr rewritedExpr = deduplicateOrs(expr.getChild(i));
if (rewritedExpr != null) {
expr.setChild(i, rewritedExpr);
}
}
}
return expr;
}
/**
* try to flat and , a and b and c => [a, b, c]
*/
private List<Expr> flatAndExpr(Expr expr) {
List<Expr> andExprs = new ArrayList<>();
if (expr instanceof CompoundPredicate && ((CompoundPredicate) expr).getOp() == CompoundPredicate.Operator.AND) {
andExprs.addAll(flatAndExpr(expr.getChild(0)));
andExprs.addAll(flatAndExpr(expr.getChild(1)));
} else {
andExprs.add(expr);
}
return andExprs;
}
/**
* the input is a list of list, the inner list is and connected exprs, the outer list is or connected
* for example clause (a and b and c) or (a and e and f) after extractDuplicateOrs will be [[a, b, c], [a, e, f]]
* this is the input of this function, first step is deduplicate [[a, b, c], [a, e, f]] => [[a], [b, c], [e, f]]
* then rebuild the expr to a and ((b and c) or (e and f))
*/
private Expr processDuplicateOrs(List<List<Expr>> exprs) {
if (exprs.size() < 2) {
return null;
}
// 1. remove duplicated elements [[a,a], [a, b], [a,b]] => [[a], [a,b]]
Set<Set<Expr>> set = new LinkedHashSet<>();
for (List<Expr> ex : exprs) {
Set<Expr> es = new LinkedHashSet<>();
es.addAll(ex);
set.add(es);
}
List<List<Expr>> clearExprs = new ArrayList<>();
for (Set<Expr> es : set) {
List<Expr> el = new ArrayList<>();
el.addAll(es);
clearExprs.add(el);
}
if (clearExprs.size() == 1) {
return makeCompound(clearExprs.get(0), CompoundPredicate.Operator.AND);
}
// 2. find duplcate cross the clause
List<Expr> cloneExprs = new ArrayList<>(clearExprs.get(0));
for (int i = 1; i < clearExprs.size(); ++i) {
cloneExprs.retainAll(clearExprs.get(i));
}
List<Expr> temp = new ArrayList<>();
if (CollectionUtils.isNotEmpty(cloneExprs)) {
temp.add(makeCompound(cloneExprs, CompoundPredicate.Operator.AND));
}
for (List<Expr> exprList : clearExprs) {
exprList.removeAll(cloneExprs);
temp.add(makeCompound(exprList, CompoundPredicate.Operator.AND));
}
// rebuild CompoundPredicate if found duplicate predicate will build (predcate) and (.. or ..) predicate in
// step 1: will build (.. or ..)
Expr result = CollectionUtils.isNotEmpty(cloneExprs) ? new CompoundPredicate(CompoundPredicate.Operator.AND,
temp.get(0), makeCompound(temp.subList(1, temp.size()), CompoundPredicate.Operator.OR))
: makeCompound(temp, CompoundPredicate.Operator.OR);
LOG.debug("rewrite ors: " + result.toSql());
return result;
}
/**
* Rebuild CompoundPredicate, [a, e, f] AND => a and e and f
*/
private Expr makeCompound(List<Expr> exprs, CompoundPredicate.Operator op) {
if (CollectionUtils.isEmpty(exprs)) {
return null;
}
if (exprs.size() == 1) {
return exprs.get(0);
}
CompoundPredicate result = new CompoundPredicate(op, exprs.get(0), exprs.get(1));
for (int i = 2; i < exprs.size(); ++i) {
result = new CompoundPredicate(op, result.clone(), exprs.get(i));
}
return result;
}
/**
* Generates and registers !empty() predicates to filter out empty collections directly
* in the parent scan of collection table refs. This is a performance optimization to

View File

@ -20,6 +20,7 @@ package org.apache.doris.analysis;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.rewrite.ExprRewriter;
import org.apache.doris.thrift.TPrimitiveType;
import org.apache.doris.utframe.DorisAssert;
import org.apache.doris.utframe.UtFrameUtils;
import org.junit.AfterClass;
@ -29,6 +30,7 @@ import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import java.io.IOException;
import java.util.UUID;
public class SelectStmtTest {
@ -89,10 +91,196 @@ public class SelectStmtTest {
"FROM db1.tbl1;";
SelectStmt stmt = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql1, ctx);
stmt.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
Assert.assertEquals("SELECT CASE WHEN `$a$1`.`$c$1` > `k4` THEN `$a$2`.`$c$2` ELSE `$a$3`.`$c$3` END" +
" AS `kk4` FROM `default_cluster:db1`.`tbl1` (SELECT count(*) / 2.0 AS `count(*) / 2.0` FROM " +
"`default_cluster:db1`.`tbl1`) $a$1 (SELECT avg(`k4`) AS `avg(``k4``)` FROM" +
" `default_cluster:db1`.`tbl1`) $a$2 (SELECT sum(`k4`) AS `sum(``k4``)` " +
"FROM `default_cluster:db1`.`tbl1`) $a$3", stmt.toSql());
Assert.assertTrue(stmt.toSql().contains("`$a$1`.`$c$1` > `k4` THEN `$a$2`.`$c$2` ELSE `$a$3`.`$c$3`"));
}
@Test
public void testDeduplicateOrs() throws Exception {
ConnectContext ctx = UtFrameUtils.createDefaultCtx();
String sql = "select\n" +
" avg(t1.k4)\n" +
"from\n" +
" db1.tbl1 t1,\n" +
" db1.tbl1 t2,\n" +
" db1.tbl1 t3,\n" +
" db1.tbl1 t4,\n" +
" db1.tbl1 t5,\n" +
" db1.tbl1 t6\n" +
"where\n" +
" t2.k1 = t1.k1\n" +
" and t1.k2 = t6.k2\n" +
" and t6.k4 = 2001\n" +
" and(\n" +
" (\n" +
" t1.k2 = t4.k2\n" +
" and t3.k3 = t1.k3\n" +
" and t3.k1 = 'D'\n" +
" and t4.k3 = '2 yr Degree'\n" +
" and t1.k4 between 100.00\n" +
" and 150.00\n" +
" and t4.k4 = 3\n" +
" )\n" +
" or (\n" +
" t1.k2 = t4.k2\n" +
" and t3.k3 = t1.k3\n" +
" and t3.k1 = 'S'\n" +
" and t4.k3 = 'Secondary'\n" +
" and t1.k4 between 50.00\n" +
" and 100.00\n" +
" and t4.k4 = 1\n" +
" )\n" +
" or (\n" +
" t1.k2 = t4.k2\n" +
" and t3.k3 = t1.k3\n" +
" and t3.k1 = 'W'\n" +
" and t4.k3 = 'Advanced Degree'\n" +
" and t1.k4 between 150.00\n" +
" and 200.00\n" +
" and t4.k4 = 1\n" +
" )\n" +
" )\n" +
" and(\n" +
" (\n" +
" t1.k1 = t5.k1\n" +
" and t5.k2 = 'United States'\n" +
" and t5.k3 in ('CO', 'IL', 'MN')\n" +
" and t1.k4 between 100\n" +
" and 200\n" +
" )\n" +
" or (\n" +
" t1.k1 = t5.k1\n" +
" and t5.k2 = 'United States'\n" +
" and t5.k3 in ('OH', 'MT', 'NM')\n" +
" and t1.k4 between 150\n" +
" and 300\n" +
" )\n" +
" or (\n" +
" t1.k1 = t5.k1\n" +
" and t5.k2 = 'United States'\n" +
" and t5.k3 in ('TX', 'MO', 'MI')\n" +
" and t1.k4 between 50 and 250\n" +
" )\n" +
" );";
SelectStmt stmt = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql, ctx);
stmt.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
String rewritedFragment1 = "(((`t1`.`k2` = `t4`.`k2`) AND (`t3`.`k3` = `t1`.`k3`)) AND ((((((`t3`.`k1` = 'D')" +
" AND (`t4`.`k3` = '2 yr Degree')) AND ((`t1`.`k4` >= 100.00) AND (`t1`.`k4` <= 150.00))) AND" +
" (`t4`.`k4` = 3)) OR ((((`t3`.`k1` = 'S') AND (`t4`.`k3` = 'Secondary')) AND ((`t1`.`k4` >= 50.00)" +
" AND (`t1`.`k4` <= 100.00))) AND (`t4`.`k4` = 1))) OR ((((`t3`.`k1` = 'W') AND " +
"(`t4`.`k3` = 'Advanced Degree')) AND ((`t1`.`k4` >= 150.00) AND (`t1`.`k4` <= 200.00)))" +
" AND (`t4`.`k4` = 1))))";
String rewritedFragment2 = "(((`t1`.`k1` = `t5`.`k1`) AND (`t5`.`k2` = 'United States')) AND" +
" ((((`t5`.`k3` IN ('CO', 'IL', 'MN')) AND ((`t1`.`k4` >= 100) AND (`t1`.`k4` <= 200)))" +
" OR ((`t5`.`k3` IN ('OH', 'MT', 'NM')) AND ((`t1`.`k4` >= 150) AND (`t1`.`k4` <= 300))))" +
" OR ((`t5`.`k3` IN ('TX', 'MO', 'MI')) AND ((`t1`.`k4` >= 50) AND (`t1`.`k4` <= 250)))))";
Assert.assertTrue(stmt.toSql().contains(rewritedFragment1));
Assert.assertTrue(stmt.toSql().contains(rewritedFragment2));
String sql2 = "select\n" +
" avg(t1.k4)\n" +
"from\n" +
" db1.tbl1 t1,\n" +
" db1.tbl1 t2\n" +
"where\n" +
"(\n" +
" t1.k1 = t2.k3\n" +
" and t2.k2 = 'United States'\n" +
" and t2.k3 in ('CO', 'IL', 'MN')\n" +
" and t1.k4 between 100\n" +
" and 200\n" +
")\n" +
"or (\n" +
" t1.k1 = t2.k1\n" +
" and t2.k2 = 'United States1'\n" +
" and t2.k3 in ('OH', 'MT', 'NM')\n" +
" and t1.k4 between 150\n" +
" and 300\n" +
")\n" +
"or (\n" +
" t1.k1 = t2.k1\n" +
" and t2.k2 = 'United States'\n" +
" and t2.k3 in ('TX', 'MO', 'MI')\n" +
" and t1.k4 between 50 and 250\n" +
")";
SelectStmt stmt2 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql2, ctx);
stmt2.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
String fragment3 = "(((((`t1`.`k1` = `t2`.`k3`) AND (`t2`.`k2` = 'United States')) AND " +
"(`t2`.`k3` IN ('CO', 'IL', 'MN'))) AND ((`t1`.`k4` >= 100) AND (`t1`.`k4` <= 200))) OR" +
" ((((`t1`.`k1` = `t2`.`k1`) AND (`t2`.`k2` = 'United States1')) AND (`t2`.`k3` IN ('OH', 'MT', 'NM')))" +
" AND ((`t1`.`k4` >= 150) AND (`t1`.`k4` <= 300)))) OR ((((`t1`.`k1` = `t2`.`k1`) AND " +
"(`t2`.`k2` = 'United States')) AND (`t2`.`k3` IN ('TX', 'MO', 'MI'))) AND ((`t1`.`k4` >= 50)" +
" AND (`t1`.`k4` <= 250)))";
Assert.assertTrue(stmt2.toSql().contains(fragment3));
String sql3 = "select\n" +
" avg(t1.k4)\n" +
"from\n" +
" db1.tbl1 t1,\n" +
" db1.tbl1 t2\n" +
"where\n" +
" t1.k1 = t2.k3 or t1.k1 = t2.k3 or t1.k1 = t2.k3";
SelectStmt stmt3 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql3, ctx);
stmt3.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
Assert.assertFalse(stmt3.toSql().contains("((`t1`.`k1` = `t2`.`k3`) OR (`t1`.`k1` = `t2`.`k3`)) OR" +
" (`t1`.`k1` = `t2`.`k3`)"));
String sql4 = "select\n" +
" avg(t1.k4)\n" +
"from\n" +
" db1.tbl1 t1,\n" +
" db1.tbl1 t2\n" +
"where\n" +
" t1.k1 = t2.k2 or t1.k1 = t2.k3 or t1.k1 = t2.k3";
SelectStmt stmt4 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql4, ctx);
stmt4.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
Assert.assertTrue(stmt4.toSql().contains("(`t1`.`k1` = `t2`.`k2`) OR (`t1`.`k1` = `t2`.`k3`)"));
String sql5 = "select\n" +
" avg(t1.k4)\n" +
"from\n" +
" db1.tbl1 t1,\n" +
" db1.tbl1 t2\n" +
"where\n" +
" t2.k1 is not null or t1.k1 is not null or t1.k1 is not null";
SelectStmt stmt5 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql5, ctx);
stmt5.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
Assert.assertTrue(stmt5.toSql().contains("(`t2`.`k1` IS NOT NULL) OR (`t1`.`k1` IS NOT NULL)"));
Assert.assertEquals(2, stmt5.toSql().split(" OR ").length);
String sql6 = "select\n" +
" avg(t1.k4)\n" +
"from\n" +
" db1.tbl1 t1,\n" +
" db1.tbl1 t2\n" +
"where\n" +
" t2.k1 is not null or t1.k1 is not null and t1.k1 is not null";
SelectStmt stmt6 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql6, ctx);
stmt6.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
Assert.assertTrue(stmt6.toSql().contains("(`t2`.`k1` IS NOT NULL) OR (`t1`.`k1` IS NOT NULL)"));
Assert.assertEquals(2, stmt6.toSql().split(" OR ").length);
String sql7 = "select\n" +
" avg(t1.k4)\n" +
"from\n" +
" db1.tbl1 t1,\n" +
" db1.tbl1 t2\n" +
"where\n" +
" t2.k1 is not null or t1.k1 is not null and t1.k2 is not null";
SelectStmt stmt7 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql7, ctx);
stmt7.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
Assert.assertTrue(stmt7.toSql().contains("(`t2`.`k1` IS NOT NULL) OR ((`t1`.`k1` IS NOT NULL) " +
"AND (`t1`.`k2` IS NOT NULL))"));
String sql8 = "select\n" +
" avg(t1.k4)\n" +
"from\n" +
" db1.tbl1 t1,\n" +
" db1.tbl1 t2\n" +
"where\n" +
" t2.k1 is not null and t1.k1 is not null and t1.k1 is not null";
SelectStmt stmt8 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql8, ctx);
stmt8.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
Assert.assertTrue(stmt8.toSql().contains("((`t2`.`k1` IS NOT NULL) AND (`t1`.`k1` IS NOT NULL))" +
" AND (`t1`.`k1` IS NOT NULL)"));
}
}