[SQL Function] Calculate 'case when expr' when possible (#3396)
Calculate 'case when expr' when possible
This commit is contained in:
@ -26,6 +26,7 @@ import org.apache.doris.thrift.TExprNodeType;
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
@ -101,6 +102,7 @@ public class CaseExpr extends Expr {
|
||||
CaseExpr expr = (CaseExpr) obj;
|
||||
return hasCaseExpr == expr.hasCaseExpr && hasElseExpr == expr.hasElseExpr;
|
||||
}
|
||||
|
||||
public boolean hasCaseExpr() {
|
||||
return hasCaseExpr;
|
||||
}
|
||||
@ -251,4 +253,84 @@ public class CaseExpr extends Expr {
|
||||
}
|
||||
return exprs;
|
||||
}
|
||||
|
||||
// this method just compare literal value and not completely consistent with be,for two cases
|
||||
// 1 not deal float
|
||||
// 2 just compare literal value with same type. for a example sql 'select case when 123 then '1' else '2' end as col'
|
||||
// for be will return '1', because be only regard 0 as false
|
||||
// but for current LiteralExpr.compareLiteral, `123`' won't be regard as true
|
||||
// the case which two values has different type left to be
|
||||
public static Expr computeCaseExpr(CaseExpr expr) {
|
||||
LiteralExpr caseExpr;
|
||||
int startIndex = 0;
|
||||
int endIndex = expr.getChildren().size();
|
||||
if (expr.hasCaseExpr()) {
|
||||
// just deal literal here
|
||||
// and avoid `float compute` in java,float should be dealt in be
|
||||
Expr caseChildExpr = expr.getChild(0);
|
||||
if (!caseChildExpr.isLiteral()
|
||||
|| caseChildExpr instanceof DecimalLiteral || caseChildExpr instanceof FloatLiteral) {
|
||||
return expr;
|
||||
}
|
||||
caseExpr = (LiteralExpr) expr.getChild(0);
|
||||
startIndex++;
|
||||
} else {
|
||||
caseExpr = new BoolLiteral(true);
|
||||
}
|
||||
|
||||
if (caseExpr instanceof NullLiteral) {
|
||||
if (expr.hasElseExpr) {
|
||||
return expr.getChild(expr.getChildren().size() - 1);
|
||||
} else {
|
||||
return new NullLiteral();
|
||||
}
|
||||
}
|
||||
|
||||
if (expr.hasElseExpr) {
|
||||
endIndex--;
|
||||
}
|
||||
|
||||
// early return when the `when expr` can't be converted to constants
|
||||
Expr startExpr = expr.getChild(startIndex);
|
||||
if ((!startExpr.isLiteral() || startExpr instanceof DecimalLiteral || startExpr instanceof FloatLiteral)
|
||||
|| (!(startExpr instanceof NullLiteral) && !startExpr.getClass().toString().equals(caseExpr.getClass().toString()))) {
|
||||
return expr;
|
||||
}
|
||||
|
||||
for (int i = startIndex; i < endIndex; i = i + 2) {
|
||||
Expr currentWhenExpr = expr.getChild(i);
|
||||
// skip null literal
|
||||
if (currentWhenExpr instanceof NullLiteral) {
|
||||
continue;
|
||||
}
|
||||
// stop convert in three cases
|
||||
// 1 not literal
|
||||
// 2 float
|
||||
// 3 `case expr` and `when expr` don't have same type
|
||||
if ((!currentWhenExpr.isLiteral() || currentWhenExpr instanceof DecimalLiteral || currentWhenExpr instanceof FloatLiteral)
|
||||
|| !currentWhenExpr.getClass().toString().equals(caseExpr.getClass().toString())) {
|
||||
// remove the expr which has been evaluated
|
||||
List<Expr> exprLeft = new ArrayList<>();
|
||||
if (expr.hasCaseExpr()) {
|
||||
exprLeft.add(caseExpr);
|
||||
}
|
||||
for (int j = i; j < expr.getChildren().size(); j++) {
|
||||
exprLeft.add(expr.getChild(j));
|
||||
}
|
||||
Expr retCaseExpr = expr.clone();
|
||||
retCaseExpr.getChildren().clear();
|
||||
retCaseExpr.addChildren(exprLeft);
|
||||
return retCaseExpr;
|
||||
} else if (caseExpr.compareLiteral((LiteralExpr) currentWhenExpr) == 0) {
|
||||
return expr.getChild(i + 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (expr.hasElseExpr) {
|
||||
return expr.getChild(expr.getChildren().size() - 1);
|
||||
} else {
|
||||
return new NullLiteral();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@ package org.apache.doris.rewrite;
|
||||
|
||||
|
||||
import org.apache.doris.analysis.Analyzer;
|
||||
import org.apache.doris.analysis.CaseExpr;
|
||||
import org.apache.doris.analysis.CastExpr;
|
||||
import org.apache.doris.analysis.Expr;
|
||||
import org.apache.doris.analysis.NullLiteral;
|
||||
@ -48,6 +49,11 @@ public class FoldConstantsRule implements ExprRewriteRule {
|
||||
|
||||
@Override
|
||||
public Expr apply(Expr expr, Analyzer analyzer) throws AnalysisException {
|
||||
// evaluate `case when expr` when possible
|
||||
if (expr instanceof CaseExpr) {
|
||||
return CaseExpr.computeCaseExpr((CaseExpr) expr);
|
||||
}
|
||||
|
||||
// Avoid calling Expr.isConstant() because that would lead to repeated traversals
|
||||
// of the Expr tree. Assumes the bottom-up application of this rule. Constant
|
||||
// children should have been folded at this point.
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
|
||||
package org.apache.doris.planner;
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.doris.analysis.CreateDbStmt;
|
||||
import org.apache.doris.analysis.CreateTableStmt;
|
||||
import org.apache.doris.analysis.DropDbStmt;
|
||||
@ -242,31 +243,31 @@ public class QueryPlanTest {
|
||||
"PROPERTIES (\n" +
|
||||
" \"replication_num\" = \"1\"\n" +
|
||||
");");
|
||||
|
||||
|
||||
createTable("CREATE TABLE test.`pushdown_test` (\n" +
|
||||
" `k1` tinyint(4) NULL COMMENT \"\",\n" +
|
||||
" `k2` smallint(6) NULL COMMENT \"\",\n" +
|
||||
" `k3` int(11) NULL COMMENT \"\",\n" +
|
||||
" `k4` bigint(20) NULL COMMENT \"\",\n" +
|
||||
" `k5` decimal(9, 3) NULL COMMENT \"\",\n" +
|
||||
" `k6` char(5) NULL COMMENT \"\",\n" +
|
||||
" `k10` date NULL COMMENT \"\",\n" +
|
||||
" `k11` datetime NULL COMMENT \"\",\n" +
|
||||
" `k7` varchar(20) NULL COMMENT \"\",\n" +
|
||||
" `k8` double MAX NULL COMMENT \"\",\n" +
|
||||
" `k9` float SUM NULL COMMENT \"\"\n" +
|
||||
") ENGINE=OLAP\n" +
|
||||
"AGGREGATE KEY(`k1`, `k2`, `k3`, `k4`, `k5`, `k6`, `k10`, `k11`, `k7`)\n" +
|
||||
"COMMENT \"OLAP\"\n" +
|
||||
"PARTITION BY RANGE(`k1`)\n" +
|
||||
"(PARTITION p1 VALUES [(\"-128\"), (\"-64\")),\n" +
|
||||
"PARTITION p2 VALUES [(\"-64\"), (\"0\")),\n" +
|
||||
"PARTITION p3 VALUES [(\"0\"), (\"64\")))\n" +
|
||||
"DISTRIBUTED BY HASH(`k1`) BUCKETS 5\n" +
|
||||
"PROPERTIES (\n" +
|
||||
"\"replication_num\" = \"1\",\n" +
|
||||
"\"in_memory\" = \"false\",\n" +
|
||||
"\"storage_format\" = \"DEFAULT\"\n" +
|
||||
" `k1` tinyint(4) NULL COMMENT \"\",\n" +
|
||||
" `k2` smallint(6) NULL COMMENT \"\",\n" +
|
||||
" `k3` int(11) NULL COMMENT \"\",\n" +
|
||||
" `k4` bigint(20) NULL COMMENT \"\",\n" +
|
||||
" `k5` decimal(9, 3) NULL COMMENT \"\",\n" +
|
||||
" `k6` char(5) NULL COMMENT \"\",\n" +
|
||||
" `k10` date NULL COMMENT \"\",\n" +
|
||||
" `k11` datetime NULL COMMENT \"\",\n" +
|
||||
" `k7` varchar(20) NULL COMMENT \"\",\n" +
|
||||
" `k8` double MAX NULL COMMENT \"\",\n" +
|
||||
" `k9` float SUM NULL COMMENT \"\"\n" +
|
||||
") ENGINE=OLAP\n" +
|
||||
"AGGREGATE KEY(`k1`, `k2`, `k3`, `k4`, `k5`, `k6`, `k10`, `k11`, `k7`)\n" +
|
||||
"COMMENT \"OLAP\"\n" +
|
||||
"PARTITION BY RANGE(`k1`)\n" +
|
||||
"(PARTITION p1 VALUES [(\"-128\"), (\"-64\")),\n" +
|
||||
"PARTITION p2 VALUES [(\"-64\"), (\"0\")),\n" +
|
||||
"PARTITION p3 VALUES [(\"0\"), (\"64\")))\n" +
|
||||
"DISTRIBUTED BY HASH(`k1`) BUCKETS 5\n" +
|
||||
"PROPERTIES (\n" +
|
||||
"\"replication_num\" = \"1\",\n" +
|
||||
"\"in_memory\" = \"false\",\n" +
|
||||
"\"storage_format\" = \"DEFAULT\"\n" +
|
||||
");");
|
||||
}
|
||||
|
||||
@ -711,18 +712,119 @@ public class QueryPlanTest {
|
||||
Assert.assertTrue(explainString.contains("PREDICATES: `join1`.`id` > 1"));
|
||||
Assert.assertFalse(explainString.contains("PREDICATES: `join2`.`id` > 1"));
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testConvertCaseWhenToConstant() throws Exception {
|
||||
// basic test
|
||||
String caseWhenSql = "select "
|
||||
+ "case when date_format(now(),'%H%i') < 123 then 1 else 0 end as col "
|
||||
+ "from test.test1 "
|
||||
+ "where time = case when date_format(now(),'%H%i') < 123 then date_format(date_sub(now(),2),'%Y%m%d') else date_format(date_sub(now(),1),'%Y%m%d') end";
|
||||
Assert.assertTrue(!StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + caseWhenSql), "CASE WHEN"));
|
||||
|
||||
// test 1: case when then
|
||||
// 1.1 multi when in on `case when` and can be converted to constants
|
||||
String sql11 = "select case when false then 2 when true then 3 else 0 end as col11;";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql11), "constant exprs: \n 3"));
|
||||
|
||||
// 1.2 multi `when expr` in on `case when` ,`when expr` can not be converted to constants
|
||||
String sql121 = "select case when false then 2 when substr(k7,2,1) then 3 else 0 end as col121 from test.baseall";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql121),
|
||||
"OUTPUT EXPRS:CASE WHEN substr(`k7`, 2, 1) THEN 3 ELSE 0 END"));
|
||||
|
||||
// 1.2.2 when expr which can not be converted to constants in the first
|
||||
String sql122 = "select case when substr(k7,2,1) then 2 when false then 3 else 0 end as col122 from test.baseall";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql122),
|
||||
"OUTPUT EXPRS:CASE WHEN substr(`k7`, 2, 1) THEN 2 WHEN FALSE THEN 3 ELSE 0 END"));
|
||||
|
||||
// 1.2.3 test return `then expr` in the middle
|
||||
String sql124 = "select case when false then 1 when true then 2 when false then 3 else 'other' end as col124";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql124), "constant exprs: \n '2'"));
|
||||
|
||||
// 1.3 test return null
|
||||
String sql3 = "select case when false then 2 end as col3";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql3), "constant exprs: \n NULL"));
|
||||
|
||||
// 1.3.1 test return else expr
|
||||
String sql131 = "select case when false then 2 when false then 3 else 4 end as col131";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql131), "constant exprs: \n 4"));
|
||||
|
||||
// 1.4 nest `case when` and can be converted to constants
|
||||
String sql14 = "select case when (case when true then true else false end) then 2 when false then 3 else 0 end as col";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql14), "constant exprs: \n 2"));
|
||||
|
||||
// 1.5 nest `case when` and can not be converted to constants
|
||||
String sql15 = "select case when case when substr(k7,2,1) then true else false end then 2 when false then 3 else 0 end as col from test.baseall";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql15),
|
||||
"OUTPUT EXPRS:CASE WHEN CASE WHEN substr(`k7`, 2, 1) THEN TRUE ELSE FALSE END THEN 2 WHEN FALSE THEN 3 ELSE 0 END"));
|
||||
|
||||
// 1.6 test when expr is null
|
||||
String sql16 = "select case when null then 1 else 2 end as col16;";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql16), "constant exprs: \n 2"));
|
||||
|
||||
// test 2: case xxx when then
|
||||
// 2.1 test equal
|
||||
String sql2 = "select case 1 when 1 then 'a' when 2 then 'b' else 'other' end as col2;";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql2), "constant exprs: \n 'a'"));
|
||||
|
||||
// 2.1.2 test not equal
|
||||
String sql212 = "select case 'a' when 1 then 'a' when 'a' then 'b' else 'other' end as col212;";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql212), "constant exprs: \n 'b'"));
|
||||
|
||||
// 2.2 test return null
|
||||
String sql22 = "select case 'a' when 1 then 'a' when 'b' then 'b' end as col22;";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql22), "constant exprs: \n NULL"));
|
||||
|
||||
// 2.2.2 test return else
|
||||
String sql222 = "select case 1 when 2 then 'a' when 3 then 'b' else 'other' end as col222;";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql222), "constant exprs: \n 'other'"));
|
||||
|
||||
// 2.3 test can not convert to constant,middle when expr is not constant
|
||||
String sql23 = "select case 'a' when 'b' then 'a' when substr(k7,2,1) then 2 when false then 3 else 0 end as col23 from test.baseall";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql23),
|
||||
"OUTPUT EXPRS:CASE'a' WHEN substr(`k7`, 2, 1) THEN '2' WHEN '0' THEN '3' ELSE '0' END"));
|
||||
|
||||
// 2.3.1 first when expr is not constant
|
||||
String sql231 = "select case 'a' when substr(k7,2,1) then 2 when 1 then 'a' when false then 3 else 0 end as col231 from test.baseall";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql231),
|
||||
"OUTPUT EXPRS:CASE'a' WHEN substr(`k7`, 2, 1) THEN '2' WHEN '1' THEN 'a' WHEN '0' THEN '3' ELSE '0' END"));
|
||||
|
||||
// 2.3.2 case expr is not constant
|
||||
String sql232 = "select case k1 when substr(k7,2,1) then 2 when 1 then 'a' when false then 3 else 0 end as col232 from test.baseall";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql232),
|
||||
"OUTPUT EXPRS:CASE`k1` WHEN substr(`k7`, 2, 1) THEN '2' WHEN '1' THEN 'a' WHEN '0' THEN '3' ELSE '0' END"));
|
||||
|
||||
// 3.1 test float,float in case expr
|
||||
String sql31 = "select case cast(100 as float) when 1 then 'a' when 2 then 'b' else 'other' end as col31;";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql31),
|
||||
"constant exprs: \n CASE100.0 WHEN 1.0 THEN 'a' WHEN 2.0 THEN 'b' ELSE 'other' END"));
|
||||
|
||||
// 4.1 test null in case expr return else
|
||||
String sql41 = "select case null when 1 then 'a' when 2 then 'b' else 'other' end as col41";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql41), "constant exprs: \n 'other'"));
|
||||
|
||||
// 4.1.2 test null in case expr return null
|
||||
String sql412 = "select case null when 1 then 'a' when 2 then 'b' end as col41";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql412), "constant exprs: \n NULL"));
|
||||
|
||||
// 4.2.1 test null in when expr
|
||||
String sql421 = "select case 'a' when null then 'a' else 'other' end as col421";
|
||||
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql421), "constant exprs: \n 'other'"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testJoinPredicateTransitivityWithSubqueryInWhereClause() throws Exception {
|
||||
connectContext.setDatabase("default_cluster:test");
|
||||
String sql = "SELECT *\n" +
|
||||
String sql = "SELECT *\n" +
|
||||
"FROM test.pushdown_test\n" +
|
||||
"WHERE 0 < (\n" +
|
||||
" SELECT MAX(k9)\n" +
|
||||
" SELECT MAX(k9)\n" +
|
||||
" FROM test.pushdown_test);";
|
||||
String explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql);
|
||||
Assert.assertTrue(explainString.contains("PLAN FRAGMENT"));
|
||||
Assert.assertTrue(explainString.contains("CROSS JOIN"));
|
||||
Assert.assertTrue(!explainString.contains("PREDICATES"));
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user