[SQL] Support subquery in case when statement (#3135)
#3153 implement subquery support for sub query in case when statement like ``` SELECT CASE WHEN ( SELECT COUNT(*) / 2 FROM t ) > k4 THEN ( SELECT AVG(k4) FROM t ) ELSE ( SELECT SUM(k4) FROM t ) END AS kk4 FROM t; ``` this statement will be rewrite to ``` SELECT CASE WHEN t1.a > k4 THEN t2.a ELSE t3.a END AS kk4 FROM t, ( SELECT COUNT(*) / 2 AS a FROM t ) t1, ( SELECT AVG(k4) AS a FROM t ) t2, ( SELECT SUM(k4) AS a FROM t ) t3; ```
This commit is contained in:
@ -132,9 +132,9 @@ public class FromClause implements ParseNode, Iterable<TableRef> {
|
||||
public String toSql() {
|
||||
StringBuilder builder = new StringBuilder();
|
||||
if (!tableRefs_.isEmpty()) {
|
||||
builder.append(" FROM ");
|
||||
builder.append(" FROM");
|
||||
for (int i = 0; i < tableRefs_.size(); ++i) {
|
||||
builder.append(tableRefs_.get(i).toSql());
|
||||
builder.append(" " + tableRefs_.get(i).toSql());
|
||||
}
|
||||
}
|
||||
return builder.toString();
|
||||
|
||||
@ -20,6 +20,7 @@ package org.apache.doris.analysis;
|
||||
import org.apache.doris.common.AnalysisException;
|
||||
import org.apache.doris.rewrite.ExprRewriter;
|
||||
|
||||
import com.google.common.base.Predicates;
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.List;
|
||||
@ -83,7 +84,17 @@ class SelectList {
|
||||
public void rewriteExprs(ExprRewriter rewriter, Analyzer analyzer)
|
||||
throws AnalysisException {
|
||||
for (SelectListItem item : items) {
|
||||
if (item.isStar()) continue;
|
||||
if (item.isStar()) {
|
||||
continue;
|
||||
}
|
||||
// rewrite subquery in select list
|
||||
if (item.getExpr().contains(Predicates.instanceOf(Subquery.class))) {
|
||||
List<Subquery> subqueryExprs = Lists.newArrayList();
|
||||
item.getExpr().collect(Subquery.class, subqueryExprs);
|
||||
for (Subquery s : subqueryExprs) {
|
||||
s.getStatement().rewriteExprs(rewriter);
|
||||
}
|
||||
}
|
||||
item.setExpr(rewriter.rewrite(item.getExpr(), analyzer));
|
||||
}
|
||||
}
|
||||
|
||||
@ -52,6 +52,7 @@ import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
@ -362,7 +363,8 @@ public class SelectStmt extends QueryStmt {
|
||||
// Analyze the resultExpr before generating a label to ensure enforcement
|
||||
// of expr child and depth limits (toColumn() label may call toSql()).
|
||||
item.getExpr().analyze(analyzer);
|
||||
if (item.getExpr().contains(Predicates.instanceOf(Subquery.class))) {
|
||||
if (!(item.getExpr() instanceof CaseExpr) &&
|
||||
item.getExpr().contains(Predicates.instanceOf(Subquery.class))) {
|
||||
throw new AnalysisException("Subquery is not supported in the select list.");
|
||||
}
|
||||
Expr expr = rewriteCountDistinctForBitmapOrHLL(item.getExpr(), analyzer);
|
||||
@ -749,39 +751,6 @@ public class SelectStmt extends QueryStmt {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This select block might contain inline views.
|
||||
* Substitute all exprs (result of the analysis) of this select block referencing any
|
||||
* of our inlined views, including everything registered with the analyzer.
|
||||
* Expressions created during parsing (such as whereClause) are not touched.
|
||||
*
|
||||
* @throws AnalysisException
|
||||
*/
|
||||
public void seondSubstituteInlineViewExprs(ExprSubstitutionMap sMap) throws AnalysisException {
|
||||
// we might not have anything to substitute
|
||||
if (sMap.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// select
|
||||
// Expr.substituteList(resultExprs, sMap);
|
||||
|
||||
// aggregation (group by and aggregation expr)
|
||||
if (aggInfo != null) {
|
||||
aggInfo.substitute(sMap, analyzer);
|
||||
}
|
||||
|
||||
// having
|
||||
if (havingPred != null) {
|
||||
havingPred.substitute(sMap);
|
||||
}
|
||||
|
||||
// ordering
|
||||
//if (sortInfo != null) {
|
||||
// sortInfo.substitute(sMap);
|
||||
//}
|
||||
}
|
||||
|
||||
/**
|
||||
* Expand "*" select list item.
|
||||
*/
|
||||
@ -1255,7 +1224,7 @@ public class SelectStmt extends QueryStmt {
|
||||
@Override
|
||||
public void rewriteExprs(ExprRewriter rewriter) throws AnalysisException {
|
||||
Preconditions.checkState(isAnalyzed());
|
||||
selectList.rewriteExprs(rewriter, analyzer);
|
||||
rewriteSelectList(rewriter);
|
||||
for (TableRef ref : fromClause_) {
|
||||
ref.rewriteExprs(rewriter, analyzer);
|
||||
}
|
||||
@ -1284,6 +1253,83 @@ public class SelectStmt extends QueryStmt {
|
||||
}
|
||||
}
|
||||
|
||||
private void rewriteSelectList(ExprRewriter rewriter) throws AnalysisException {
|
||||
for (SelectListItem item : selectList.getItems()) {
|
||||
if (!(item.getExpr() instanceof CaseExpr)) {
|
||||
continue;
|
||||
}
|
||||
if (!item.getExpr().contains(Predicates.instanceOf(Subquery.class))) {
|
||||
continue;
|
||||
}
|
||||
item.setExpr(rewriteSubquery(item.getExpr(), analyzer));
|
||||
}
|
||||
|
||||
selectList.rewriteExprs(rewriter, analyzer);
|
||||
}
|
||||
|
||||
/** rewrite subquery in case when to an inline view
|
||||
* subquery in case when statement like
|
||||
*
|
||||
* SELECT CASE
|
||||
* WHEN (
|
||||
* SELECT COUNT(*) / 2
|
||||
* FROM t
|
||||
* ) > k4 THEN (
|
||||
* SELECT AVG(k4)
|
||||
* FROM t
|
||||
* )
|
||||
* ELSE (
|
||||
* SELECT SUM(k4)
|
||||
* FROM t
|
||||
* )
|
||||
* END AS kk4
|
||||
* FROM t;
|
||||
* this statement will be rewrite to
|
||||
*
|
||||
* SELECT CASE
|
||||
* WHEN t1.a > k4 THEN t2.a
|
||||
* ELSE t3.a
|
||||
* END AS kk4
|
||||
* FROM t, (
|
||||
* SELECT COUNT(*) / 2 AS a
|
||||
* FROM t
|
||||
* ) t1, (
|
||||
* SELECT AVG(k4) AS a
|
||||
* FROM t
|
||||
* ) t2, (
|
||||
* SELECT SUM(k4) AS a
|
||||
* FROM t
|
||||
* ) t3;
|
||||
*/
|
||||
private Expr rewriteSubquery(Expr expr, Analyzer analyzer)
|
||||
throws AnalysisException {
|
||||
if (expr instanceof Subquery) {
|
||||
if (!(((Subquery) expr).getStatement() instanceof SelectStmt)) {
|
||||
throw new AnalysisException("Only support select subquery in case statement.");
|
||||
}
|
||||
SelectStmt subquery = (SelectStmt) ((Subquery) expr).getStatement();
|
||||
if (subquery.resultExprs.size() != 1) {
|
||||
throw new AnalysisException("Only support select subquery produce one column in case statement.");
|
||||
}
|
||||
subquery.reset();
|
||||
String alias = getTableAliasGenerator().getNextAlias();
|
||||
String colAlias = getColumnAliasGenerator().getNextAlias();
|
||||
InlineViewRef inlineViewRef = new InlineViewRef(alias, subquery, Arrays.asList(colAlias));
|
||||
try {
|
||||
inlineViewRef.analyze(analyzer);
|
||||
} catch (UserException e) {
|
||||
throw new AnalysisException(e.getMessage());
|
||||
}
|
||||
fromClause_.add(inlineViewRef);
|
||||
expr = new SlotRef(inlineViewRef.getAliasAsName(), colAlias);
|
||||
} else if (CollectionUtils.isNotEmpty(expr.getChildren())) {
|
||||
for (int i = 0; i < expr.getChildren().size(); ++i) {
|
||||
expr.setChild(i, rewriteSubquery(expr.getChild(i), analyzer));
|
||||
}
|
||||
}
|
||||
return expr;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toSql() {
|
||||
if (sqlString_ != null) {
|
||||
|
||||
@ -1,43 +1,44 @@
|
||||
package org.apache.doris.analysis;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.UUID;
|
||||
|
||||
import org.apache.doris.catalog.Catalog;
|
||||
import org.apache.doris.common.AnalysisException;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
import org.apache.doris.rewrite.ExprRewriter;
|
||||
import org.apache.doris.utframe.DorisAssert;
|
||||
import org.apache.doris.utframe.UtFrameUtils;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.Assert;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
|
||||
import java.util.UUID;
|
||||
|
||||
public class SelectStmtTest {
|
||||
private static String runningDir = "fe/mocked/DemoTest/" + UUID.randomUUID().toString() + "/";
|
||||
private static DorisAssert dorisAssert;
|
||||
|
||||
@Rule
|
||||
public ExpectedException expectedEx = ExpectedException.none();
|
||||
|
||||
@AfterClass
|
||||
public static void afterClass() throws Exception {
|
||||
FileUtils.deleteDirectory(new File(runningDir));
|
||||
public static void tearDown() throws Exception {
|
||||
UtFrameUtils.cleanDorisFeDir(runningDir);
|
||||
}
|
||||
|
||||
@BeforeClass
|
||||
public static void setUp() throws Exception {
|
||||
UtFrameUtils.createMinDorisCluster(runningDir);
|
||||
String createTblStmtStr = "create table db1.tbl1(k1 varchar(32), k2 varchar(32), k3 varchar(32), k4 int) "
|
||||
+ "AGGREGATE KEY(k1, k2,k3,k4) distributed by hash(k1) buckets 3 properties('replication_num' = '1');";
|
||||
dorisAssert = new DorisAssert();
|
||||
dorisAssert.withDatabase("db1").useDatabase("db1");
|
||||
dorisAssert.withTable(createTblStmtStr);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGroupingSets() throws Exception {
|
||||
ConnectContext ctx = UtFrameUtils.createDefaultCtx();
|
||||
UtFrameUtils.createMinDorisCluster(runningDir);
|
||||
String createDbStmtStr = "create database db1;";
|
||||
CreateDbStmt createDbStmt = (CreateDbStmt) UtFrameUtils.parseAndAnalyzeStmt(createDbStmtStr, ctx);
|
||||
Catalog.getCurrentCatalog().createDb(createDbStmt);
|
||||
System.out.println(Catalog.getCurrentCatalog().getDbNames());
|
||||
// 3. create table tbl1
|
||||
String createTblStmtStr = "create table db1.tbl1(k1 varchar(32), k2 varchar(32), k3 varchar(32), k4 int) "
|
||||
+ "AGGREGATE KEY(k1, k2,k3,k4) distributed by hash(k1) buckets 3 properties('replication_num' = '1');";
|
||||
CreateTableStmt createTableStmt = (CreateTableStmt) UtFrameUtils.parseAndAnalyzeStmt(createTblStmtStr, ctx);
|
||||
Catalog.getCurrentCatalog().createTable(createTableStmt);
|
||||
String selectStmtStr = "select k1,k2,MAX(k4) from db1.tbl1 GROUP BY GROUPING sets ((k1,k2),(k1),(k2),());";
|
||||
UtFrameUtils.parseAndAnalyzeStmt(selectStmtStr, ctx);
|
||||
String selectStmtStr2 = "select k1,k4,MAX(k4) from db1.tbl1 GROUP BY GROUPING sets ((k1,k4),(k1),(k4),());";
|
||||
@ -52,5 +53,29 @@ public class SelectStmtTest {
|
||||
UtFrameUtils.parseAndAnalyzeStmt(selectStmtStr4, ctx);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@Test
|
||||
public void testSubqueryInCase() throws Exception {
|
||||
ConnectContext ctx = UtFrameUtils.createDefaultCtx();
|
||||
String sql1 = "SELECT CASE\n" +
|
||||
" WHEN (\n" +
|
||||
" SELECT COUNT(*) / 2\n" +
|
||||
" FROM db1.tbl1\n" +
|
||||
" ) > k4 THEN (\n" +
|
||||
" SELECT AVG(k4)\n" +
|
||||
" FROM db1.tbl1\n" +
|
||||
" )\n" +
|
||||
" ELSE (\n" +
|
||||
" SELECT SUM(k4)\n" +
|
||||
" FROM db1.tbl1\n" +
|
||||
" )\n" +
|
||||
" END AS kk4\n" +
|
||||
"FROM db1.tbl1;";
|
||||
SelectStmt stmt = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql1, ctx);
|
||||
stmt.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
|
||||
Assert.assertEquals("SELECT CASE WHEN `$a$1`.`$c$1` > `k4` THEN `$a$2`.`$c$2` ELSE `$a$3`.`$c$3` END" +
|
||||
" AS `kk4` FROM `default_cluster:db1`.`tbl1` (SELECT count(*) / 2.0 AS `count(*) / 2.0` FROM " +
|
||||
"`default_cluster:db1`.`tbl1`) $a$1 (SELECT avg(`k4`) AS `avg(``k4``)` FROM" +
|
||||
" `default_cluster:db1`.`tbl1`) $a$2 (SELECT sum(`k4`) AS `sum(``k4``)` " +
|
||||
"FROM `default_cluster:db1`.`tbl1`) $a$3", stmt.toSql());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user