[Feature](Nereids) Support CaseWhen with subquery (#16385)

Co-authored-by: jianghaochen <jianghaochen@meituan.com>
This commit is contained in:
zhengshiJ
2023-02-03 18:20:47 +08:00
committed by GitHub
parent 3891083474
commit 929b31bd3c
4 changed files with 59 additions and 1 deletions

View File

@ -103,6 +103,7 @@ public enum RuleType {
// subquery analyze
ANALYZE_FILTER_SUBQUERY(RuleTypeClass.REWRITE),
ANALYZE_PROJECT_SUBQUERY(RuleTypeClass.REWRITE),
// subquery rewrite rule
ELIMINATE_LIMIT_UNDER_APPLY(RuleTypeClass.REWRITE),
ELIMINATE_SORT_UNDER_APPLY(RuleTypeClass.REWRITE),

View File

@ -20,6 +20,8 @@ package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Exists;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InSubquery;
@ -38,6 +40,7 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
@ -67,7 +70,30 @@ public class AnalyzeSubquery implements AnalysisRuleFactory {
subqueryExprs, filter.child(), ctx.cascadesContext
));
})
)
),
RuleType.ANALYZE_PROJECT_SUBQUERY.build(
logicalProject().thenApply(ctx -> {
LogicalProject<GroupPlan> project = ctx.root;
Set<SubqueryExpr> subqueryExprs = new HashSet<>();
project.getProjects().stream()
.filter(Alias.class::isInstance)
.map(Alias.class::cast)
.filter(alias -> alias.child() instanceof CaseWhen)
.forEach(alias -> alias.child().children().stream()
.forEach(e ->
subqueryExprs.addAll(e.collect(SubqueryExpr.class::isInstance))));
if (subqueryExprs.isEmpty()) {
return project;
}
return new LogicalProject(project.getProjects().stream()
.map(p -> p.withChildren(new ReplaceSubquery().replace(p)))
.collect(ImmutableList.toImmutableList()),
analyzedSubquery(
subqueryExprs, project.child(), ctx.cascadesContext
));
})
)
);
}
@ -117,6 +143,10 @@ public class AnalyzeSubquery implements AnalysisRuleFactory {
.collect(ImmutableSet.toImmutableSet());
}
public Expression replace(Expression expressions) {
return expressions.accept(this, null);
}
@Override
public Expression visitExistsSubquery(Exists exists, Void context) {
return BooleanLiteral.TRUE;

View File

@ -199,3 +199,12 @@
1
20
-- !case_when_subquery --
4.0
4.0
20.0
20.0
20.0
20.0
20.0

View File

@ -262,4 +262,22 @@ suite ("sub_query_correlated") {
qt_scalar_subquery_with_disjunctions """
SELECT DISTINCT k1 FROM sub_query_correlated_subquery1 i1 WHERE ((SELECT count(*) FROM sub_query_correlated_subquery1 WHERE ((k1 = i1.k1) AND (k2 = 2)) or ((k1 = i1.k1) AND (k2 = 1)) ) > 0);
"""
//--------subquery case when-----------
qt_case_when_subquery """
SELECT CASE
WHEN (
SELECT COUNT(*) / 2
FROM sub_query_correlated_subquery3
) > v1 THEN (
SELECT AVG(v1)
FROM sub_query_correlated_subquery3
)
ELSE (
SELECT SUM(v2)
FROM sub_query_correlated_subquery3
)
END AS kk4
FROM sub_query_correlated_subquery3 ;
"""
}