From 71bc815b20ce589873a6baa0390bb2de36fa0dce Mon Sep 17 00:00:00 2001 From: yangzhg <780531911@qq.com> Date: Wed, 25 Mar 2020 17:12:54 +0800 Subject: [PATCH] [SQL] Support subquery in case when statement (#3135) #3153 implement subquery support for sub query in case when statement like ``` SELECT CASE WHEN ( SELECT COUNT(*) / 2 FROM t ) > k4 THEN ( SELECT AVG(k4) FROM t ) ELSE ( SELECT SUM(k4) FROM t ) END AS kk4 FROM t; ``` this statement will be rewrite to ``` SELECT CASE WHEN t1.a > k4 THEN t2.a ELSE t3.a END AS kk4 FROM t, ( SELECT COUNT(*) / 2 AS a FROM t ) t1, ( SELECT AVG(k4) AS a FROM t ) t2, ( SELECT SUM(k4) AS a FROM t ) t3; ``` --- .../org/apache/doris/analysis/FromClause.java | 4 +- .../org/apache/doris/analysis/SelectList.java | 13 +- .../org/apache/doris/analysis/SelectStmt.java | 116 ++++++++++++------ .../apache/doris/analysis/SelectStmtTest.java | 65 +++++++--- 4 files changed, 140 insertions(+), 58 deletions(-) diff --git a/fe/src/main/java/org/apache/doris/analysis/FromClause.java b/fe/src/main/java/org/apache/doris/analysis/FromClause.java index b8018df0d3..1120675580 100644 --- a/fe/src/main/java/org/apache/doris/analysis/FromClause.java +++ b/fe/src/main/java/org/apache/doris/analysis/FromClause.java @@ -132,9 +132,9 @@ public class FromClause implements ParseNode, Iterable { public String toSql() { StringBuilder builder = new StringBuilder(); if (!tableRefs_.isEmpty()) { - builder.append(" FROM "); + builder.append(" FROM"); for (int i = 0; i < tableRefs_.size(); ++i) { - builder.append(tableRefs_.get(i).toSql()); + builder.append(" " + tableRefs_.get(i).toSql()); } } return builder.toString(); diff --git a/fe/src/main/java/org/apache/doris/analysis/SelectList.java b/fe/src/main/java/org/apache/doris/analysis/SelectList.java index 083c136363..d56ed7d171 100644 --- a/fe/src/main/java/org/apache/doris/analysis/SelectList.java +++ b/fe/src/main/java/org/apache/doris/analysis/SelectList.java @@ -20,6 +20,7 @@ package org.apache.doris.analysis; import org.apache.doris.common.AnalysisException; import org.apache.doris.rewrite.ExprRewriter; +import com.google.common.base.Predicates; import com.google.common.collect.Lists; import java.util.List; @@ -83,7 +84,17 @@ class SelectList { public void rewriteExprs(ExprRewriter rewriter, Analyzer analyzer) throws AnalysisException { for (SelectListItem item : items) { - if (item.isStar()) continue; + if (item.isStar()) { + continue; + } + // rewrite subquery in select list + if (item.getExpr().contains(Predicates.instanceOf(Subquery.class))) { + List subqueryExprs = Lists.newArrayList(); + item.getExpr().collect(Subquery.class, subqueryExprs); + for (Subquery s : subqueryExprs) { + s.getStatement().rewriteExprs(rewriter); + } + } item.setExpr(rewriter.rewrite(item.getExpr(), analyzer)); } } diff --git a/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java b/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java index 5f8f827b1f..c86b40d8c7 100644 --- a/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java +++ b/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java @@ -52,6 +52,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Comparator; @@ -362,7 +363,8 @@ public class SelectStmt extends QueryStmt { // Analyze the resultExpr before generating a label to ensure enforcement // of expr child and depth limits (toColumn() label may call toSql()). item.getExpr().analyze(analyzer); - if (item.getExpr().contains(Predicates.instanceOf(Subquery.class))) { + if (!(item.getExpr() instanceof CaseExpr) && + item.getExpr().contains(Predicates.instanceOf(Subquery.class))) { throw new AnalysisException("Subquery is not supported in the select list."); } Expr expr = rewriteCountDistinctForBitmapOrHLL(item.getExpr(), analyzer); @@ -749,39 +751,6 @@ public class SelectStmt extends QueryStmt { } } - /** - * This select block might contain inline views. - * Substitute all exprs (result of the analysis) of this select block referencing any - * of our inlined views, including everything registered with the analyzer. - * Expressions created during parsing (such as whereClause) are not touched. - * - * @throws AnalysisException - */ - public void seondSubstituteInlineViewExprs(ExprSubstitutionMap sMap) throws AnalysisException { - // we might not have anything to substitute - if (sMap.size() == 0) { - return; - } - - // select - // Expr.substituteList(resultExprs, sMap); - - // aggregation (group by and aggregation expr) - if (aggInfo != null) { - aggInfo.substitute(sMap, analyzer); - } - - // having - if (havingPred != null) { - havingPred.substitute(sMap); - } - - // ordering - //if (sortInfo != null) { - // sortInfo.substitute(sMap); - //} - } - /** * Expand "*" select list item. */ @@ -1255,7 +1224,7 @@ public class SelectStmt extends QueryStmt { @Override public void rewriteExprs(ExprRewriter rewriter) throws AnalysisException { Preconditions.checkState(isAnalyzed()); - selectList.rewriteExprs(rewriter, analyzer); + rewriteSelectList(rewriter); for (TableRef ref : fromClause_) { ref.rewriteExprs(rewriter, analyzer); } @@ -1284,6 +1253,83 @@ public class SelectStmt extends QueryStmt { } } + private void rewriteSelectList(ExprRewriter rewriter) throws AnalysisException { + for (SelectListItem item : selectList.getItems()) { + if (!(item.getExpr() instanceof CaseExpr)) { + continue; + } + if (!item.getExpr().contains(Predicates.instanceOf(Subquery.class))) { + continue; + } + item.setExpr(rewriteSubquery(item.getExpr(), analyzer)); + } + + selectList.rewriteExprs(rewriter, analyzer); + } + + /** rewrite subquery in case when to an inline view + * subquery in case when statement like + * + * SELECT CASE + * WHEN ( + * SELECT COUNT(*) / 2 + * FROM t + * ) > k4 THEN ( + * SELECT AVG(k4) + * FROM t + * ) + * ELSE ( + * SELECT SUM(k4) + * FROM t + * ) + * END AS kk4 + * FROM t; + * this statement will be rewrite to + * + * SELECT CASE + * WHEN t1.a > k4 THEN t2.a + * ELSE t3.a + * END AS kk4 + * FROM t, ( + * SELECT COUNT(*) / 2 AS a + * FROM t + * ) t1, ( + * SELECT AVG(k4) AS a + * FROM t + * ) t2, ( + * SELECT SUM(k4) AS a + * FROM t + * ) t3; + */ + private Expr rewriteSubquery(Expr expr, Analyzer analyzer) + throws AnalysisException { + if (expr instanceof Subquery) { + if (!(((Subquery) expr).getStatement() instanceof SelectStmt)) { + throw new AnalysisException("Only support select subquery in case statement."); + } + SelectStmt subquery = (SelectStmt) ((Subquery) expr).getStatement(); + if (subquery.resultExprs.size() != 1) { + throw new AnalysisException("Only support select subquery produce one column in case statement."); + } + subquery.reset(); + String alias = getTableAliasGenerator().getNextAlias(); + String colAlias = getColumnAliasGenerator().getNextAlias(); + InlineViewRef inlineViewRef = new InlineViewRef(alias, subquery, Arrays.asList(colAlias)); + try { + inlineViewRef.analyze(analyzer); + } catch (UserException e) { + throw new AnalysisException(e.getMessage()); + } + fromClause_.add(inlineViewRef); + expr = new SlotRef(inlineViewRef.getAliasAsName(), colAlias); + } else if (CollectionUtils.isNotEmpty(expr.getChildren())) { + for (int i = 0; i < expr.getChildren().size(); ++i) { + expr.setChild(i, rewriteSubquery(expr.getChild(i), analyzer)); + } + } + return expr; + } + @Override public String toSql() { if (sqlString_ != null) { diff --git a/fe/src/test/java/org/apache/doris/analysis/SelectStmtTest.java b/fe/src/test/java/org/apache/doris/analysis/SelectStmtTest.java index 9491d4f81c..d1e2829022 100644 --- a/fe/src/test/java/org/apache/doris/analysis/SelectStmtTest.java +++ b/fe/src/test/java/org/apache/doris/analysis/SelectStmtTest.java @@ -1,43 +1,44 @@ package org.apache.doris.analysis; -import java.io.File; -import java.util.UUID; - -import org.apache.doris.catalog.Catalog; import org.apache.doris.common.AnalysisException; import org.apache.doris.qe.ConnectContext; +import org.apache.doris.rewrite.ExprRewriter; +import org.apache.doris.utframe.DorisAssert; import org.apache.doris.utframe.UtFrameUtils; - -import org.apache.commons.io.FileUtils; import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import java.util.UUID; + public class SelectStmtTest { private static String runningDir = "fe/mocked/DemoTest/" + UUID.randomUUID().toString() + "/"; + private static DorisAssert dorisAssert; @Rule public ExpectedException expectedEx = ExpectedException.none(); @AfterClass - public static void afterClass() throws Exception { - FileUtils.deleteDirectory(new File(runningDir)); + public static void tearDown() throws Exception { + UtFrameUtils.cleanDorisFeDir(runningDir); + } + + @BeforeClass + public static void setUp() throws Exception { + UtFrameUtils.createMinDorisCluster(runningDir); + String createTblStmtStr = "create table db1.tbl1(k1 varchar(32), k2 varchar(32), k3 varchar(32), k4 int) " + + "AGGREGATE KEY(k1, k2,k3,k4) distributed by hash(k1) buckets 3 properties('replication_num' = '1');"; + dorisAssert = new DorisAssert(); + dorisAssert.withDatabase("db1").useDatabase("db1"); + dorisAssert.withTable(createTblStmtStr); } @Test public void testGroupingSets() throws Exception { ConnectContext ctx = UtFrameUtils.createDefaultCtx(); - UtFrameUtils.createMinDorisCluster(runningDir); - String createDbStmtStr = "create database db1;"; - CreateDbStmt createDbStmt = (CreateDbStmt) UtFrameUtils.parseAndAnalyzeStmt(createDbStmtStr, ctx); - Catalog.getCurrentCatalog().createDb(createDbStmt); - System.out.println(Catalog.getCurrentCatalog().getDbNames()); - // 3. create table tbl1 - String createTblStmtStr = "create table db1.tbl1(k1 varchar(32), k2 varchar(32), k3 varchar(32), k4 int) " - + "AGGREGATE KEY(k1, k2,k3,k4) distributed by hash(k1) buckets 3 properties('replication_num' = '1');"; - CreateTableStmt createTableStmt = (CreateTableStmt) UtFrameUtils.parseAndAnalyzeStmt(createTblStmtStr, ctx); - Catalog.getCurrentCatalog().createTable(createTableStmt); String selectStmtStr = "select k1,k2,MAX(k4) from db1.tbl1 GROUP BY GROUPING sets ((k1,k2),(k1),(k2),());"; UtFrameUtils.parseAndAnalyzeStmt(selectStmtStr, ctx); String selectStmtStr2 = "select k1,k4,MAX(k4) from db1.tbl1 GROUP BY GROUPING sets ((k1,k4),(k1),(k4),());"; @@ -52,5 +53,29 @@ public class SelectStmtTest { UtFrameUtils.parseAndAnalyzeStmt(selectStmtStr4, ctx); } - -} \ No newline at end of file + @Test + public void testSubqueryInCase() throws Exception { + ConnectContext ctx = UtFrameUtils.createDefaultCtx(); + String sql1 = "SELECT CASE\n" + + " WHEN (\n" + + " SELECT COUNT(*) / 2\n" + + " FROM db1.tbl1\n" + + " ) > k4 THEN (\n" + + " SELECT AVG(k4)\n" + + " FROM db1.tbl1\n" + + " )\n" + + " ELSE (\n" + + " SELECT SUM(k4)\n" + + " FROM db1.tbl1\n" + + " )\n" + + " END AS kk4\n" + + "FROM db1.tbl1;"; + SelectStmt stmt = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql1, ctx); + stmt.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter()); + Assert.assertEquals("SELECT CASE WHEN `$a$1`.`$c$1` > `k4` THEN `$a$2`.`$c$2` ELSE `$a$3`.`$c$3` END" + + " AS `kk4` FROM `default_cluster:db1`.`tbl1` (SELECT count(*) / 2.0 AS `count(*) / 2.0` FROM " + + "`default_cluster:db1`.`tbl1`) $a$1 (SELECT avg(`k4`) AS `avg(``k4``)` FROM" + + " `default_cluster:db1`.`tbl1`) $a$2 (SELECT sum(`k4`) AS `sum(``k4``)` " + + "FROM `default_cluster:db1`.`tbl1`) $a$3", stmt.toSql()); + } +}