[SQL] Support subquery in case when statement (#3135)

#3153
implement subquery support for  sub query in case when statement like
```
SELECT CASE
        WHEN (
            SELECT COUNT(*) / 2
            FROM t
        ) > k4 THEN (
            SELECT AVG(k4)
            FROM t
        )
        ELSE (
            SELECT SUM(k4)
            FROM t
        )
    END AS kk4
FROM t;
```

this statement will be rewrite to 
```
SELECT CASE
        WHEN t1.a > k4 THEN t2.a
        ELSE t3.a
    END AS kk4
FROM t, (
        SELECT COUNT(*) / 2 AS a
        FROM t
    ) t1,  (
        SELECT AVG(k4) AS a
        FROM t
    ) t2,  (
        SELECT SUM(k4) AS a
        FROM t
    ) t3;
```
This commit is contained in:
yangzhg
2020-03-25 17:12:54 +08:00
committed by GitHub
parent b2518fc285
commit 71bc815b20
4 changed files with 140 additions and 58 deletions

View File

@ -132,9 +132,9 @@ public class FromClause implements ParseNode, Iterable<TableRef> {
public String toSql() {
StringBuilder builder = new StringBuilder();
if (!tableRefs_.isEmpty()) {
builder.append(" FROM ");
builder.append(" FROM");
for (int i = 0; i < tableRefs_.size(); ++i) {
builder.append(tableRefs_.get(i).toSql());
builder.append(" " + tableRefs_.get(i).toSql());
}
}
return builder.toString();

View File

@ -20,6 +20,7 @@ package org.apache.doris.analysis;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.rewrite.ExprRewriter;
import com.google.common.base.Predicates;
import com.google.common.collect.Lists;
import java.util.List;
@ -83,7 +84,17 @@ class SelectList {
public void rewriteExprs(ExprRewriter rewriter, Analyzer analyzer)
throws AnalysisException {
for (SelectListItem item : items) {
if (item.isStar()) continue;
if (item.isStar()) {
continue;
}
// rewrite subquery in select list
if (item.getExpr().contains(Predicates.instanceOf(Subquery.class))) {
List<Subquery> subqueryExprs = Lists.newArrayList();
item.getExpr().collect(Subquery.class, subqueryExprs);
for (Subquery s : subqueryExprs) {
s.getStatement().rewriteExprs(rewriter);
}
}
item.setExpr(rewriter.rewrite(item.getExpr(), analyzer));
}
}

View File

@ -52,6 +52,7 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
@ -362,7 +363,8 @@ public class SelectStmt extends QueryStmt {
// Analyze the resultExpr before generating a label to ensure enforcement
// of expr child and depth limits (toColumn() label may call toSql()).
item.getExpr().analyze(analyzer);
if (item.getExpr().contains(Predicates.instanceOf(Subquery.class))) {
if (!(item.getExpr() instanceof CaseExpr) &&
item.getExpr().contains(Predicates.instanceOf(Subquery.class))) {
throw new AnalysisException("Subquery is not supported in the select list.");
}
Expr expr = rewriteCountDistinctForBitmapOrHLL(item.getExpr(), analyzer);
@ -749,39 +751,6 @@ public class SelectStmt extends QueryStmt {
}
}
/**
* This select block might contain inline views.
* Substitute all exprs (result of the analysis) of this select block referencing any
* of our inlined views, including everything registered with the analyzer.
* Expressions created during parsing (such as whereClause) are not touched.
*
* @throws AnalysisException
*/
public void seondSubstituteInlineViewExprs(ExprSubstitutionMap sMap) throws AnalysisException {
// we might not have anything to substitute
if (sMap.size() == 0) {
return;
}
// select
// Expr.substituteList(resultExprs, sMap);
// aggregation (group by and aggregation expr)
if (aggInfo != null) {
aggInfo.substitute(sMap, analyzer);
}
// having
if (havingPred != null) {
havingPred.substitute(sMap);
}
// ordering
//if (sortInfo != null) {
// sortInfo.substitute(sMap);
//}
}
/**
* Expand "*" select list item.
*/
@ -1255,7 +1224,7 @@ public class SelectStmt extends QueryStmt {
@Override
public void rewriteExprs(ExprRewriter rewriter) throws AnalysisException {
Preconditions.checkState(isAnalyzed());
selectList.rewriteExprs(rewriter, analyzer);
rewriteSelectList(rewriter);
for (TableRef ref : fromClause_) {
ref.rewriteExprs(rewriter, analyzer);
}
@ -1284,6 +1253,83 @@ public class SelectStmt extends QueryStmt {
}
}
private void rewriteSelectList(ExprRewriter rewriter) throws AnalysisException {
for (SelectListItem item : selectList.getItems()) {
if (!(item.getExpr() instanceof CaseExpr)) {
continue;
}
if (!item.getExpr().contains(Predicates.instanceOf(Subquery.class))) {
continue;
}
item.setExpr(rewriteSubquery(item.getExpr(), analyzer));
}
selectList.rewriteExprs(rewriter, analyzer);
}
/** rewrite subquery in case when to an inline view
* subquery in case when statement like
*
* SELECT CASE
* WHEN (
* SELECT COUNT(*) / 2
* FROM t
* ) > k4 THEN (
* SELECT AVG(k4)
* FROM t
* )
* ELSE (
* SELECT SUM(k4)
* FROM t
* )
* END AS kk4
* FROM t;
* this statement will be rewrite to
*
* SELECT CASE
* WHEN t1.a > k4 THEN t2.a
* ELSE t3.a
* END AS kk4
* FROM t, (
* SELECT COUNT(*) / 2 AS a
* FROM t
* ) t1, (
* SELECT AVG(k4) AS a
* FROM t
* ) t2, (
* SELECT SUM(k4) AS a
* FROM t
* ) t3;
*/
private Expr rewriteSubquery(Expr expr, Analyzer analyzer)
throws AnalysisException {
if (expr instanceof Subquery) {
if (!(((Subquery) expr).getStatement() instanceof SelectStmt)) {
throw new AnalysisException("Only support select subquery in case statement.");
}
SelectStmt subquery = (SelectStmt) ((Subquery) expr).getStatement();
if (subquery.resultExprs.size() != 1) {
throw new AnalysisException("Only support select subquery produce one column in case statement.");
}
subquery.reset();
String alias = getTableAliasGenerator().getNextAlias();
String colAlias = getColumnAliasGenerator().getNextAlias();
InlineViewRef inlineViewRef = new InlineViewRef(alias, subquery, Arrays.asList(colAlias));
try {
inlineViewRef.analyze(analyzer);
} catch (UserException e) {
throw new AnalysisException(e.getMessage());
}
fromClause_.add(inlineViewRef);
expr = new SlotRef(inlineViewRef.getAliasAsName(), colAlias);
} else if (CollectionUtils.isNotEmpty(expr.getChildren())) {
for (int i = 0; i < expr.getChildren().size(); ++i) {
expr.setChild(i, rewriteSubquery(expr.getChild(i), analyzer));
}
}
return expr;
}
@Override
public String toSql() {
if (sqlString_ != null) {

View File

@ -1,43 +1,44 @@
package org.apache.doris.analysis;
import java.io.File;
import java.util.UUID;
import org.apache.doris.catalog.Catalog;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.rewrite.ExprRewriter;
import org.apache.doris.utframe.DorisAssert;
import org.apache.doris.utframe.UtFrameUtils;
import org.apache.commons.io.FileUtils;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import java.util.UUID;
public class SelectStmtTest {
private static String runningDir = "fe/mocked/DemoTest/" + UUID.randomUUID().toString() + "/";
private static DorisAssert dorisAssert;
@Rule
public ExpectedException expectedEx = ExpectedException.none();
@AfterClass
public static void afterClass() throws Exception {
FileUtils.deleteDirectory(new File(runningDir));
public static void tearDown() throws Exception {
UtFrameUtils.cleanDorisFeDir(runningDir);
}
@BeforeClass
public static void setUp() throws Exception {
UtFrameUtils.createMinDorisCluster(runningDir);
String createTblStmtStr = "create table db1.tbl1(k1 varchar(32), k2 varchar(32), k3 varchar(32), k4 int) "
+ "AGGREGATE KEY(k1, k2,k3,k4) distributed by hash(k1) buckets 3 properties('replication_num' = '1');";
dorisAssert = new DorisAssert();
dorisAssert.withDatabase("db1").useDatabase("db1");
dorisAssert.withTable(createTblStmtStr);
}
@Test
public void testGroupingSets() throws Exception {
ConnectContext ctx = UtFrameUtils.createDefaultCtx();
UtFrameUtils.createMinDorisCluster(runningDir);
String createDbStmtStr = "create database db1;";
CreateDbStmt createDbStmt = (CreateDbStmt) UtFrameUtils.parseAndAnalyzeStmt(createDbStmtStr, ctx);
Catalog.getCurrentCatalog().createDb(createDbStmt);
System.out.println(Catalog.getCurrentCatalog().getDbNames());
// 3. create table tbl1
String createTblStmtStr = "create table db1.tbl1(k1 varchar(32), k2 varchar(32), k3 varchar(32), k4 int) "
+ "AGGREGATE KEY(k1, k2,k3,k4) distributed by hash(k1) buckets 3 properties('replication_num' = '1');";
CreateTableStmt createTableStmt = (CreateTableStmt) UtFrameUtils.parseAndAnalyzeStmt(createTblStmtStr, ctx);
Catalog.getCurrentCatalog().createTable(createTableStmt);
String selectStmtStr = "select k1,k2,MAX(k4) from db1.tbl1 GROUP BY GROUPING sets ((k1,k2),(k1),(k2),());";
UtFrameUtils.parseAndAnalyzeStmt(selectStmtStr, ctx);
String selectStmtStr2 = "select k1,k4,MAX(k4) from db1.tbl1 GROUP BY GROUPING sets ((k1,k4),(k1),(k4),());";
@ -52,5 +53,29 @@ public class SelectStmtTest {
UtFrameUtils.parseAndAnalyzeStmt(selectStmtStr4, ctx);
}
}
@Test
public void testSubqueryInCase() throws Exception {
ConnectContext ctx = UtFrameUtils.createDefaultCtx();
String sql1 = "SELECT CASE\n" +
" WHEN (\n" +
" SELECT COUNT(*) / 2\n" +
" FROM db1.tbl1\n" +
" ) > k4 THEN (\n" +
" SELECT AVG(k4)\n" +
" FROM db1.tbl1\n" +
" )\n" +
" ELSE (\n" +
" SELECT SUM(k4)\n" +
" FROM db1.tbl1\n" +
" )\n" +
" END AS kk4\n" +
"FROM db1.tbl1;";
SelectStmt stmt = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql1, ctx);
stmt.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
Assert.assertEquals("SELECT CASE WHEN `$a$1`.`$c$1` > `k4` THEN `$a$2`.`$c$2` ELSE `$a$3`.`$c$3` END" +
" AS `kk4` FROM `default_cluster:db1`.`tbl1` (SELECT count(*) / 2.0 AS `count(*) / 2.0` FROM " +
"`default_cluster:db1`.`tbl1`) $a$1 (SELECT avg(`k4`) AS `avg(``k4``)` FROM" +
" `default_cluster:db1`.`tbl1`) $a$2 (SELECT sum(`k4`) AS `sum(``k4``)` " +
"FROM `default_cluster:db1`.`tbl1`) $a$3", stmt.toSql());
}
}