[SQL Function] Calculate 'case when expr' when possible (#3396)

Calculate 'case when expr' when possible
This commit is contained in:
wangbo
2020-05-07 22:04:09 +08:00
committed by GitHub
parent 94539e7120
commit d60bb81cb0
3 changed files with 217 additions and 27 deletions

View File

@ -26,6 +26,7 @@ import org.apache.doris.thrift.TExprNodeType;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
/**
@ -101,6 +102,7 @@ public class CaseExpr extends Expr {
CaseExpr expr = (CaseExpr) obj;
return hasCaseExpr == expr.hasCaseExpr && hasElseExpr == expr.hasElseExpr;
}
public boolean hasCaseExpr() {
return hasCaseExpr;
}
@ -251,4 +253,84 @@ public class CaseExpr extends Expr {
}
return exprs;
}
// this method just compare literal value and not completely consistent with be,for two cases
// 1 not deal float
// 2 just compare literal value with same type. for a example sql 'select case when 123 then '1' else '2' end as col'
// for be will return '1', because be only regard 0 as false
// but for current LiteralExpr.compareLiteral, `123`' won't be regard as true
// the case which two values has different type left to be
public static Expr computeCaseExpr(CaseExpr expr) {
LiteralExpr caseExpr;
int startIndex = 0;
int endIndex = expr.getChildren().size();
if (expr.hasCaseExpr()) {
// just deal literal here
// and avoid `float compute` in java,float should be dealt in be
Expr caseChildExpr = expr.getChild(0);
if (!caseChildExpr.isLiteral()
|| caseChildExpr instanceof DecimalLiteral || caseChildExpr instanceof FloatLiteral) {
return expr;
}
caseExpr = (LiteralExpr) expr.getChild(0);
startIndex++;
} else {
caseExpr = new BoolLiteral(true);
}
if (caseExpr instanceof NullLiteral) {
if (expr.hasElseExpr) {
return expr.getChild(expr.getChildren().size() - 1);
} else {
return new NullLiteral();
}
}
if (expr.hasElseExpr) {
endIndex--;
}
// early return when the `when expr` can't be converted to constants
Expr startExpr = expr.getChild(startIndex);
if ((!startExpr.isLiteral() || startExpr instanceof DecimalLiteral || startExpr instanceof FloatLiteral)
|| (!(startExpr instanceof NullLiteral) && !startExpr.getClass().toString().equals(caseExpr.getClass().toString()))) {
return expr;
}
for (int i = startIndex; i < endIndex; i = i + 2) {
Expr currentWhenExpr = expr.getChild(i);
// skip null literal
if (currentWhenExpr instanceof NullLiteral) {
continue;
}
// stop convert in three cases
// 1 not literal
// 2 float
// 3 `case expr` and `when expr` don't have same type
if ((!currentWhenExpr.isLiteral() || currentWhenExpr instanceof DecimalLiteral || currentWhenExpr instanceof FloatLiteral)
|| !currentWhenExpr.getClass().toString().equals(caseExpr.getClass().toString())) {
// remove the expr which has been evaluated
List<Expr> exprLeft = new ArrayList<>();
if (expr.hasCaseExpr()) {
exprLeft.add(caseExpr);
}
for (int j = i; j < expr.getChildren().size(); j++) {
exprLeft.add(expr.getChild(j));
}
Expr retCaseExpr = expr.clone();
retCaseExpr.getChildren().clear();
retCaseExpr.addChildren(exprLeft);
return retCaseExpr;
} else if (caseExpr.compareLiteral((LiteralExpr) currentWhenExpr) == 0) {
return expr.getChild(i + 1);
}
}
if (expr.hasElseExpr) {
return expr.getChild(expr.getChildren().size() - 1);
} else {
return new NullLiteral();
}
}
}

View File

@ -19,6 +19,7 @@ package org.apache.doris.rewrite;
import org.apache.doris.analysis.Analyzer;
import org.apache.doris.analysis.CaseExpr;
import org.apache.doris.analysis.CastExpr;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.NullLiteral;
@ -48,6 +49,11 @@ public class FoldConstantsRule implements ExprRewriteRule {
@Override
public Expr apply(Expr expr, Analyzer analyzer) throws AnalysisException {
// evaluate `case when expr` when possible
if (expr instanceof CaseExpr) {
return CaseExpr.computeCaseExpr((CaseExpr) expr);
}
// Avoid calling Expr.isConstant() because that would lead to repeated traversals
// of the Expr tree. Assumes the bottom-up application of this rule. Constant
// children should have been folded at this point.

View File

@ -17,6 +17,7 @@
package org.apache.doris.planner;
import org.apache.commons.lang3.StringUtils;
import org.apache.doris.analysis.CreateDbStmt;
import org.apache.doris.analysis.CreateTableStmt;
import org.apache.doris.analysis.DropDbStmt;
@ -242,31 +243,31 @@ public class QueryPlanTest {
"PROPERTIES (\n" +
" \"replication_num\" = \"1\"\n" +
");");
createTable("CREATE TABLE test.`pushdown_test` (\n" +
" `k1` tinyint(4) NULL COMMENT \"\",\n" +
" `k2` smallint(6) NULL COMMENT \"\",\n" +
" `k3` int(11) NULL COMMENT \"\",\n" +
" `k4` bigint(20) NULL COMMENT \"\",\n" +
" `k5` decimal(9, 3) NULL COMMENT \"\",\n" +
" `k6` char(5) NULL COMMENT \"\",\n" +
" `k10` date NULL COMMENT \"\",\n" +
" `k11` datetime NULL COMMENT \"\",\n" +
" `k7` varchar(20) NULL COMMENT \"\",\n" +
" `k8` double MAX NULL COMMENT \"\",\n" +
" `k9` float SUM NULL COMMENT \"\"\n" +
") ENGINE=OLAP\n" +
"AGGREGATE KEY(`k1`, `k2`, `k3`, `k4`, `k5`, `k6`, `k10`, `k11`, `k7`)\n" +
"COMMENT \"OLAP\"\n" +
"PARTITION BY RANGE(`k1`)\n" +
"(PARTITION p1 VALUES [(\"-128\"), (\"-64\")),\n" +
"PARTITION p2 VALUES [(\"-64\"), (\"0\")),\n" +
"PARTITION p3 VALUES [(\"0\"), (\"64\")))\n" +
"DISTRIBUTED BY HASH(`k1`) BUCKETS 5\n" +
"PROPERTIES (\n" +
"\"replication_num\" = \"1\",\n" +
"\"in_memory\" = \"false\",\n" +
"\"storage_format\" = \"DEFAULT\"\n" +
" `k1` tinyint(4) NULL COMMENT \"\",\n" +
" `k2` smallint(6) NULL COMMENT \"\",\n" +
" `k3` int(11) NULL COMMENT \"\",\n" +
" `k4` bigint(20) NULL COMMENT \"\",\n" +
" `k5` decimal(9, 3) NULL COMMENT \"\",\n" +
" `k6` char(5) NULL COMMENT \"\",\n" +
" `k10` date NULL COMMENT \"\",\n" +
" `k11` datetime NULL COMMENT \"\",\n" +
" `k7` varchar(20) NULL COMMENT \"\",\n" +
" `k8` double MAX NULL COMMENT \"\",\n" +
" `k9` float SUM NULL COMMENT \"\"\n" +
") ENGINE=OLAP\n" +
"AGGREGATE KEY(`k1`, `k2`, `k3`, `k4`, `k5`, `k6`, `k10`, `k11`, `k7`)\n" +
"COMMENT \"OLAP\"\n" +
"PARTITION BY RANGE(`k1`)\n" +
"(PARTITION p1 VALUES [(\"-128\"), (\"-64\")),\n" +
"PARTITION p2 VALUES [(\"-64\"), (\"0\")),\n" +
"PARTITION p3 VALUES [(\"0\"), (\"64\")))\n" +
"DISTRIBUTED BY HASH(`k1`) BUCKETS 5\n" +
"PROPERTIES (\n" +
"\"replication_num\" = \"1\",\n" +
"\"in_memory\" = \"false\",\n" +
"\"storage_format\" = \"DEFAULT\"\n" +
");");
}
@ -711,18 +712,119 @@ public class QueryPlanTest {
Assert.assertTrue(explainString.contains("PREDICATES: `join1`.`id` > 1"));
Assert.assertFalse(explainString.contains("PREDICATES: `join2`.`id` > 1"));
}
@Test
public void testConvertCaseWhenToConstant() throws Exception {
// basic test
String caseWhenSql = "select "
+ "case when date_format(now(),'%H%i') < 123 then 1 else 0 end as col "
+ "from test.test1 "
+ "where time = case when date_format(now(),'%H%i') < 123 then date_format(date_sub(now(),2),'%Y%m%d') else date_format(date_sub(now(),1),'%Y%m%d') end";
Assert.assertTrue(!StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + caseWhenSql), "CASE WHEN"));
// test 1: case when then
// 1.1 multi when in on `case when` and can be converted to constants
String sql11 = "select case when false then 2 when true then 3 else 0 end as col11;";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql11), "constant exprs: \n 3"));
// 1.2 multi `when expr` in on `case when` ,`when expr` can not be converted to constants
String sql121 = "select case when false then 2 when substr(k7,2,1) then 3 else 0 end as col121 from test.baseall";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql121),
"OUTPUT EXPRS:CASE WHEN substr(`k7`, 2, 1) THEN 3 ELSE 0 END"));
// 1.2.2 when expr which can not be converted to constants in the first
String sql122 = "select case when substr(k7,2,1) then 2 when false then 3 else 0 end as col122 from test.baseall";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql122),
"OUTPUT EXPRS:CASE WHEN substr(`k7`, 2, 1) THEN 2 WHEN FALSE THEN 3 ELSE 0 END"));
// 1.2.3 test return `then expr` in the middle
String sql124 = "select case when false then 1 when true then 2 when false then 3 else 'other' end as col124";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql124), "constant exprs: \n '2'"));
// 1.3 test return null
String sql3 = "select case when false then 2 end as col3";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql3), "constant exprs: \n NULL"));
// 1.3.1 test return else expr
String sql131 = "select case when false then 2 when false then 3 else 4 end as col131";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql131), "constant exprs: \n 4"));
// 1.4 nest `case when` and can be converted to constants
String sql14 = "select case when (case when true then true else false end) then 2 when false then 3 else 0 end as col";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql14), "constant exprs: \n 2"));
// 1.5 nest `case when` and can not be converted to constants
String sql15 = "select case when case when substr(k7,2,1) then true else false end then 2 when false then 3 else 0 end as col from test.baseall";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql15),
"OUTPUT EXPRS:CASE WHEN CASE WHEN substr(`k7`, 2, 1) THEN TRUE ELSE FALSE END THEN 2 WHEN FALSE THEN 3 ELSE 0 END"));
// 1.6 test when expr is null
String sql16 = "select case when null then 1 else 2 end as col16;";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql16), "constant exprs: \n 2"));
// test 2: case xxx when then
// 2.1 test equal
String sql2 = "select case 1 when 1 then 'a' when 2 then 'b' else 'other' end as col2;";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql2), "constant exprs: \n 'a'"));
// 2.1.2 test not equal
String sql212 = "select case 'a' when 1 then 'a' when 'a' then 'b' else 'other' end as col212;";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql212), "constant exprs: \n 'b'"));
// 2.2 test return null
String sql22 = "select case 'a' when 1 then 'a' when 'b' then 'b' end as col22;";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql22), "constant exprs: \n NULL"));
// 2.2.2 test return else
String sql222 = "select case 1 when 2 then 'a' when 3 then 'b' else 'other' end as col222;";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql222), "constant exprs: \n 'other'"));
// 2.3 test can not convert to constant,middle when expr is not constant
String sql23 = "select case 'a' when 'b' then 'a' when substr(k7,2,1) then 2 when false then 3 else 0 end as col23 from test.baseall";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql23),
"OUTPUT EXPRS:CASE'a' WHEN substr(`k7`, 2, 1) THEN '2' WHEN '0' THEN '3' ELSE '0' END"));
// 2.3.1 first when expr is not constant
String sql231 = "select case 'a' when substr(k7,2,1) then 2 when 1 then 'a' when false then 3 else 0 end as col231 from test.baseall";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql231),
"OUTPUT EXPRS:CASE'a' WHEN substr(`k7`, 2, 1) THEN '2' WHEN '1' THEN 'a' WHEN '0' THEN '3' ELSE '0' END"));
// 2.3.2 case expr is not constant
String sql232 = "select case k1 when substr(k7,2,1) then 2 when 1 then 'a' when false then 3 else 0 end as col232 from test.baseall";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql232),
"OUTPUT EXPRS:CASE`k1` WHEN substr(`k7`, 2, 1) THEN '2' WHEN '1' THEN 'a' WHEN '0' THEN '3' ELSE '0' END"));
// 3.1 test float,float in case expr
String sql31 = "select case cast(100 as float) when 1 then 'a' when 2 then 'b' else 'other' end as col31;";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql31),
"constant exprs: \n CASE100.0 WHEN 1.0 THEN 'a' WHEN 2.0 THEN 'b' ELSE 'other' END"));
// 4.1 test null in case expr return else
String sql41 = "select case null when 1 then 'a' when 2 then 'b' else 'other' end as col41";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql41), "constant exprs: \n 'other'"));
// 4.1.2 test null in case expr return null
String sql412 = "select case null when 1 then 'a' when 2 then 'b' end as col41";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql412), "constant exprs: \n NULL"));
// 4.2.1 test null in when expr
String sql421 = "select case 'a' when null then 'a' else 'other' end as col421";
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql421), "constant exprs: \n 'other'"));
}
@Test
public void testJoinPredicateTransitivityWithSubqueryInWhereClause() throws Exception {
connectContext.setDatabase("default_cluster:test");
String sql = "SELECT *\n" +
String sql = "SELECT *\n" +
"FROM test.pushdown_test\n" +
"WHERE 0 < (\n" +
" SELECT MAX(k9)\n" +
" SELECT MAX(k9)\n" +
" FROM test.pushdown_test);";
String explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql);
Assert.assertTrue(explainString.contains("PLAN FRAGMENT"));
Assert.assertTrue(explainString.contains("CROSS JOIN"));
Assert.assertTrue(!explainString.contains("PREDICATES"));
}
}