[feature](Nereids): Pushdown LimitDistinct Through Join (#25113)

Push down limit-distinct through left/right outer join or cross join.

such as select t1.c1 from t1 left join t2 on t1.c1 = t2.c1 order by t1.c1 limit 1;
This commit is contained in:
jakevin
2023-10-09 14:19:22 +08:00
committed by GitHub
parent 5a55e47acd
commit b41ec6a8a4
10 changed files with 447 additions and 21 deletions

View File

@ -91,6 +91,7 @@ import org.apache.doris.nereids.rules.rewrite.PushProjectIntoOneRowRelation;
import org.apache.doris.nereids.rules.rewrite.PushProjectThroughUnion;
import org.apache.doris.nereids.rules.rewrite.PushdownFilterThroughProject;
import org.apache.doris.nereids.rules.rewrite.PushdownLimit;
import org.apache.doris.nereids.rules.rewrite.PushdownLimitDistinctThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushdownTopNThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushdownTopNThroughWindow;
import org.apache.doris.nereids.rules.rewrite.ReorderJoin;
@ -280,6 +281,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
new SplitLimit(),
new PushdownLimit(),
new PushdownTopNThroughJoin(),
new PushdownLimitDistinctThroughJoin(),
new PushdownTopNThroughWindow(),
new CreatePartitionTopNFromWindow()
)

View File

@ -252,6 +252,9 @@ public enum RuleType {
PUSH_TOP_N_THROUGH_PROJECT_JOIN(RuleTypeClass.REWRITE),
PUSH_TOP_N_THROUGH_PROJECT_WINDOW(RuleTypeClass.REWRITE),
PUSH_TOP_N_THROUGH_WINDOW(RuleTypeClass.REWRITE),
// limit distinct push down
PUSH_LIMIT_DISTINCT_THROUGH_JOIN(RuleTypeClass.REWRITE),
PUSH_LIMIT_DISTINCT_THROUGH_PROJECT_JOIN(RuleTypeClass.REWRITE),
// adjust nullable
ADJUST_NULLABLE(RuleTypeClass.REWRITE),
ADJUST_CONJUNCTS_RETURN_TYPE(RuleTypeClass.REWRITE),

View File

@ -0,0 +1,109 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.stream.Collectors;
/**
* Same with PushdownLimit
*/
public class PushdownLimitDistinctThroughJoin implements RewriteRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
// limit -> distinct -> join
logicalLimit(logicalAggregate(logicalJoin())
.when(LogicalAggregate::isDistinct))
.then(limit -> {
LogicalAggregate<LogicalJoin<Plan, Plan>> agg = limit.child();
LogicalJoin<Plan, Plan> join = agg.child();
Plan newJoin = pushLimitThroughJoin(limit, join);
if (newJoin == null || join.children().equals(newJoin.children())) {
return null;
}
return limit.withChildren(agg.withChildren(newJoin));
})
.toRule(RuleType.PUSH_LIMIT_DISTINCT_THROUGH_JOIN),
// limit -> distinct -> project -> join
logicalLimit(logicalAggregate(logicalProject(logicalJoin()).when(LogicalProject::isAllSlots))
.when(LogicalAggregate::isDistinct))
.then(limit -> {
LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = limit.child();
LogicalProject<LogicalJoin<Plan, Plan>> project = agg.child();
LogicalJoin<Plan, Plan> join = project.child();
Plan newJoin = pushLimitThroughJoin(limit, join);
if (newJoin == null || join.children().equals(newJoin.children())) {
return null;
}
return limit.withChildren(agg.withChildren(project.withChildren(newJoin)));
}).toRule(RuleType.PUSH_LIMIT_DISTINCT_THROUGH_JOIN)
);
}
private Plan pushLimitThroughJoin(LogicalLimit<?> limit, LogicalJoin<Plan, Plan> join) {
LogicalAggregate<?> agg = (LogicalAggregate<?>) limit.child();
List<Slot> groupBySlots = agg.getGroupByExpressions().stream()
.flatMap(e -> e.getInputSlots().stream()).collect(Collectors.toList());
switch (join.getJoinType()) {
case LEFT_OUTER_JOIN:
if (join.left().getOutputSet().containsAll(groupBySlots)
&& join.left().getOutputSet().equals(agg.getOutputSet())) {
return join.withChildren(limit.withLimitChild(limit.getLimit() + limit.getOffset(), 0,
agg.withChildren(join.left())), join.right());
}
return null;
case RIGHT_OUTER_JOIN:
if (join.right().getOutputSet().containsAll(groupBySlots)
&& join.right().getOutputSet().equals(agg.getOutputSet())) {
return join.withChildren(join.left(), limit.withLimitChild(limit.getLimit() + limit.getOffset(), 0,
agg.withChildren(join.right())));
}
return null;
case CROSS_JOIN:
if (join.left().getOutputSet().containsAll(groupBySlots)
&& join.left().getOutputSet().equals(agg.getOutputSet())) {
return join.withChildren(limit.withLimitChild(limit.getLimit() + limit.getOffset(), 0,
agg.withChildren(join.left())), join.right());
} else if (join.right().getOutputSet().containsAll(groupBySlots)
&& join.right().getOutputSet().equals(agg.getOutputSet())) {
return join.withChildren(join.left(), limit.withLimitChild(limit.getLimit() + limit.getOffset(), 0,
agg.withChildren(join.right())));
} else {
return null;
}
default:
return null;
}
}
}

View File

@ -25,12 +25,10 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
/**
@ -75,28 +73,33 @@ public class PushdownTopNThroughJoin implements RewriteRuleFactory {
}
private Plan pushLimitThroughJoin(LogicalTopN<? extends Plan> topN, LogicalJoin<Plan, Plan> join) {
List<Slot> orderbySlots = topN.getOrderKeys().stream().map(OrderKey::getExpr)
.flatMap(e -> e.getInputSlots().stream()).collect(Collectors.toList());
switch (join.getJoinType()) {
case LEFT_OUTER_JOIN:
Set<Slot> rightOutputSet = join.right().getOutputSet();
if (topN.getOrderKeys().stream().map(OrderKey::getExpr)
.anyMatch(e -> Utils.isIntersecting(rightOutputSet, e.getInputSlots()))) {
return null;
}
return join.withChildren(topN.withChildren(join.left()), join.right());
case RIGHT_OUTER_JOIN:
Set<Slot> leftOutputSet = join.left().getOutputSet();
if (topN.getOrderKeys().stream().map(OrderKey::getExpr)
.anyMatch(e -> Utils.isIntersecting(leftOutputSet, e.getInputSlots()))) {
return null;
}
return join.withChildren(join.left(), topN.withChildren(join.right()));
case CROSS_JOIN:
List<Slot> orderbySlots = topN.getOrderKeys().stream().map(OrderKey::getExpr)
.flatMap(e -> e.getInputSlots().stream()).collect(Collectors.toList());
if (join.left().getOutputSet().containsAll(orderbySlots)) {
return join.withChildren(topN.withChildren(join.left()), join.right());
return join.withChildren(
topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0, join.left()),
join.right());
}
return null;
case RIGHT_OUTER_JOIN:
if (join.right().getOutputSet().containsAll(orderbySlots)) {
return join.withChildren(
join.left(),
topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0, join.right()));
}
return null;
case CROSS_JOIN:
if (join.left().getOutputSet().containsAll(orderbySlots)) {
return join.withChildren(
topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0, join.left()),
join.right());
} else if (join.right().getOutputSet().containsAll(orderbySlots)) {
return join.withChildren(join.left(), topN.withChildren(join.right()));
return join.withChildren(
join.left(),
topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0, join.right()));
} else {
return null;
}

View File

@ -161,7 +161,8 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
}
public boolean isDistinct() {
return outputExpressions.equals(groupByExpressions);
return outputExpressions.stream().allMatch(e -> e instanceof Slot)
&& groupByExpressions.stream().allMatch(e -> e instanceof Slot);
}
public boolean isGenerated() {

View File

@ -117,6 +117,12 @@ public class LogicalLimit<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_TY
return ImmutableList.of();
}
public LogicalLimit<Plan> withLimitChild(long limit, long offset, Plan child) {
Preconditions.checkArgument(children.size() == 1,
"LogicalTopN should have 1 child, but input is %s", children.size());
return new LogicalLimit<>(limit, offset, phase, child);
}
@Override
public Plan withGroupExpression(Optional<GroupExpression> groupExpression) {
return new LogicalLimit<>(limit, offset, phase, groupExpression, Optional.of(getLogicalProperties()), child());

View File

@ -122,6 +122,13 @@ public class LogicalTopN<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_TYP
Optional.empty(), Optional.of(getLogicalProperties()), child());
}
public LogicalTopN<Plan> withLimitChild(long limit, long offset, Plan child) {
Preconditions.checkArgument(children.size() == 1,
"LogicalTopN should have 1 child, but input is %s", children.size());
return new LogicalTopN<>(orderKeys, limit, offset,
Optional.empty(), Optional.of(getLogicalProperties()), child);
}
@Override
public LogicalTopN<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1,

View File

@ -0,0 +1,167 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.utframe.TestWithFeService;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;
class PushdownLimitDistinctThroughJoinTest extends TestWithFeService implements MemoPatternMatchSupported {
private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
@Override
protected void runBeforeAll() throws Exception {
createDatabase("test");
connectContext.setDatabase("default_cluster:test");
createTable("CREATE TABLE `t1` (\n"
+ " `k1` int(11) NULL,\n"
+ " `k2` int(11) NULL\n"
+ ") ENGINE=OLAP\n"
+ "COMMENT 'OLAP'\n"
+ "DISTRIBUTED BY HASH(`k1`) BUCKETS 3\n"
+ "PROPERTIES (\n"
+ "\"replication_allocation\" = \"tag.location.default: 1\",\n"
+ "\"in_memory\" = \"false\",\n"
+ "\"storage_format\" = \"V2\",\n"
+ "\"disable_auto_compaction\" = \"false\"\n"
+ ");");
createTable("CREATE TABLE `t2` (\n"
+ " `k1` int(11) NULL,\n"
+ " `k2` int(11) NULL\n"
+ ") ENGINE=OLAP\n"
+ "COMMENT 'OLAP'\n"
+ "DISTRIBUTED BY HASH(`k1`) BUCKETS 3\n"
+ "PROPERTIES (\n"
+ "\"replication_allocation\" = \"tag.location.default: 1\",\n"
+ "\"in_memory\" = \"false\",\n"
+ "\"storage_format\" = \"V2\",\n"
+ "\"disable_auto_compaction\" = \"false\"\n"
+ ");");
}
@Test
void testJoin() {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0))
.distinct(ImmutableList.of(0, 1))
.limit(10)
.build();
PlanChecker.from(connectContext, plan)
.applyTopDown(new PushdownLimitDistinctThroughJoin())
.matches(
logicalJoin(
logicalLimit(logicalAggregate(logicalOlapScan())).when(l -> l.getLimit() == 10),
logicalOlapScan()
)
);
plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0))
.distinct(ImmutableList.of(2, 3))
.limit(10)
.build();
PlanChecker.from(connectContext, plan)
.applyTopDown(new PushdownLimitDistinctThroughJoin())
.matches(
logicalJoin(
logicalOlapScan(),
logicalLimit(logicalAggregate(logicalOlapScan())).when(l -> l.getLimit() == 10)
)
);
plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.CROSS_JOIN, Pair.of(0, 0))
.distinct(ImmutableList.of(0, 1))
.limit(10)
.build();
PlanChecker.from(connectContext, plan)
.applyTopDown(new PushdownLimitDistinctThroughJoin())
.matches(
logicalJoin(
logicalLimit(logicalAggregate(logicalOlapScan())).when(l -> l.getLimit() == 10),
logicalOlapScan()
)
);
plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.CROSS_JOIN, Pair.of(0, 0))
.distinct(ImmutableList.of(2, 3))
.limit(10)
.build();
PlanChecker.from(connectContext, plan)
.applyTopDown(new PushdownLimitDistinctThroughJoin())
.matches(
logicalJoin(
logicalOlapScan(),
logicalLimit(logicalAggregate(logicalOlapScan())).when(l -> l.getLimit() == 10)
)
);
}
@Test
void testJoinSql() {
PlanChecker.from(connectContext)
.analyze("select t1.k1 from t1 left join t2 on t1.k1 = t2.k1 group by t1.k1 limit 10")
.rewrite()
.matches(
logicalProject(logicalJoin(
logicalLimit(logicalAggregate(logicalProject(logicalOlapScan())))
.when(l -> l.getLimit() == 10),
logicalProject(logicalOlapScan())
))
);
}
@Test
void badCaseJoinType() {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0))
.distinct(ImmutableList.of(2))
.limit(10)
.build();
PlanChecker.from(connectContext, plan)
.applyTopDown(new PushdownLimitDistinctThroughJoin())
.matches(logicalJoin(logicalOlapScan(), logicalOlapScan()));
}
@Test
void badCaseOutput() {
// distinct agg don't output all group by columns of left child
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0))
.distinct(ImmutableList.of(0))
.limit(10)
.build();
PlanChecker.from(connectContext, plan)
.applyTopDown(new PushdownLimitDistinctThroughJoin())
.matches(logicalJoin(logicalOlapScan(), logicalOlapScan()));
}
}

View File

@ -0,0 +1,23 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !join1 --
1
2
-- !join3 --
0
1
-- !join5 --
1
1
-- !join6 --
1
-- !join7 --
0
0
-- !join8 --
0

View File

@ -0,0 +1,105 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("test_limit_join", "nereids_p0") {
def DBname = "nereids_regression_test_limit_join"
sql "DROP DATABASE IF EXISTS ${DBname}"
sql "CREATE DATABASE IF NOT EXISTS ${DBname}"
sql "use ${DBname}"
sql "SET enable_nereids_planner=true"
sql "SET enable_fallback_to_original_planner=false"
def tbName1 = "t1"
def tbName2 = "t2"
sql "DROP TABLE IF EXISTS ${tbName1};"
sql "DROP TABLE IF EXISTS ${tbName2};"
sql """create table if not exists ${tbName1} (c1 int, c2 int) DISTRIBUTED BY HASH(c1) properties("replication_num" = "1");"""
sql """create table if not exists ${tbName2} (c1 int, c2 int, c3 int) DISTRIBUTED BY HASH(c1) properties("replication_num" = "1");"""
sql "insert into ${tbName1} values (1,1);"
sql "insert into ${tbName1} values (2,2);"
sql "insert into ${tbName1} values (1,null);"
sql "insert into ${tbName1} values (2,null);"
sql "insert into ${tbName2} values (0,1,9999);"
sql "insert into ${tbName2} values (1,1,9999);"
sql "insert into ${tbName2} values (0,null,9999);"
sql "insert into ${tbName2} values (1,null,9999);"
/* test push limit-distinct through join */
order_qt_join1 """
SELECT t1.c1
FROM ${tbName1} t1 left join ${tbName2} t2 on t1.c1 = t2.c1
GROUP BY t1.c1
limit 2;
"""
sql """
SELECT t1.c1
FROM ${tbName1} t1 left join ${tbName2} t2 on t1.c1 = t2.c1
GROUP BY t1.c1
LIMIT 1 OFFSET 1;
"""
order_qt_join3 """
SELECT t2.c1
FROM ${tbName1} t1 right join ${tbName2} t2 on t1.c1 = t2.c1
GROUP BY t2.c1
limit 2;
"""
sql """
SELECT t2.c1
FROM ${tbName1} t1 right join ${tbName2} t2 on t1.c1 = t2.c1
GROUP BY t2.c1
LIMIT 1 OFFSET 1;
"""
/* test push topN through join */
qt_join5 """
SELECT t1.c1
FROM ${tbName1} t1 left join ${tbName2} t2 on t1.c1 = t2.c1
ORDER BY t1.c1
limit 2;
"""
qt_join6 """
SELECT t1.c1
FROM ${tbName1} t1 left join ${tbName2} t2 on t1.c1 = t2.c1
ORDER BY t1.c1
LIMIT 1 OFFSET 1;
"""
qt_join7 """
SELECT t2.c1
FROM ${tbName1} t1 right join ${tbName2} t2 on t1.c1 = t2.c1
ORDER BY t2.c1
limit 2;
"""
qt_join8 """
SELECT t2.c1
FROM ${tbName1} t1 right join ${tbName2} t2 on t1.c1 = t2.c1
ORDER BY t2.c1
LIMIT 1 OFFSET 1;
"""
sql "DROP DATABASE IF EXISTS ${DBname};"
}