diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughPartitionTopN.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughPartitionTopN.java index 4cb8076f7a..312c2bdac3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughPartitionTopN.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughPartitionTopN.java @@ -29,6 +29,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet.Builder; +import java.util.HashSet; import java.util.Set; /** @@ -62,19 +63,17 @@ public class PushDownFilterThroughPartitionTopN extends OneRewriteRuleFactory { // PushdownFilterThroughWindow Builder bottomConjunctsBuilder = ImmutableSet.builder(); Builder upperConjunctsBuilder = ImmutableSet.builder(); - for (Expression expr : filter.getConjuncts()) { - boolean pushed = false; - Set exprInputSlots = expr.getInputSlots(); - for (Expression partitionKey : partitionTopN.getPartitionKeys()) { - if (partitionKey instanceof SlotReference - && exprInputSlots.size() == 1 - && partitionKey.getInputSlots().containsAll(exprInputSlots)) { - bottomConjunctsBuilder.add(expr); - pushed = true; - break; - } + Set partitionKeySlots = new HashSet<>(); + for (Expression partitionKey : partitionTopN.getPartitionKeys()) { + if (partitionKey instanceof SlotReference) { + partitionKeySlots.add((SlotReference) partitionKey); } - if (!pushed) { + } + for (Expression expr : filter.getConjuncts()) { + Set exprInputSlots = expr.getInputSlots(); + if (partitionKeySlots.containsAll(exprInputSlots)) { + bottomConjunctsBuilder.add(expr); + } else { upperConjunctsBuilder.add(expr); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughWindow.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughWindow.java index 0696bdd95c..a949cbc945 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughWindow.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughWindow.java @@ -64,18 +64,9 @@ public class PushDownFilterThroughWindow extends OneRewriteRuleFactory { Set bottomConjuncts = Sets.newHashSet(); Set upperConjuncts = Sets.newHashSet(); for (Expression expr : filter.getConjuncts()) { - boolean pushed = false; - for (Expression partitionKey : commonPartitionKeys) { - // partitionKey is a single slot reference, - // we want to push expressions which have only one input slot, and the input slot is used as - // partition key in all windowExpressions. - if (partitionKey.getInputSlots().containsAll(expr.getInputSlots())) { - bottomConjuncts.add(expr); - pushed = true; - break; - } - } - if (!pushed) { + if (commonPartitionKeys.containsAll(expr.getInputSlots())) { + bottomConjuncts.add(expr); + } else { upperConjuncts.add(expr); } } diff --git a/regression-test/data/nereids_rules_p0/push_down_filter_through_window/push_down_filter_through_window.out b/regression-test/data/nereids_rules_p0/push_down_filter_through_window/push_down_filter_through_window.out new file mode 100644 index 0000000000..9ecb96e1fc --- /dev/null +++ b/regression-test/data/nereids_rules_p0/push_down_filter_through_window/push_down_filter_through_window.out @@ -0,0 +1,31 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !multi_column_predicate_push_down_window_shape -- +PhysicalResultSink +--filter((num <= 2)) +----PhysicalWindow +------PhysicalQuickSort[LOCAL_SORT] +--------PhysicalPartitionTopN +----------PhysicalProject +------------filter((abs((id + value1)) < 4)) +--------------PhysicalOlapScan[push_down_multi_column_predicate_through_window_t] + +-- !multi_column_predicate_push_down_window -- +1 1 10 +1 2 20 + +-- !multi_column_or_predicate_push_down_window_shape -- +PhysicalResultSink +--filter((rc < 2)) +----PhysicalWindow +------PhysicalQuickSort[LOCAL_SORT] +--------PhysicalPartitionTopN +----------PhysicalProject +------------filter(((t.id > 1) OR (t.value1 > 2))) +--------------PhysicalOlapScan[push_down_multi_column_predicate_through_window_t] + +-- !multi_column_or_predicate_push_down_window -- +1 10 1 +2 20 1 +3 30 1 +4 40 1 + diff --git a/regression-test/suites/nereids_rules_p0/push_down_filter_through_window/push_down_filter_through_window.groovy b/regression-test/suites/nereids_rules_p0/push_down_filter_through_window/push_down_filter_through_window.groovy new file mode 100644 index 0000000000..75464a428e --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/push_down_filter_through_window/push_down_filter_through_window.groovy @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("push_down_filter_through_window") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + sql "set ignore_shape_nodes='PhysicalDistribute'" + sql "drop table if exists push_down_multi_column_predicate_through_window_t" + multi_sql """ + CREATE TABLE push_down_multi_column_predicate_through_window_t (id INT, value1 INT, value2 VARCHAR(50)) + properties("replication_num"="1"); + INSERT INTO push_down_multi_column_predicate_through_window_t (id, value1, value2) VALUES(1, 10, 'A'),(2, 20, 'B'),(3, 30, 'C'),(4, 40, 'D'); + """ + qt_multi_column_predicate_push_down_window_shape """ + explain shape plan + select * from (select row_number() over(partition by id,value1 order by value1) as num, id, value1 from push_down_multi_column_predicate_through_window_t ) t + where abs(id+value1)<4 and num<=2; + """ + qt_multi_column_predicate_push_down_window """ + select * from (select row_number() over(partition by id,value1 order by value1) as num, id, value1 from push_down_multi_column_predicate_through_window_t ) t + where abs(id+value1)<30 and num<=2 order by id,value1,num; + """ + qt_multi_column_or_predicate_push_down_window_shape """ + explain shape plan + select * from (select id,value1, row_number() over(partition by id,value1 order by value1) rc from push_down_multi_column_predicate_through_window_t ) t where (id>1 or value1>2) and rc<2; + """ + qt_multi_column_or_predicate_push_down_window """ + select * from (select id,value1, row_number() over(partition by id,value1 order by value1) rc from push_down_multi_column_predicate_through_window_t ) t where (id>1 or value1>2) and rc<2 order by 1,2 ; + """ +} \ No newline at end of file