From 929b31bd3c13d8fbc69680706125faf6bdfe2bfb Mon Sep 17 00:00:00 2001 From: zhengshiJ <32082872+zhengshiJ@users.noreply.github.com> Date: Fri, 3 Feb 2023 18:20:47 +0800 Subject: [PATCH] [Feature](Nereids) Support CaseWhen with subquery (#16385) Co-authored-by: jianghaochen --- .../apache/doris/nereids/rules/RuleType.java | 1 + .../rules/analysis/AnalyzeSubquery.java | 32 ++++++++++++++++++- .../sub_query_correlated.out | 9 ++++++ .../sub_query_correlated.groovy | 18 +++++++++++ 4 files changed, 59 insertions(+), 1 deletion(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 52c0756a82..56edfa23c4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -103,6 +103,7 @@ public enum RuleType { // subquery analyze ANALYZE_FILTER_SUBQUERY(RuleTypeClass.REWRITE), + ANALYZE_PROJECT_SUBQUERY(RuleTypeClass.REWRITE), // subquery rewrite rule ELIMINATE_LIMIT_UNDER_APPLY(RuleTypeClass.REWRITE), ELIMINATE_SORT_UNDER_APPLY(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AnalyzeSubquery.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AnalyzeSubquery.java index d2d427f4ff..1e85e87797 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AnalyzeSubquery.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AnalyzeSubquery.java @@ -20,6 +20,8 @@ package org.apache.doris.nereids.rules.analysis; import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.CaseWhen; import org.apache.doris.nereids.trees.expressions.Exists; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InSubquery; @@ -38,6 +40,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Optional; import java.util.Set; @@ -67,7 +70,30 @@ public class AnalyzeSubquery implements AnalysisRuleFactory { subqueryExprs, filter.child(), ctx.cascadesContext )); }) - ) + ), + RuleType.ANALYZE_PROJECT_SUBQUERY.build( + logicalProject().thenApply(ctx -> { + LogicalProject project = ctx.root; + Set subqueryExprs = new HashSet<>(); + project.getProjects().stream() + .filter(Alias.class::isInstance) + .map(Alias.class::cast) + .filter(alias -> alias.child() instanceof CaseWhen) + .forEach(alias -> alias.child().children().stream() + .forEach(e -> + subqueryExprs.addAll(e.collect(SubqueryExpr.class::isInstance)))); + if (subqueryExprs.isEmpty()) { + return project; + } + + return new LogicalProject(project.getProjects().stream() + .map(p -> p.withChildren(new ReplaceSubquery().replace(p))) + .collect(ImmutableList.toImmutableList()), + analyzedSubquery( + subqueryExprs, project.child(), ctx.cascadesContext + )); + }) + ) ); } @@ -117,6 +143,10 @@ public class AnalyzeSubquery implements AnalysisRuleFactory { .collect(ImmutableSet.toImmutableSet()); } + public Expression replace(Expression expressions) { + return expressions.accept(this, null); + } + @Override public Expression visitExistsSubquery(Exists exists, Void context) { return BooleanLiteral.TRUE; diff --git a/regression-test/data/nereids_syntax_p0/sub_query_correlated.out b/regression-test/data/nereids_syntax_p0/sub_query_correlated.out index 30700befa7..42a9bb5907 100644 --- a/regression-test/data/nereids_syntax_p0/sub_query_correlated.out +++ b/regression-test/data/nereids_syntax_p0/sub_query_correlated.out @@ -199,3 +199,12 @@ 1 20 +-- !case_when_subquery -- +4.0 +4.0 +20.0 +20.0 +20.0 +20.0 +20.0 + diff --git a/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy b/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy index 480043ca02..40117fe13c 100644 --- a/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy +++ b/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy @@ -262,4 +262,22 @@ suite ("sub_query_correlated") { qt_scalar_subquery_with_disjunctions """ SELECT DISTINCT k1 FROM sub_query_correlated_subquery1 i1 WHERE ((SELECT count(*) FROM sub_query_correlated_subquery1 WHERE ((k1 = i1.k1) AND (k2 = 2)) or ((k1 = i1.k1) AND (k2 = 1)) ) > 0); """ + + //--------subquery case when----------- + qt_case_when_subquery """ + SELECT CASE + WHEN ( + SELECT COUNT(*) / 2 + FROM sub_query_correlated_subquery3 + ) > v1 THEN ( + SELECT AVG(v1) + FROM sub_query_correlated_subquery3 + ) + ELSE ( + SELECT SUM(v2) + FROM sub_query_correlated_subquery3 + ) + END AS kk4 + FROM sub_query_correlated_subquery3 ; + """ }