diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index c4fa3abd0b..98721f6c48 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -105,7 +105,6 @@ import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoOdbcScan; import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoin; import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOneSide; import org.apache.doris.nereids.rules.rewrite.PushDownDistinctThroughJoin; -import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughAggregation; import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughProject; import org.apache.doris.nereids.rules.rewrite.PushDownLimit; import org.apache.doris.nereids.rules.rewrite.PushDownLimitDistinctThroughJoin; @@ -171,51 +170,7 @@ public class Rewriter extends AbstractBatchJobExecutor { // after doing NormalizeAggregate in analysis job // we need run the following 2 rules to make AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION work bottomUp(new PullUpProjectUnderApply()), - topDown( - /* - * for subquery unnest, we need hand sql like - * - * SELECT * - * FROM table1 AS t1 - * WHERE EXISTS - * (SELECT `pk` - * FROM table2 AS t2 - * WHERE t1.pk = t2 .pk - * GROUP BY t2.pk - * HAVING t2.pk > 0) ; - * - * before: - * apply - * / \ - * child Filter(t2.pk > 0) - * | - * Project(t2.pk) - * | - * agg - * | - * Project(t2.pk) - * | - * Filter(t1.pk=t2.pk) - * | - * child - * - * after: - * apply - * / \ - * child agg - * | - * Project(t2.pk) - * | - * Filter(t1.pk=t2.pk and t2.pk >0) - * | - * child - * - * then PullUpCorrelatedFilterUnderApplyAggregateProject rule can match the node pattern - */ - new PushDownFilterThroughAggregation(), - new PushDownFilterThroughProject(), - new MergeFilters() - ), + topDown(new PushDownFilterThroughProject()), custom(RuleType.AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION, AggScalarSubQueryToWindowFunction::new), bottomUp( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 1fda32a400..4515aaf55f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -137,7 +137,9 @@ public enum RuleType { UN_CORRELATED_APPLY_FILTER(RuleTypeClass.REWRITE), UN_CORRELATED_APPLY_PROJECT_FILTER(RuleTypeClass.REWRITE), UN_CORRELATED_APPLY_AGGREGATE_FILTER(RuleTypeClass.REWRITE), + UN_CORRELATED_APPLY_FILTER_AGGREGATE_FILTER(RuleTypeClass.REWRITE), PULL_UP_CORRELATED_FILTER_UNDER_APPLY_AGGREGATE_PROJECT(RuleTypeClass.REWRITE), + PULL_UP_CORRELATED_FILTER_UNDER_APPLY_FILTER_AGGREGATE_PROJECT(RuleTypeClass.REWRITE), SCALAR_APPLY_TO_JOIN(RuleTypeClass.REWRITE), IN_APPLY_TO_JOIN(RuleTypeClass.REWRITE), EXISTS_APPLY_TO_JOIN(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpCorrelatedFilterUnderApplyAggregateProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpCorrelatedFilterUnderApplyAggregateProject.java index e56f552f9f..309bd9a78b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpCorrelatedFilterUnderApplyAggregateProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpCorrelatedFilterUnderApplyAggregateProject.java @@ -36,8 +36,10 @@ import java.util.List; *
  * before:
  *              apply
- *         /              \
- * Input(output:b)        agg
+ *         /             \
+ *  Input(output:b)     Filter(this node's existence depends on having clause's existence)
+ *                         |
+ *                        agg
  *                         |
  *                  Project(output:a)
  *                         |
@@ -47,8 +49,10 @@ import java.util.List;
  *
  * after:
  *              apply
- *         /              \
- * Input(output:b)        agg
+ *         /             \
+ *  Input(output:b)     Filter(this node's existence depends on having clause's existence)
+ *                         |
+ *                        agg
  *                         |
  *              Filter(correlated predicate(Input.e = this.f)/Unapply predicate)
  *                         |
@@ -57,27 +61,43 @@ import java.util.List;
  *                         child
  * 
*/ -public class PullUpCorrelatedFilterUnderApplyAggregateProject extends OneRewriteRuleFactory { +public class PullUpCorrelatedFilterUnderApplyAggregateProject implements RewriteRuleFactory { @Override - public Rule build() { - return logicalApply(any(), logicalAggregate(logicalProject(logicalFilter()))) - .when(LogicalApply::isCorrelated).then(apply -> { - LogicalAggregate>> agg = apply.right(); + public List buildRules() { + return ImmutableList.of(logicalApply(any(), logicalAggregate( + logicalProject(logicalFilter()))).when(LogicalApply::isCorrelated).then( + PullUpCorrelatedFilterUnderApplyAggregateProject::pullUpCorrelatedFilter) + .toRule(RuleType.PULL_UP_CORRELATED_FILTER_UNDER_APPLY_AGGREGATE_PROJECT), + logicalApply(any(), logicalFilter((logicalAggregate( + logicalProject(logicalFilter()))))).when(LogicalApply::isCorrelated).then( + PullUpCorrelatedFilterUnderApplyAggregateProject::pullUpCorrelatedFilter) + .toRule(RuleType.PULL_UP_CORRELATED_FILTER_UNDER_APPLY_FILTER_AGGREGATE_PROJECT)); + } - LogicalProject> project = agg.child(); - LogicalFilter filter = project.child(); - List newProjects = Lists.newArrayList(); - newProjects.addAll(project.getProjects()); - filter.child().getOutput().forEach(slot -> { - if (!newProjects.contains(slot)) { - newProjects.add(slot); - } - }); + private static LogicalApply pullUpCorrelatedFilter(LogicalApply apply) { + boolean isRightChildAgg = apply.right() instanceof LogicalAggregate; + // locate agg node + LogicalAggregate>> agg = isRightChildAgg + ? (LogicalAggregate>>) (apply.right()) + : (LogicalAggregate>>) (apply.right().child(0)); - LogicalProject newProject = project.withProjectsAndChild(newProjects, filter.child()); - LogicalFilter newFilter = new LogicalFilter<>(filter.getConjuncts(), newProject); - LogicalAggregate newAgg = agg.withChildren(ImmutableList.of(newFilter)); - return apply.withChildren(apply.left(), newAgg); - }).toRule(RuleType.PULL_UP_CORRELATED_FILTER_UNDER_APPLY_AGGREGATE_PROJECT); + // pull up filter under the project + LogicalProject> project = agg.child(); + LogicalFilter filter = project.child(); + List newProjects = Lists.newArrayList(); + newProjects.addAll(project.getProjects()); + + // filter may use all slots from its child, so add all the slots to newProjects + filter.child().getOutput().forEach(slot -> { + if (!newProjects.contains(slot)) { + newProjects.add(slot); + } + }); + + LogicalProject newProject = project.withProjectsAndChild(newProjects, filter.child()); + LogicalFilter newFilter = new LogicalFilter<>(filter.getConjuncts(), newProject); + LogicalAggregate newAgg = agg.withChildren(ImmutableList.of(newFilter)); + return (LogicalApply) (apply.withChildren(apply.left(), + isRightChildAgg ? newAgg : apply.right().withChildren(newAgg))); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyAggregateFilter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyAggregateFilter.java index b0b62f2e9b..211e76710c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyAggregateFilter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyAggregateFilter.java @@ -44,52 +44,69 @@ import java.util.Map; * the output column is the correlated column and the input column. *
  * before:
- *              apply
- *          /              \
- * Input(output:b)    agg(output:fn; group by:null)
+ *                 apply
+ *             /          \
+ *     Input(output:b)   Filter(this node's existence depends on having clause's existence)
+ *                              |
+ *                         agg(output:fn; group by:null)
  *                              |
  *              Filter(correlated predicate(Input.e = this.f)/Unapply predicate)
  *
  * end:
  *          apply(correlated predicate(Input.e = this.f))
  *         /              \
- * Input(output:b)    agg(output:fn,this.f; group by:this.f)
+ * Input(output:b)   Filter(this node's existence depends on having clause's existence)
+ *                             |
+ *                        agg(output:fn,this.f; group by:this.f)
  *                              |
  *                    Filter(Uncorrelated predicate)
  * 
*/ -public class UnCorrelatedApplyAggregateFilter extends OneRewriteRuleFactory { +public class UnCorrelatedApplyAggregateFilter implements RewriteRuleFactory { @Override - public Rule build() { - return logicalApply(any(), logicalAggregate(logicalFilter())).when(LogicalApply::isCorrelated).then(apply -> { - LogicalAggregate> agg = apply.right(); - LogicalFilter filter = agg.child(); - Map> split = Utils.splitCorrelatedConjuncts( - filter.getConjuncts(), apply.getCorrelationSlot()); - List correlatedPredicate = split.get(true); - List unCorrelatedPredicate = split.get(false); + public List buildRules() { + return ImmutableList.of( + logicalApply(any(), logicalAggregate(logicalFilter())) + .when(LogicalApply::isCorrelated) + .then(UnCorrelatedApplyAggregateFilter::pullUpCorrelatedFilter) + .toRule(RuleType.UN_CORRELATED_APPLY_AGGREGATE_FILTER), + logicalApply(any(), logicalFilter(logicalAggregate(logicalFilter()))) + .when(LogicalApply::isCorrelated) + .then(UnCorrelatedApplyAggregateFilter::pullUpCorrelatedFilter) + .toRule(RuleType.UN_CORRELATED_APPLY_FILTER_AGGREGATE_FILTER)); + } - // the representative has experienced the rule and added the correlated predicate to the apply node - if (correlatedPredicate.isEmpty()) { - return apply; - } + private static LogicalApply pullUpCorrelatedFilter(LogicalApply apply) { + boolean isRightChildAgg = apply.right() instanceof LogicalAggregate; + // locate agg node + LogicalAggregate> agg = + isRightChildAgg ? (LogicalAggregate>) (apply.right()) + : (LogicalAggregate>) (apply.right().child(0)); + LogicalFilter filter = agg.child(); + // split filter conjuncts to correlated and unCorrelated ones + Map> split = + Utils.splitCorrelatedConjuncts(filter.getConjuncts(), apply.getCorrelationSlot()); + List correlatedPredicate = split.get(true); + List unCorrelatedPredicate = split.get(false); - List newAggOutput = new ArrayList<>(agg.getOutputExpressions()); - List newGroupby = Utils.getCorrelatedSlots(correlatedPredicate, - apply.getCorrelationSlot()); - newGroupby.addAll(agg.getGroupByExpressions()); - newAggOutput.addAll(newGroupby.stream() - .map(NamedExpression.class::cast) - .collect(ImmutableList.toImmutableList())); - LogicalAggregate newAgg = new LogicalAggregate<>( - newGroupby, newAggOutput, - PlanUtils.filterOrSelf(ImmutableSet.copyOf(unCorrelatedPredicate), filter.child())); - return new LogicalApply<>(apply.getCorrelationSlot(), - apply.getSubqueryExpr(), - ExpressionUtils.optionalAnd(correlatedPredicate), - apply.getMarkJoinSlotReference(), - apply.isNeedAddSubOutputToProjects(), - apply.isInProject(), apply.isMarkJoinSlotNotNull(), apply.left(), newAgg); - }).toRule(RuleType.UN_CORRELATED_APPLY_AGGREGATE_FILTER); + // the representative has experienced the rule and added the correlated predicate to the apply node + if (correlatedPredicate.isEmpty()) { + return apply; + } + + // pull up correlated filter into apply node + List newAggOutput = new ArrayList<>(agg.getOutputExpressions()); + List newGroupby = + Utils.getCorrelatedSlots(correlatedPredicate, apply.getCorrelationSlot()); + newGroupby.addAll(agg.getGroupByExpressions()); + newAggOutput.addAll(newGroupby.stream().map(NamedExpression.class::cast) + .collect(ImmutableList.toImmutableList())); + LogicalAggregate newAgg = new LogicalAggregate<>(newGroupby, newAggOutput, + PlanUtils.filterOrSelf(ImmutableSet.copyOf(unCorrelatedPredicate), filter.child())); + return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(), + ExpressionUtils.optionalAnd(correlatedPredicate), apply.getMarkJoinSlotReference(), + apply.isNeedAddSubOutputToProjects(), apply.isInProject(), + apply.isMarkJoinSlotNotNull(), apply.left(), + isRightChildAgg ? newAgg : apply.right().withChildren(newAgg)); } } diff --git a/regression-test/data/nereids_p0/subquery/subquery_unnesting.out b/regression-test/data/nereids_p0/subquery/subquery_unnesting.out index e262f19b21..5124b5f970 100644 --- a/regression-test/data/nereids_p0/subquery/subquery_unnesting.out +++ b/regression-test/data/nereids_p0/subquery/subquery_unnesting.out @@ -1508,3 +1508,14 @@ -- !select62 -- +-- !select63 -- +1 \N +1 2 +1 3 +2 4 +2 5 +3 3 +3 4 +20 2 +22 3 +24 4 diff --git a/regression-test/suites/nereids_p0/subquery/subquery_unnesting.groovy b/regression-test/suites/nereids_p0/subquery/subquery_unnesting.groovy index 174a72df62..397c5cfbf2 100644 --- a/regression-test/suites/nereids_p0/subquery/subquery_unnesting.groovy +++ b/regression-test/suites/nereids_p0/subquery/subquery_unnesting.groovy @@ -132,4 +132,5 @@ suite ("subquery_unnesting") { qt_select60 """select * from t1 where exists(select distinct k1 from t2 where t1.k1 > t2.k3 or t1.k2 < t2.v1) order by t1.k1, t1.k2;""" qt_select61 """SELECT * FROM t1 AS t1 WHERE EXISTS (SELECT k1 FROM t1 AS t2 WHERE t1.k1 <> t2.k1 + 7 GROUP BY k1 HAVING k1 >= 100);""" qt_select62 """select * from t1 left semi join ( select * from t1 where t1.k1 < -1 ) l on true;""" + qt_select63 """SELECT * FROM t1 AS t1 WHERE EXISTS (SELECT k1 FROM t1 AS t2 WHERE t1.k1 <> t2.k1 + 7 GROUP BY k1 HAVING sum(k2) >= 1) order by t1.k1, t1.k2;""" }