diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index c4fa3abd0b..98721f6c48 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -105,7 +105,6 @@ import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoOdbcScan;
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOneSide;
import org.apache.doris.nereids.rules.rewrite.PushDownDistinctThroughJoin;
-import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughAggregation;
import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughProject;
import org.apache.doris.nereids.rules.rewrite.PushDownLimit;
import org.apache.doris.nereids.rules.rewrite.PushDownLimitDistinctThroughJoin;
@@ -171,51 +170,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
// after doing NormalizeAggregate in analysis job
// we need run the following 2 rules to make AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION work
bottomUp(new PullUpProjectUnderApply()),
- topDown(
- /*
- * for subquery unnest, we need hand sql like
- *
- * SELECT *
- * FROM table1 AS t1
- * WHERE EXISTS
- * (SELECT `pk`
- * FROM table2 AS t2
- * WHERE t1.pk = t2 .pk
- * GROUP BY t2.pk
- * HAVING t2.pk > 0) ;
- *
- * before:
- * apply
- * / \
- * child Filter(t2.pk > 0)
- * |
- * Project(t2.pk)
- * |
- * agg
- * |
- * Project(t2.pk)
- * |
- * Filter(t1.pk=t2.pk)
- * |
- * child
- *
- * after:
- * apply
- * / \
- * child agg
- * |
- * Project(t2.pk)
- * |
- * Filter(t1.pk=t2.pk and t2.pk >0)
- * |
- * child
- *
- * then PullUpCorrelatedFilterUnderApplyAggregateProject rule can match the node pattern
- */
- new PushDownFilterThroughAggregation(),
- new PushDownFilterThroughProject(),
- new MergeFilters()
- ),
+ topDown(new PushDownFilterThroughProject()),
custom(RuleType.AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION,
AggScalarSubQueryToWindowFunction::new),
bottomUp(
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 1fda32a400..4515aaf55f 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -137,7 +137,9 @@ public enum RuleType {
UN_CORRELATED_APPLY_FILTER(RuleTypeClass.REWRITE),
UN_CORRELATED_APPLY_PROJECT_FILTER(RuleTypeClass.REWRITE),
UN_CORRELATED_APPLY_AGGREGATE_FILTER(RuleTypeClass.REWRITE),
+ UN_CORRELATED_APPLY_FILTER_AGGREGATE_FILTER(RuleTypeClass.REWRITE),
PULL_UP_CORRELATED_FILTER_UNDER_APPLY_AGGREGATE_PROJECT(RuleTypeClass.REWRITE),
+ PULL_UP_CORRELATED_FILTER_UNDER_APPLY_FILTER_AGGREGATE_PROJECT(RuleTypeClass.REWRITE),
SCALAR_APPLY_TO_JOIN(RuleTypeClass.REWRITE),
IN_APPLY_TO_JOIN(RuleTypeClass.REWRITE),
EXISTS_APPLY_TO_JOIN(RuleTypeClass.REWRITE),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpCorrelatedFilterUnderApplyAggregateProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpCorrelatedFilterUnderApplyAggregateProject.java
index e56f552f9f..309bd9a78b 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpCorrelatedFilterUnderApplyAggregateProject.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpCorrelatedFilterUnderApplyAggregateProject.java
@@ -36,8 +36,10 @@ import java.util.List;
*
* before:
* apply
- * / \
- * Input(output:b) agg
+ * / \
+ * Input(output:b) Filter(this node's existence depends on having clause's existence)
+ * |
+ * agg
* |
* Project(output:a)
* |
@@ -47,8 +49,10 @@ import java.util.List;
*
* after:
* apply
- * / \
- * Input(output:b) agg
+ * / \
+ * Input(output:b) Filter(this node's existence depends on having clause's existence)
+ * |
+ * agg
* |
* Filter(correlated predicate(Input.e = this.f)/Unapply predicate)
* |
@@ -57,27 +61,43 @@ import java.util.List;
* child
*
*/
-public class PullUpCorrelatedFilterUnderApplyAggregateProject extends OneRewriteRuleFactory {
+public class PullUpCorrelatedFilterUnderApplyAggregateProject implements RewriteRuleFactory {
@Override
- public Rule build() {
- return logicalApply(any(), logicalAggregate(logicalProject(logicalFilter())))
- .when(LogicalApply::isCorrelated).then(apply -> {
- LogicalAggregate>> agg = apply.right();
+ public List buildRules() {
+ return ImmutableList.of(logicalApply(any(), logicalAggregate(
+ logicalProject(logicalFilter()))).when(LogicalApply::isCorrelated).then(
+ PullUpCorrelatedFilterUnderApplyAggregateProject::pullUpCorrelatedFilter)
+ .toRule(RuleType.PULL_UP_CORRELATED_FILTER_UNDER_APPLY_AGGREGATE_PROJECT),
+ logicalApply(any(), logicalFilter((logicalAggregate(
+ logicalProject(logicalFilter()))))).when(LogicalApply::isCorrelated).then(
+ PullUpCorrelatedFilterUnderApplyAggregateProject::pullUpCorrelatedFilter)
+ .toRule(RuleType.PULL_UP_CORRELATED_FILTER_UNDER_APPLY_FILTER_AGGREGATE_PROJECT));
+ }
- LogicalProject> project = agg.child();
- LogicalFilter filter = project.child();
- List newProjects = Lists.newArrayList();
- newProjects.addAll(project.getProjects());
- filter.child().getOutput().forEach(slot -> {
- if (!newProjects.contains(slot)) {
- newProjects.add(slot);
- }
- });
+ private static LogicalApply, ?> pullUpCorrelatedFilter(LogicalApply, ?> apply) {
+ boolean isRightChildAgg = apply.right() instanceof LogicalAggregate;
+ // locate agg node
+ LogicalAggregate>> agg = isRightChildAgg
+ ? (LogicalAggregate>>) (apply.right())
+ : (LogicalAggregate>>) (apply.right().child(0));
- LogicalProject newProject = project.withProjectsAndChild(newProjects, filter.child());
- LogicalFilter newFilter = new LogicalFilter<>(filter.getConjuncts(), newProject);
- LogicalAggregate newAgg = agg.withChildren(ImmutableList.of(newFilter));
- return apply.withChildren(apply.left(), newAgg);
- }).toRule(RuleType.PULL_UP_CORRELATED_FILTER_UNDER_APPLY_AGGREGATE_PROJECT);
+ // pull up filter under the project
+ LogicalProject> project = agg.child();
+ LogicalFilter filter = project.child();
+ List newProjects = Lists.newArrayList();
+ newProjects.addAll(project.getProjects());
+
+ // filter may use all slots from its child, so add all the slots to newProjects
+ filter.child().getOutput().forEach(slot -> {
+ if (!newProjects.contains(slot)) {
+ newProjects.add(slot);
+ }
+ });
+
+ LogicalProject newProject = project.withProjectsAndChild(newProjects, filter.child());
+ LogicalFilter newFilter = new LogicalFilter<>(filter.getConjuncts(), newProject);
+ LogicalAggregate newAgg = agg.withChildren(ImmutableList.of(newFilter));
+ return (LogicalApply, ?>) (apply.withChildren(apply.left(),
+ isRightChildAgg ? newAgg : apply.right().withChildren(newAgg)));
}
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyAggregateFilter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyAggregateFilter.java
index b0b62f2e9b..211e76710c 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyAggregateFilter.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyAggregateFilter.java
@@ -44,52 +44,69 @@ import java.util.Map;
* the output column is the correlated column and the input column.
*
* before:
- * apply
- * / \
- * Input(output:b) agg(output:fn; group by:null)
+ * apply
+ * / \
+ * Input(output:b) Filter(this node's existence depends on having clause's existence)
+ * |
+ * agg(output:fn; group by:null)
* |
* Filter(correlated predicate(Input.e = this.f)/Unapply predicate)
*
* end:
* apply(correlated predicate(Input.e = this.f))
* / \
- * Input(output:b) agg(output:fn,this.f; group by:this.f)
+ * Input(output:b) Filter(this node's existence depends on having clause's existence)
+ * |
+ * agg(output:fn,this.f; group by:this.f)
* |
* Filter(Uncorrelated predicate)
*
*/
-public class UnCorrelatedApplyAggregateFilter extends OneRewriteRuleFactory {
+public class UnCorrelatedApplyAggregateFilter implements RewriteRuleFactory {
@Override
- public Rule build() {
- return logicalApply(any(), logicalAggregate(logicalFilter())).when(LogicalApply::isCorrelated).then(apply -> {
- LogicalAggregate> agg = apply.right();
- LogicalFilter filter = agg.child();
- Map> split = Utils.splitCorrelatedConjuncts(
- filter.getConjuncts(), apply.getCorrelationSlot());
- List correlatedPredicate = split.get(true);
- List unCorrelatedPredicate = split.get(false);
+ public List buildRules() {
+ return ImmutableList.of(
+ logicalApply(any(), logicalAggregate(logicalFilter()))
+ .when(LogicalApply::isCorrelated)
+ .then(UnCorrelatedApplyAggregateFilter::pullUpCorrelatedFilter)
+ .toRule(RuleType.UN_CORRELATED_APPLY_AGGREGATE_FILTER),
+ logicalApply(any(), logicalFilter(logicalAggregate(logicalFilter())))
+ .when(LogicalApply::isCorrelated)
+ .then(UnCorrelatedApplyAggregateFilter::pullUpCorrelatedFilter)
+ .toRule(RuleType.UN_CORRELATED_APPLY_FILTER_AGGREGATE_FILTER));
+ }
- // the representative has experienced the rule and added the correlated predicate to the apply node
- if (correlatedPredicate.isEmpty()) {
- return apply;
- }
+ private static LogicalApply, ?> pullUpCorrelatedFilter(LogicalApply, ?> apply) {
+ boolean isRightChildAgg = apply.right() instanceof LogicalAggregate;
+ // locate agg node
+ LogicalAggregate> agg =
+ isRightChildAgg ? (LogicalAggregate>) (apply.right())
+ : (LogicalAggregate>) (apply.right().child(0));
+ LogicalFilter filter = agg.child();
+ // split filter conjuncts to correlated and unCorrelated ones
+ Map> split =
+ Utils.splitCorrelatedConjuncts(filter.getConjuncts(), apply.getCorrelationSlot());
+ List correlatedPredicate = split.get(true);
+ List unCorrelatedPredicate = split.get(false);
- List newAggOutput = new ArrayList<>(agg.getOutputExpressions());
- List newGroupby = Utils.getCorrelatedSlots(correlatedPredicate,
- apply.getCorrelationSlot());
- newGroupby.addAll(agg.getGroupByExpressions());
- newAggOutput.addAll(newGroupby.stream()
- .map(NamedExpression.class::cast)
- .collect(ImmutableList.toImmutableList()));
- LogicalAggregate newAgg = new LogicalAggregate<>(
- newGroupby, newAggOutput,
- PlanUtils.filterOrSelf(ImmutableSet.copyOf(unCorrelatedPredicate), filter.child()));
- return new LogicalApply<>(apply.getCorrelationSlot(),
- apply.getSubqueryExpr(),
- ExpressionUtils.optionalAnd(correlatedPredicate),
- apply.getMarkJoinSlotReference(),
- apply.isNeedAddSubOutputToProjects(),
- apply.isInProject(), apply.isMarkJoinSlotNotNull(), apply.left(), newAgg);
- }).toRule(RuleType.UN_CORRELATED_APPLY_AGGREGATE_FILTER);
+ // the representative has experienced the rule and added the correlated predicate to the apply node
+ if (correlatedPredicate.isEmpty()) {
+ return apply;
+ }
+
+ // pull up correlated filter into apply node
+ List newAggOutput = new ArrayList<>(agg.getOutputExpressions());
+ List newGroupby =
+ Utils.getCorrelatedSlots(correlatedPredicate, apply.getCorrelationSlot());
+ newGroupby.addAll(agg.getGroupByExpressions());
+ newAggOutput.addAll(newGroupby.stream().map(NamedExpression.class::cast)
+ .collect(ImmutableList.toImmutableList()));
+ LogicalAggregate newAgg = new LogicalAggregate<>(newGroupby, newAggOutput,
+ PlanUtils.filterOrSelf(ImmutableSet.copyOf(unCorrelatedPredicate), filter.child()));
+ return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(),
+ ExpressionUtils.optionalAnd(correlatedPredicate), apply.getMarkJoinSlotReference(),
+ apply.isNeedAddSubOutputToProjects(), apply.isInProject(),
+ apply.isMarkJoinSlotNotNull(), apply.left(),
+ isRightChildAgg ? newAgg : apply.right().withChildren(newAgg));
}
}
diff --git a/regression-test/data/nereids_p0/subquery/subquery_unnesting.out b/regression-test/data/nereids_p0/subquery/subquery_unnesting.out
index e262f19b21..5124b5f970 100644
--- a/regression-test/data/nereids_p0/subquery/subquery_unnesting.out
+++ b/regression-test/data/nereids_p0/subquery/subquery_unnesting.out
@@ -1508,3 +1508,14 @@
-- !select62 --
+-- !select63 --
+1 \N
+1 2
+1 3
+2 4
+2 5
+3 3
+3 4
+20 2
+22 3
+24 4
diff --git a/regression-test/suites/nereids_p0/subquery/subquery_unnesting.groovy b/regression-test/suites/nereids_p0/subquery/subquery_unnesting.groovy
index 174a72df62..397c5cfbf2 100644
--- a/regression-test/suites/nereids_p0/subquery/subquery_unnesting.groovy
+++ b/regression-test/suites/nereids_p0/subquery/subquery_unnesting.groovy
@@ -132,4 +132,5 @@ suite ("subquery_unnesting") {
qt_select60 """select * from t1 where exists(select distinct k1 from t2 where t1.k1 > t2.k3 or t1.k2 < t2.v1) order by t1.k1, t1.k2;"""
qt_select61 """SELECT * FROM t1 AS t1 WHERE EXISTS (SELECT k1 FROM t1 AS t2 WHERE t1.k1 <> t2.k1 + 7 GROUP BY k1 HAVING k1 >= 100);"""
qt_select62 """select * from t1 left semi join ( select * from t1 where t1.k1 < -1 ) l on true;"""
+ qt_select63 """SELECT * FROM t1 AS t1 WHERE EXISTS (SELECT k1 FROM t1 AS t2 WHERE t1.k1 <> t2.k1 + 7 GROUP BY k1 HAVING sum(k2) >= 1) order by t1.k1, t1.k2;"""
}