[SQL] Support subquery in case when statement (#3135)

#3153
implement subquery support for  sub query in case when statement like
```
SELECT CASE
        WHEN (
            SELECT COUNT(*) / 2
            FROM t
        ) > k4 THEN (
            SELECT AVG(k4)
            FROM t
        )
        ELSE (
            SELECT SUM(k4)
            FROM t
        )
    END AS kk4
FROM t;
```

this statement will be rewrite to 
```
SELECT CASE
        WHEN t1.a > k4 THEN t2.a
        ELSE t3.a
    END AS kk4
FROM t, (
        SELECT COUNT(*) / 2 AS a
        FROM t
    ) t1,  (
        SELECT AVG(k4) AS a
        FROM t
    ) t2,  (
        SELECT SUM(k4) AS a
        FROM t
    ) t3;
```
This commit is contained in:
yangzhg
2020-03-25 17:12:54 +08:00
committed by GitHub
parent b2518fc285
commit 71bc815b20
4 changed files with 140 additions and 58 deletions

View File

@ -1,43 +1,44 @@
package org.apache.doris.analysis;
import java.io.File;
import java.util.UUID;
import org.apache.doris.catalog.Catalog;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.rewrite.ExprRewriter;
import org.apache.doris.utframe.DorisAssert;
import org.apache.doris.utframe.UtFrameUtils;
import org.apache.commons.io.FileUtils;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import java.util.UUID;
public class SelectStmtTest {
private static String runningDir = "fe/mocked/DemoTest/" + UUID.randomUUID().toString() + "/";
private static DorisAssert dorisAssert;
@Rule
public ExpectedException expectedEx = ExpectedException.none();
@AfterClass
public static void afterClass() throws Exception {
FileUtils.deleteDirectory(new File(runningDir));
public static void tearDown() throws Exception {
UtFrameUtils.cleanDorisFeDir(runningDir);
}
@BeforeClass
public static void setUp() throws Exception {
UtFrameUtils.createMinDorisCluster(runningDir);
String createTblStmtStr = "create table db1.tbl1(k1 varchar(32), k2 varchar(32), k3 varchar(32), k4 int) "
+ "AGGREGATE KEY(k1, k2,k3,k4) distributed by hash(k1) buckets 3 properties('replication_num' = '1');";
dorisAssert = new DorisAssert();
dorisAssert.withDatabase("db1").useDatabase("db1");
dorisAssert.withTable(createTblStmtStr);
}
@Test
public void testGroupingSets() throws Exception {
ConnectContext ctx = UtFrameUtils.createDefaultCtx();
UtFrameUtils.createMinDorisCluster(runningDir);
String createDbStmtStr = "create database db1;";
CreateDbStmt createDbStmt = (CreateDbStmt) UtFrameUtils.parseAndAnalyzeStmt(createDbStmtStr, ctx);
Catalog.getCurrentCatalog().createDb(createDbStmt);
System.out.println(Catalog.getCurrentCatalog().getDbNames());
// 3. create table tbl1
String createTblStmtStr = "create table db1.tbl1(k1 varchar(32), k2 varchar(32), k3 varchar(32), k4 int) "
+ "AGGREGATE KEY(k1, k2,k3,k4) distributed by hash(k1) buckets 3 properties('replication_num' = '1');";
CreateTableStmt createTableStmt = (CreateTableStmt) UtFrameUtils.parseAndAnalyzeStmt(createTblStmtStr, ctx);
Catalog.getCurrentCatalog().createTable(createTableStmt);
String selectStmtStr = "select k1,k2,MAX(k4) from db1.tbl1 GROUP BY GROUPING sets ((k1,k2),(k1),(k2),());";
UtFrameUtils.parseAndAnalyzeStmt(selectStmtStr, ctx);
String selectStmtStr2 = "select k1,k4,MAX(k4) from db1.tbl1 GROUP BY GROUPING sets ((k1,k4),(k1),(k4),());";
@ -52,5 +53,29 @@ public class SelectStmtTest {
UtFrameUtils.parseAndAnalyzeStmt(selectStmtStr4, ctx);
}
}
@Test
public void testSubqueryInCase() throws Exception {
ConnectContext ctx = UtFrameUtils.createDefaultCtx();
String sql1 = "SELECT CASE\n" +
" WHEN (\n" +
" SELECT COUNT(*) / 2\n" +
" FROM db1.tbl1\n" +
" ) > k4 THEN (\n" +
" SELECT AVG(k4)\n" +
" FROM db1.tbl1\n" +
" )\n" +
" ELSE (\n" +
" SELECT SUM(k4)\n" +
" FROM db1.tbl1\n" +
" )\n" +
" END AS kk4\n" +
"FROM db1.tbl1;";
SelectStmt stmt = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql1, ctx);
stmt.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
Assert.assertEquals("SELECT CASE WHEN `$a$1`.`$c$1` > `k4` THEN `$a$2`.`$c$2` ELSE `$a$3`.`$c$3` END" +
" AS `kk4` FROM `default_cluster:db1`.`tbl1` (SELECT count(*) / 2.0 AS `count(*) / 2.0` FROM " +
"`default_cluster:db1`.`tbl1`) $a$1 (SELECT avg(`k4`) AS `avg(``k4``)` FROM" +
" `default_cluster:db1`.`tbl1`) $a$2 (SELECT sum(`k4`) AS `sum(``k4``)` " +
"FROM `default_cluster:db1`.`tbl1`) $a$3", stmt.toSql());
}
}