[Feature](Nereids) Support CaseWhen with subquery (#16385)
Co-authored-by: jianghaochen <jianghaochen@meituan.com>
This commit is contained in:
@ -103,6 +103,7 @@ public enum RuleType {
|
||||
|
||||
// subquery analyze
|
||||
ANALYZE_FILTER_SUBQUERY(RuleTypeClass.REWRITE),
|
||||
ANALYZE_PROJECT_SUBQUERY(RuleTypeClass.REWRITE),
|
||||
// subquery rewrite rule
|
||||
ELIMINATE_LIMIT_UNDER_APPLY(RuleTypeClass.REWRITE),
|
||||
ELIMINATE_SORT_UNDER_APPLY(RuleTypeClass.REWRITE),
|
||||
|
||||
@ -20,6 +20,8 @@ package org.apache.doris.nereids.rules.analysis;
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.CaseWhen;
|
||||
import org.apache.doris.nereids.trees.expressions.Exists;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.InSubquery;
|
||||
@ -38,6 +40,7 @@ import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
@ -67,7 +70,30 @@ public class AnalyzeSubquery implements AnalysisRuleFactory {
|
||||
subqueryExprs, filter.child(), ctx.cascadesContext
|
||||
));
|
||||
})
|
||||
)
|
||||
),
|
||||
RuleType.ANALYZE_PROJECT_SUBQUERY.build(
|
||||
logicalProject().thenApply(ctx -> {
|
||||
LogicalProject<GroupPlan> project = ctx.root;
|
||||
Set<SubqueryExpr> subqueryExprs = new HashSet<>();
|
||||
project.getProjects().stream()
|
||||
.filter(Alias.class::isInstance)
|
||||
.map(Alias.class::cast)
|
||||
.filter(alias -> alias.child() instanceof CaseWhen)
|
||||
.forEach(alias -> alias.child().children().stream()
|
||||
.forEach(e ->
|
||||
subqueryExprs.addAll(e.collect(SubqueryExpr.class::isInstance))));
|
||||
if (subqueryExprs.isEmpty()) {
|
||||
return project;
|
||||
}
|
||||
|
||||
return new LogicalProject(project.getProjects().stream()
|
||||
.map(p -> p.withChildren(new ReplaceSubquery().replace(p)))
|
||||
.collect(ImmutableList.toImmutableList()),
|
||||
analyzedSubquery(
|
||||
subqueryExprs, project.child(), ctx.cascadesContext
|
||||
));
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@ -117,6 +143,10 @@ public class AnalyzeSubquery implements AnalysisRuleFactory {
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
}
|
||||
|
||||
public Expression replace(Expression expressions) {
|
||||
return expressions.accept(this, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitExistsSubquery(Exists exists, Void context) {
|
||||
return BooleanLiteral.TRUE;
|
||||
|
||||
@ -199,3 +199,12 @@
|
||||
1
|
||||
20
|
||||
|
||||
-- !case_when_subquery --
|
||||
4.0
|
||||
4.0
|
||||
20.0
|
||||
20.0
|
||||
20.0
|
||||
20.0
|
||||
20.0
|
||||
|
||||
|
||||
@ -262,4 +262,22 @@ suite ("sub_query_correlated") {
|
||||
qt_scalar_subquery_with_disjunctions """
|
||||
SELECT DISTINCT k1 FROM sub_query_correlated_subquery1 i1 WHERE ((SELECT count(*) FROM sub_query_correlated_subquery1 WHERE ((k1 = i1.k1) AND (k2 = 2)) or ((k1 = i1.k1) AND (k2 = 1)) ) > 0);
|
||||
"""
|
||||
|
||||
//--------subquery case when-----------
|
||||
qt_case_when_subquery """
|
||||
SELECT CASE
|
||||
WHEN (
|
||||
SELECT COUNT(*) / 2
|
||||
FROM sub_query_correlated_subquery3
|
||||
) > v1 THEN (
|
||||
SELECT AVG(v1)
|
||||
FROM sub_query_correlated_subquery3
|
||||
)
|
||||
ELSE (
|
||||
SELECT SUM(v2)
|
||||
FROM sub_query_correlated_subquery3
|
||||
)
|
||||
END AS kk4
|
||||
FROM sub_query_correlated_subquery3 ;
|
||||
"""
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user