diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index ab79b19de1..028a0428d4 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -99,6 +99,7 @@ import org.apache.doris.nereids.rules.rewrite.PushDownLimitDistinctThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownMinMaxThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownSumThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownTopNThroughJoin;
+import org.apache.doris.nereids.rules.rewrite.PushDownTopNThroughUnion;
import org.apache.doris.nereids.rules.rewrite.PushDownTopNThroughWindow;
import org.apache.doris.nereids.rules.rewrite.PushFilterInsideJoin;
import org.apache.doris.nereids.rules.rewrite.PushProjectIntoOneRowRelation;
@@ -296,7 +297,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
new PushDownLimit(),
new PushDownTopNThroughJoin(),
new PushDownLimitDistinctThroughJoin(),
- new PushDownTopNThroughWindow()
+ new PushDownTopNThroughWindow(),
+ new PushDownTopNThroughUnion()
),
topDown(new CreatePartitionTopNFromWindow()),
topDown(
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index e91b91360e..7afc0123aa 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -256,10 +256,11 @@ public enum RuleType {
PUSH_LIMIT_THROUGH_WINDOW(RuleTypeClass.REWRITE),
LIMIT_SORT_TO_TOP_N(RuleTypeClass.REWRITE),
// topN push down
- PUSH_TOP_N_THROUGH_JOIN(RuleTypeClass.REWRITE),
- PUSH_TOP_N_THROUGH_PROJECT_JOIN(RuleTypeClass.REWRITE),
- PUSH_TOP_N_THROUGH_PROJECT_WINDOW(RuleTypeClass.REWRITE),
- PUSH_TOP_N_THROUGH_WINDOW(RuleTypeClass.REWRITE),
+ PUSH_DOWN_TOP_N_THROUGH_JOIN(RuleTypeClass.REWRITE),
+ PUSH_DOWN_TOP_N_THROUGH_PROJECT_JOIN(RuleTypeClass.REWRITE),
+ PUSH_DOWN_TOP_N_THROUGH_PROJECT_WINDOW(RuleTypeClass.REWRITE),
+ PUSH_DOWN_TOP_N_THROUGH_WINDOW(RuleTypeClass.REWRITE),
+ PUSH_DOWN_TOP_N_THROUGH_UNION(RuleTypeClass.REWRITE),
// limit distinct push down
PUSH_LIMIT_DISTINCT_THROUGH_JOIN(RuleTypeClass.REWRITE),
PUSH_LIMIT_DISTINCT_THROUGH_PROJECT_JOIN(RuleTypeClass.REWRITE),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNThroughJoin.java
index de4d5c9725..28a7f2688b 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNThroughJoin.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNThroughJoin.java
@@ -53,7 +53,7 @@ public class PushDownTopNThroughJoin implements RewriteRuleFactory {
}
return topN.withChildren(newJoin);
})
- .toRule(RuleType.PUSH_TOP_N_THROUGH_JOIN),
+ .toRule(RuleType.PUSH_DOWN_TOP_N_THROUGH_JOIN),
// topN -> project -> join
logicalTopN(logicalProject(logicalJoin()))
@@ -79,7 +79,7 @@ public class PushDownTopNThroughJoin implements RewriteRuleFactory {
return null;
}
return topN.withChildren(project.withChildren(newJoin));
- }).toRule(RuleType.PUSH_TOP_N_THROUGH_PROJECT_JOIN)
+ }).toRule(RuleType.PUSH_DOWN_TOP_N_THROUGH_PROJECT_JOIN)
);
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNThroughUnion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNThroughUnion.java
new file mode 100644
index 0000000000..c13c1143ef
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNThroughUnion.java
@@ -0,0 +1,88 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.properties.OrderKey;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ *
+ * TopN
+ * -> Union All
+ * -> child plan1
+ * -> child plan2
+ * -> child plan3
+ *
+ * rewritten to
+ *
+ * -> Union All
+ * -> TopN
+ * -> child plan1
+ * -> TopN
+ * -> child plan2
+ * -> TopN
+ * -> child plan3
+ *
+ */
+public class PushDownTopNThroughUnion implements RewriteRuleFactory {
+
+ @Override
+ public List buildRules() {
+ return ImmutableList.of(
+ logicalTopN(logicalUnion().when(union -> union.getQualifier() == Qualifier.ALL))
+ .then(topN -> {
+ LogicalUnion union = topN.child();
+ List newChildren = new ArrayList<>();
+ for (Plan child : union.children()) {
+ Map replaceMap = new HashMap<>();
+ for (int i = 0; i < union.getOutputs().size(); ++i) {
+ NamedExpression output = union.getOutputs().get(i);
+ replaceMap.put(output, child.getOutput().get(i));
+ }
+
+ List orderKeys = topN.getOrderKeys().stream()
+ .map(orderKey -> orderKey.withExpression(
+ ExpressionUtils.replace(orderKey.getExpr(), replaceMap)))
+ .collect(ImmutableList.toImmutableList());
+ newChildren.add(
+ new LogicalTopN<>(orderKeys, topN.getLimit() + topN.getOffset(), 0, child));
+ }
+ if (union.children().equals(newChildren)) {
+ return null;
+ }
+ return topN.withChildren(union.withChildren(newChildren));
+ })
+ .toRule(RuleType.PUSH_DOWN_TOP_N_THROUGH_UNION)
+ );
+ }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNThroughWindow.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNThroughWindow.java
index 11dd2b7959..7a0eb06887 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNThroughWindow.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNThroughWindow.java
@@ -59,7 +59,7 @@ public class PushDownTopNThroughWindow implements RewriteRuleFactory {
return topn;
}
return topn.withChildren(newWindow.get());
- }).toRule(RuleType.PUSH_TOP_N_THROUGH_WINDOW),
+ }).toRule(RuleType.PUSH_DOWN_TOP_N_THROUGH_WINDOW),
// topn -> projection -> window
logicalTopN(logicalProject(logicalWindow())).then(topn -> {
@@ -79,7 +79,7 @@ public class PushDownTopNThroughWindow implements RewriteRuleFactory {
return topn;
}
return topn.withChildren(project.withChildren(newWindow.get()));
- }).toRule(RuleType.PUSH_TOP_N_THROUGH_PROJECT_WINDOW)
+ }).toRule(RuleType.PUSH_DOWN_TOP_N_THROUGH_PROJECT_WINDOW)
);
}
diff --git a/regression-test/data/nereids_rules_p0/push_down_topn/push_down_topn_through_union.out b/regression-test/data/nereids_rules_p0/push_down_topn/push_down_topn_through_union.out
new file mode 100644
index 0000000000..9c20ca80d5
--- /dev/null
+++ b/regression-test/data/nereids_rules_p0/push_down_topn/push_down_topn_through_union.out
@@ -0,0 +1,194 @@
+-- This file is automatically generated. You should know what you did if you want to edit this
+-- !push_down_topn_through_union --
+PhysicalResultSink
+--PhysicalTopN[MERGE_SORT]
+----PhysicalDistribute
+------PhysicalTopN[LOCAL_SORT]
+--------PhysicalUnion
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------PhysicalOlapScan[t]
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------PhysicalOlapScan[t]
+
+-- !push_down_topn_union_with_conditions --
+PhysicalResultSink
+--PhysicalTopN[MERGE_SORT]
+----PhysicalDistribute
+------PhysicalTopN[LOCAL_SORT]
+--------PhysicalUnion
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------filter((t1.score > 10))
+--------------------PhysicalOlapScan[t]
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------filter((t2.name = 'Test'))
+--------------------PhysicalOlapScan[t]
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------filter((t3.id < 5))
+--------------------PhysicalOlapScan[t]
+
+-- !push_down_topn_union_with_order_by --
+PhysicalResultSink
+--PhysicalTopN[MERGE_SORT]
+----PhysicalDistribute
+------PhysicalTopN[LOCAL_SORT]
+--------PhysicalUnion
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------PhysicalOlapScan[t]
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------PhysicalOlapScan[t]
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------PhysicalOlapScan[t]
+
+-- !push_down_topn_nested_union --
+PhysicalResultSink
+--PhysicalTopN[MERGE_SORT]
+----PhysicalDistribute
+------PhysicalTopN[LOCAL_SORT]
+--------PhysicalUnion
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------PhysicalOlapScan[t]
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------PhysicalOlapScan[t]
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------PhysicalOlapScan[t]
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------PhysicalOlapScan[t]
+
+-- !push_down_topn_union_after_join --
+PhysicalResultSink
+--PhysicalTopN[MERGE_SORT]
+----PhysicalDistribute
+------PhysicalTopN[LOCAL_SORT]
+--------PhysicalUnion
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------PhysicalProject
+--------------------hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id)) otherCondition=()
+----------------------PhysicalProject
+------------------------PhysicalOlapScan[t] apply RFs: RF0
+----------------------PhysicalProject
+------------------------PhysicalOlapScan[t]
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------PhysicalProject
+--------------------PhysicalOlapScan[t]
+
+-- !push_down_topn_union_different_projections --
+PhysicalResultSink
+--PhysicalTopN[MERGE_SORT]
+----PhysicalDistribute
+------PhysicalTopN[LOCAL_SORT]
+--------PhysicalUnion
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------PhysicalProject
+--------------------PhysicalOlapScan[t]
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------PhysicalProject
+--------------------PhysicalOlapScan[t]
+
+-- !push_down_topn_union_with_subquery --
+PhysicalResultSink
+--PhysicalTopN[MERGE_SORT]
+----PhysicalDistribute
+------PhysicalTopN[LOCAL_SORT]
+--------PhysicalUnion
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------PhysicalProject
+--------------------filter((t.score > 20))
+----------------------PhysicalOlapScan[t]
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------PhysicalProject
+--------------------PhysicalOlapScan[t]
+
+-- !push_down_topn_union_with_limit --
+PhysicalResultSink
+--PhysicalTopN[MERGE_SORT]
+----PhysicalDistribute
+------PhysicalTopN[LOCAL_SORT]
+--------PhysicalUnion
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalTopN[LOCAL_SORT]
+----------------PhysicalLimit[GLOBAL]
+------------------PhysicalDistribute
+--------------------PhysicalLimit[LOCAL]
+----------------------PhysicalOlapScan[t]
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalTopN[LOCAL_SORT]
+----------------PhysicalLimit[GLOBAL]
+------------------PhysicalDistribute
+--------------------PhysicalLimit[LOCAL]
+----------------------PhysicalOlapScan[t]
+
+-- !push_down_topn_union_complex_conditions --
+PhysicalResultSink
+--PhysicalTopN[MERGE_SORT]
+----PhysicalDistribute
+------PhysicalTopN[LOCAL_SORT]
+--------PhysicalUnion
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------filter((t1.name = 'Test') and (t1.score > 10))
+--------------------PhysicalOlapScan[t]
+----------PhysicalDistribute
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalDistribute
+----------------PhysicalTopN[LOCAL_SORT]
+------------------filter((t2.id < 5) and (t2.score < 20))
+--------------------PhysicalOlapScan[t]
+
diff --git a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query23.out b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query23.out
index 4a9c85f1b2..d9af7d72f2 100644
--- a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query23.out
+++ b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query23.out
@@ -158,54 +158,62 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
--------PhysicalDistribute
----------PhysicalTopN[LOCAL_SORT]
------------PhysicalUnion
---------------hashAgg[GLOBAL]
-----------------PhysicalDistribute
-------------------hashAgg[LOCAL]
---------------------PhysicalProject
-----------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((catalog_sales.cs_item_sk = frequent_ss_items.item_sk)) otherCondition=()
+--------------PhysicalDistribute
+----------------PhysicalTopN[MERGE_SORT]
+------------------PhysicalDistribute
+--------------------PhysicalTopN[LOCAL_SORT]
+----------------------hashAgg[GLOBAL]
------------------------PhysicalDistribute
---------------------------PhysicalProject
-----------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
-------------------------PhysicalDistribute
---------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF7 c_customer_sk->[cs_bill_customer_sk]
-----------------------------hashJoin[LEFT_SEMI_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = best_ss_customer.c_customer_sk)) otherCondition=() build RFs:RF6 c_customer_sk->[cs_bill_customer_sk]
-------------------------------PhysicalDistribute
---------------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF5 d_date_sk->[cs_sold_date_sk]
+--------------------------hashAgg[LOCAL]
+----------------------------PhysicalProject
+------------------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((catalog_sales.cs_item_sk = frequent_ss_items.item_sk)) otherCondition=()
+--------------------------------PhysicalDistribute
----------------------------------PhysicalProject
-------------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF5 RF6 RF7
-----------------------------------PhysicalDistribute
-------------------------------------PhysicalProject
---------------------------------------filter((date_dim.d_moy = 7) and (date_dim.d_year = 2000))
-----------------------------------------PhysicalOlapScan[date_dim]
-------------------------------PhysicalDistribute
---------------------------------PhysicalProject
-----------------------------------PhysicalCteConsumer ( cteId=CTEId#2 )
-----------------------------PhysicalDistribute
-------------------------------PhysicalProject
---------------------------------PhysicalOlapScan[customer]
---------------hashAgg[GLOBAL]
-----------------PhysicalDistribute
-------------------hashAgg[LOCAL]
---------------------PhysicalProject
-----------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((web_sales.ws_item_sk = frequent_ss_items.item_sk)) otherCondition=()
+------------------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
+--------------------------------PhysicalDistribute
+----------------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF7 c_customer_sk->[cs_bill_customer_sk]
+------------------------------------hashJoin[LEFT_SEMI_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = best_ss_customer.c_customer_sk)) otherCondition=() build RFs:RF6 c_customer_sk->[cs_bill_customer_sk]
+--------------------------------------PhysicalDistribute
+----------------------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF5 d_date_sk->[cs_sold_date_sk]
+------------------------------------------PhysicalProject
+--------------------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF5 RF6 RF7
+------------------------------------------PhysicalDistribute
+--------------------------------------------PhysicalProject
+----------------------------------------------filter((date_dim.d_moy = 7) and (date_dim.d_year = 2000))
+------------------------------------------------PhysicalOlapScan[date_dim]
+--------------------------------------PhysicalDistribute
+----------------------------------------PhysicalProject
+------------------------------------------PhysicalCteConsumer ( cteId=CTEId#2 )
+------------------------------------PhysicalDistribute
+--------------------------------------PhysicalProject
+----------------------------------------PhysicalOlapScan[customer]
+--------------PhysicalDistribute
+----------------PhysicalTopN[MERGE_SORT]
+------------------PhysicalDistribute
+--------------------PhysicalTopN[LOCAL_SORT]
+----------------------hashAgg[GLOBAL]
------------------------PhysicalDistribute
---------------------------PhysicalProject
-----------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
-------------------------PhysicalDistribute
---------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF10 ws_bill_customer_sk->[c_customer_sk]
-----------------------------PhysicalDistribute
-------------------------------PhysicalProject
---------------------------------PhysicalOlapScan[customer] apply RFs: RF10
-----------------------------hashJoin[LEFT_SEMI_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = best_ss_customer.c_customer_sk)) otherCondition=() build RFs:RF9 c_customer_sk->[ws_bill_customer_sk]
-------------------------------PhysicalDistribute
---------------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF8 d_date_sk->[ws_sold_date_sk]
+--------------------------hashAgg[LOCAL]
+----------------------------PhysicalProject
+------------------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((web_sales.ws_item_sk = frequent_ss_items.item_sk)) otherCondition=()
+--------------------------------PhysicalDistribute
----------------------------------PhysicalProject
-------------------------------------PhysicalOlapScan[web_sales] apply RFs: RF8 RF9
-----------------------------------PhysicalDistribute
-------------------------------------PhysicalProject
---------------------------------------filter((date_dim.d_moy = 7) and (date_dim.d_year = 2000))
-----------------------------------------PhysicalOlapScan[date_dim]
-------------------------------PhysicalDistribute
---------------------------------PhysicalProject
-----------------------------------PhysicalCteConsumer ( cteId=CTEId#2 )
+------------------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
+--------------------------------PhysicalDistribute
+----------------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF10 ws_bill_customer_sk->[c_customer_sk]
+------------------------------------PhysicalDistribute
+--------------------------------------PhysicalProject
+----------------------------------------PhysicalOlapScan[customer] apply RFs: RF10
+------------------------------------hashJoin[LEFT_SEMI_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = best_ss_customer.c_customer_sk)) otherCondition=() build RFs:RF9 c_customer_sk->[ws_bill_customer_sk]
+--------------------------------------PhysicalDistribute
+----------------------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF8 d_date_sk->[ws_sold_date_sk]
+------------------------------------------PhysicalProject
+--------------------------------------------PhysicalOlapScan[web_sales] apply RFs: RF8 RF9
+------------------------------------------PhysicalDistribute
+--------------------------------------------PhysicalProject
+----------------------------------------------filter((date_dim.d_moy = 7) and (date_dim.d_year = 2000))
+------------------------------------------------PhysicalOlapScan[date_dim]
+--------------------------------------PhysicalDistribute
+----------------------------------------PhysicalProject
+------------------------------------------PhysicalCteConsumer ( cteId=CTEId#2 )
diff --git a/regression-test/suites/nereids_rules_p0/push_down_topn/push_down_topn_through_union.groovy b/regression-test/suites/nereids_rules_p0/push_down_topn/push_down_topn_through_union.groovy
new file mode 100644
index 0000000000..afc26a51cf
--- /dev/null
+++ b/regression-test/suites/nereids_rules_p0/push_down_topn/push_down_topn_through_union.groovy
@@ -0,0 +1,82 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("push_down_topn_through_union") {
+ sql "SET enable_nereids_planner=true"
+ sql "SET enable_fallback_to_original_planner=false"
+
+ sql """
+ DROP TABLE IF EXISTS t1;
+ """
+ sql """
+ DROP TABLE IF EXISTS t2;
+ """
+ sql """
+ DROP TABLE IF EXISTS t3;
+ """
+ sql """
+ DROP TABLE IF EXISTS t4;
+ """
+
+ sql """
+ CREATE TABLE IF NOT EXISTS t(
+ `id` int(32) NULL,
+ `score` int(64) NULL,
+ `name` varchar(64) NULL
+ ) ENGINE = OLAP
+ DISTRIBUTED BY HASH(id) BUCKETS 4
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1"
+ );
+ """
+
+ qt_push_down_topn_through_union """
+ explain shape plan select * from (select * from t t1 union all select * from t t2) t order by id limit 10;
+ """
+
+ qt_push_down_topn_union_with_conditions """
+ explain shape plan select * from (select * from t t1 where t1.score > 10 union all select * from t t2 where t2.name = 'Test' union all select * from t t3 where t3.id < 5) sub order by id limit 10;
+ """
+
+ qt_push_down_topn_union_with_order_by """
+ explain shape plan select * from (select * from t t1 union all select * from t t2 union all select * from t t3 order by score) sub order by id limit 10;
+ """
+
+ qt_push_down_topn_nested_union """
+ explain shape plan select * from ((select * from t t1 union all select * from t t2) union all (select * from t t3 union all select * from t t4)) sub order by id limit 10;
+ """
+
+ qt_push_down_topn_union_after_join """
+ explain shape plan select * from (select t1.id from t t1 join t t2 on t1.id = t2.id union all select id from t t3) sub order by id limit 10;
+ """
+
+ qt_push_down_topn_union_different_projections """
+ explain shape plan select * from (select id from t t1 union all select name from t t2) sub order by id limit 10;
+ """
+
+ qt_push_down_topn_union_with_subquery """
+ explain shape plan select * from (select id from (select * from t where score > 20) t1 union all select id from t t2) sub order by id limit 10;
+ """
+
+ qt_push_down_topn_union_with_limit """
+ explain shape plan select * from (select * from t t1 limit 5 union all select * from t t2 limit 5) sub order by id limit 10;
+ """
+
+ qt_push_down_topn_union_complex_conditions """
+ explain shape plan select * from (select * from t t1 where t1.score > 10 and t1.name = 'Test' union all select * from t t2 where t2.id < 5 and t2.score < 20) sub order by id limit 10;
+ """
+}
\ No newline at end of file