[SQL] Support subquery in case when statement (#3135)
#3153 implement subquery support for sub query in case when statement like ``` SELECT CASE WHEN ( SELECT COUNT(*) / 2 FROM t ) > k4 THEN ( SELECT AVG(k4) FROM t ) ELSE ( SELECT SUM(k4) FROM t ) END AS kk4 FROM t; ``` this statement will be rewrite to ``` SELECT CASE WHEN t1.a > k4 THEN t2.a ELSE t3.a END AS kk4 FROM t, ( SELECT COUNT(*) / 2 AS a FROM t ) t1, ( SELECT AVG(k4) AS a FROM t ) t2, ( SELECT SUM(k4) AS a FROM t ) t3; ```
This commit is contained in:
@ -1,43 +1,44 @@
|
||||
package org.apache.doris.analysis;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.UUID;
|
||||
|
||||
import org.apache.doris.catalog.Catalog;
|
||||
import org.apache.doris.common.AnalysisException;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
import org.apache.doris.rewrite.ExprRewriter;
|
||||
import org.apache.doris.utframe.DorisAssert;
|
||||
import org.apache.doris.utframe.UtFrameUtils;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.Assert;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
|
||||
import java.util.UUID;
|
||||
|
||||
public class SelectStmtTest {
|
||||
private static String runningDir = "fe/mocked/DemoTest/" + UUID.randomUUID().toString() + "/";
|
||||
private static DorisAssert dorisAssert;
|
||||
|
||||
@Rule
|
||||
public ExpectedException expectedEx = ExpectedException.none();
|
||||
|
||||
@AfterClass
|
||||
public static void afterClass() throws Exception {
|
||||
FileUtils.deleteDirectory(new File(runningDir));
|
||||
public static void tearDown() throws Exception {
|
||||
UtFrameUtils.cleanDorisFeDir(runningDir);
|
||||
}
|
||||
|
||||
@BeforeClass
|
||||
public static void setUp() throws Exception {
|
||||
UtFrameUtils.createMinDorisCluster(runningDir);
|
||||
String createTblStmtStr = "create table db1.tbl1(k1 varchar(32), k2 varchar(32), k3 varchar(32), k4 int) "
|
||||
+ "AGGREGATE KEY(k1, k2,k3,k4) distributed by hash(k1) buckets 3 properties('replication_num' = '1');";
|
||||
dorisAssert = new DorisAssert();
|
||||
dorisAssert.withDatabase("db1").useDatabase("db1");
|
||||
dorisAssert.withTable(createTblStmtStr);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGroupingSets() throws Exception {
|
||||
ConnectContext ctx = UtFrameUtils.createDefaultCtx();
|
||||
UtFrameUtils.createMinDorisCluster(runningDir);
|
||||
String createDbStmtStr = "create database db1;";
|
||||
CreateDbStmt createDbStmt = (CreateDbStmt) UtFrameUtils.parseAndAnalyzeStmt(createDbStmtStr, ctx);
|
||||
Catalog.getCurrentCatalog().createDb(createDbStmt);
|
||||
System.out.println(Catalog.getCurrentCatalog().getDbNames());
|
||||
// 3. create table tbl1
|
||||
String createTblStmtStr = "create table db1.tbl1(k1 varchar(32), k2 varchar(32), k3 varchar(32), k4 int) "
|
||||
+ "AGGREGATE KEY(k1, k2,k3,k4) distributed by hash(k1) buckets 3 properties('replication_num' = '1');";
|
||||
CreateTableStmt createTableStmt = (CreateTableStmt) UtFrameUtils.parseAndAnalyzeStmt(createTblStmtStr, ctx);
|
||||
Catalog.getCurrentCatalog().createTable(createTableStmt);
|
||||
String selectStmtStr = "select k1,k2,MAX(k4) from db1.tbl1 GROUP BY GROUPING sets ((k1,k2),(k1),(k2),());";
|
||||
UtFrameUtils.parseAndAnalyzeStmt(selectStmtStr, ctx);
|
||||
String selectStmtStr2 = "select k1,k4,MAX(k4) from db1.tbl1 GROUP BY GROUPING sets ((k1,k4),(k1),(k4),());";
|
||||
@ -52,5 +53,29 @@ public class SelectStmtTest {
|
||||
UtFrameUtils.parseAndAnalyzeStmt(selectStmtStr4, ctx);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@Test
|
||||
public void testSubqueryInCase() throws Exception {
|
||||
ConnectContext ctx = UtFrameUtils.createDefaultCtx();
|
||||
String sql1 = "SELECT CASE\n" +
|
||||
" WHEN (\n" +
|
||||
" SELECT COUNT(*) / 2\n" +
|
||||
" FROM db1.tbl1\n" +
|
||||
" ) > k4 THEN (\n" +
|
||||
" SELECT AVG(k4)\n" +
|
||||
" FROM db1.tbl1\n" +
|
||||
" )\n" +
|
||||
" ELSE (\n" +
|
||||
" SELECT SUM(k4)\n" +
|
||||
" FROM db1.tbl1\n" +
|
||||
" )\n" +
|
||||
" END AS kk4\n" +
|
||||
"FROM db1.tbl1;";
|
||||
SelectStmt stmt = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql1, ctx);
|
||||
stmt.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
|
||||
Assert.assertEquals("SELECT CASE WHEN `$a$1`.`$c$1` > `k4` THEN `$a$2`.`$c$2` ELSE `$a$3`.`$c$3` END" +
|
||||
" AS `kk4` FROM `default_cluster:db1`.`tbl1` (SELECT count(*) / 2.0 AS `count(*) / 2.0` FROM " +
|
||||
"`default_cluster:db1`.`tbl1`) $a$1 (SELECT avg(`k4`) AS `avg(``k4``)` FROM" +
|
||||
" `default_cluster:db1`.`tbl1`) $a$2 (SELECT sum(`k4`) AS `sum(``k4``)` " +
|
||||
"FROM `default_cluster:db1`.`tbl1`) $a$3", stmt.toSql());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user