[fix](Nereids) join reorder lead to circle in memo (#19935)

If we have join as the root node, then after some join reorder join, the root Group in Memo will have a GroupExpression including LogicalProject as its plan and the children is its ownerGroup.
This PR add a rewrite rule to ensure we have a Project on the top of the top Join of plan to avoid circle in Memo.
This commit is contained in:
morrySnow
2023-05-23 15:22:32 +08:00
committed by GitHub
parent 14de2a5c0e
commit 7247ac9b75
7 changed files with 321 additions and 137 deletions

View File

@ -48,6 +48,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.EliminateNotNull;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateNullAwareLeftAntiJoin;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateOrderByConstant;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateUnnecessaryProject;
import org.apache.doris.nereids.rules.rewrite.logical.EnsureProjectOnTopJoin;
import org.apache.doris.nereids.rules.rewrite.logical.ExtractAndNormalizeWindowExpression;
import org.apache.doris.nereids.rules.rewrite.logical.ExtractFilterFromCrossJoin;
import org.apache.doris.nereids.rules.rewrite.logical.ExtractSingleTableExpressionFromDisjunction;
@ -263,6 +264,11 @@ public class NereidsRewriter extends BatchRewriteJob {
// this rule batch must keep at the end of rewrite to do some plan check
topic("Final rewrite and check",
custom(RuleType.ENSURE_PROJECT_ON_TOP_JOIN, EnsureProjectOnTopJoin::new),
topDown(
new PushdownFilterThroughProject(),
new MergeProjects()
),
custom(RuleType.ADJUST_NULLABLE, AdjustNullable::new),
bottomUp(
new ExpressionRewrite(CheckLegalityAfterRewrite.INSTANCE),

View File

@ -208,6 +208,8 @@ public enum RuleType {
PUSH_LIMIT_INTO_SORT(RuleTypeClass.REWRITE),
// adjust nullable
ADJUST_NULLABLE(RuleTypeClass.REWRITE),
// ensure having project on the top join
ENSURE_PROJECT_ON_TOP_JOIN(RuleTypeClass.REWRITE),
REWRITE_SENTINEL(RuleTypeClass.REWRITE),

View File

@ -0,0 +1,61 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import java.util.List;
import java.util.stream.Collectors;
/**
* The rule add an explicit project at the top join to ensure the output of whole plan is stable
* and avoid generate circle in memo.
*
*/
public class EnsureProjectOnTopJoin extends DefaultPlanRewriter<Void> implements CustomRewriter {
@Override
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
return plan.accept(this, null);
}
@Override
public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Void context) {
return aggregate;
}
@Override
public Plan visitLogicalProject(LogicalProject<? extends Plan> project, Void context) {
return project;
}
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, Void context) {
List<NamedExpression> projects = join.getOutput().stream()
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
return new LogicalProject<>(projects, join);
}
}

View File

@ -82,13 +82,15 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalProject(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
)
)
);
}
@ -101,9 +103,11 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalProject(
logicalJoin(
logicalOlapScan(),
logicalOlapScan()
)
)
);
}
@ -116,11 +120,13 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
logicalOlapScan()
logicalProject(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
logicalOlapScan()
)
)
);
}
@ -133,11 +139,13 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
logicalOlapScan()
logicalProject(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
logicalOlapScan()
)
)
);
}
@ -150,18 +158,20 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalProject(
logicalJoin(
logicalJoin(
logicalFilter(
logicalOlapScan()
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
),
logicalFilter(
),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1"))
).when(filter -> filter.getPredicate().toSql().contains("id > 1"))
)
)
);
}
@ -174,18 +184,20 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalProject(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1"))
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1"))
)
)
);
}
@ -198,13 +210,15 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalProject(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
)
)
);
}
@ -217,11 +231,13 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalOlapScan(),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
logicalProject(
logicalJoin(
logicalOlapScan(),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
)
)
);
}
@ -235,13 +251,15 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalProject(
logicalJoin(
logicalFilter(
logicalOlapScan()
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
)
)
);
}
@ -254,15 +272,17 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalProject(
logicalProject(
logicalJoin(
logicalProject(
logicalFilter(
logicalOlapScan()
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1"))
),
logicalFilter(
),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
)
)
);
}
@ -275,13 +295,15 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalProject(
logicalJoin(
logicalProject(
logicalOlapScan()
),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
)
)
);
}
@ -294,16 +316,18 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicate().toSql().contains("id > 1")),
logicalAggregate(
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicate().toSql().contains("sid > 1"))
))
logicalProject(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicate().toSql().contains("id > 1")),
logicalAggregate(
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicate().toSql().contains("sid > 1"))
))
)
)
);
}
@ -316,7 +340,8 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalProject(
logicalJoin(
logicalProject(
logicalFilter(
logicalOlapScan()
@ -325,6 +350,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid = 1"))
)
)
);
}
@ -337,7 +363,8 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalProject(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
@ -346,6 +373,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
)
)
)
);
}
@ -358,14 +386,16 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
logicalProject(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
)
)
)
);
@ -379,12 +409,14 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalOlapScan(),
logicalProject(
logicalProject(
logicalJoin(
logicalOlapScan(),
logicalProject(
logicalFilter(
logicalOlapScan()
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
)
)
)
);
@ -398,12 +430,14 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalOlapScan(),
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
logicalProject(
logicalJoin(
logicalOlapScan(),
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
)
)
)
);
@ -417,14 +451,16 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
logicalProject(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
)
)
)
);
@ -460,28 +496,29 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("k1 = 3")),
logicalProject(
logicalJoin(
logicalProject(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("k1 = 3")),
logicalProject(
logicalJoin(
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("k3 = 3"))
),
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("k1 = 3"))
)
),
logicalAggregate(
logicalProject(
logicalOlapScan()
logicalJoin(
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("k3 = 3"))
),
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("k1 = 3"))
)
),
logicalAggregate(
logicalProject(
logicalOlapScan()
)
)
)
)
@ -498,18 +535,20 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
innerLogicalJoin(
logicalProject(
innerLogicalJoin(
innerLogicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1"))
).when(filter -> filter.getPredicate().toSql().contains("id > 1"))
)
)
);
}
@ -522,18 +561,20 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalProject(
logicalJoin(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
),
logicalFilter(
logicalOlapScan()
logicalOlapScan()
).when(filter -> filter.getPredicate().toSql().contains("id > 1"))
)
)
);
}
@ -549,6 +590,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalProject(
logicalJoin(
logicalFilter(
logicalOlapScan()
@ -559,6 +601,7 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter
).when(filter -> filter.getPredicate().toSql().contains("sid > 1"))
)
)
)
);
}
}

View File

@ -96,10 +96,12 @@ class ReorderJoinTest implements MemoPatternMatchSupported {
PlanChecker.from(MemoTestUtils.createConnectContext(), plan2)
.rewrite()
.matchesFromRoot(
logicalProject(
logicalJoin(
logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
logicalOlapScan()
logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
logicalOlapScan()
).whenNot(join -> join.getJoinType().isCrossJoin())
)
);
}
@ -124,10 +126,12 @@ class ReorderJoinTest implements MemoPatternMatchSupported {
PlanChecker.from(MemoTestUtils.createConnectContext(), plan2)
.rewrite()
.matchesFromRoot(
logicalProject(
innerLogicalJoin(
leftSemiLogicalJoin(),
logicalOlapScan()
leftSemiLogicalJoin(),
logicalOlapScan()
)
)
);
}
@ -183,10 +187,12 @@ class ReorderJoinTest implements MemoPatternMatchSupported {
.rewrite()
.printlnTree()
.matchesFromRoot(
logicalProject(
logicalJoin(
logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
leafPlan()
logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
leafPlan()
).whenNot(join -> join.getJoinType().isCrossJoin())
)
);
}
}

View File

@ -30,10 +30,12 @@ public class InferTest extends SqlTestBase {
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalProject(
innerLogicalJoin(
logicalFilter().when(f -> f.getPredicate().toString().equals("(id#0 = 4)")),
logicalFilter().when(f -> f.getPredicate().toString().equals("(id#2 = 4)"))
logicalFilter().when(f -> f.getPredicate().toString().equals("(id#0 = 4)")),
logicalFilter().when(f -> f.getPredicate().toString().equals("(id#2 = 4)"))
)
)
);
}
@ -46,11 +48,13 @@ public class InferTest extends SqlTestBase {
.rewrite()
.printlnTree()
.matchesFromRoot(
logicalProject(
innerLogicalJoin(
logicalFilter().when(
f -> f.getPredicate().toString().equals("((id#0 = 4) OR (id#0 > 4))")),
logicalOlapScan()
logicalFilter().when(
f -> f.getPredicate().toString().equals("((id#0 = 4) OR (id#0 > 4))")),
logicalOlapScan()
)
)
);
}
@ -62,14 +66,16 @@ public class InferTest extends SqlTestBase {
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalProject(
logicalFilter(
leftOuterLogicalJoin(
logicalFilter().when(
f -> f.getPredicate().toString().equals("((id#0 = 4) OR (id#0 > 4))")),
logicalOlapScan()
)
leftOuterLogicalJoin(
logicalFilter().when(
f -> f.getPredicate().toString().equals("((id#0 = 4) OR (id#0 > 4))")),
logicalOlapScan()
)
).when(f -> f.getPredicate().toString()
.equals("((id#0 = 4) OR ((id#0 > 4) AND score#3 IS NULL))"))
)
);
}

View File

@ -211,4 +211,64 @@ suite("join") {
logger.info(exception.message)
}
}
sql """drop table if exists test_memo_1"""
sql """drop table if exists test_memo_2"""
sql """drop table if exists test_memo_2"""
sql """ CREATE TABLE `test_memo_1` (
`c_bigint` bigint(20) NULL,
`c_long_decimal` decimal(27, 9) NULL
) ENGINE=OLAP
DUPLICATE KEY(`c_bigint`)
COMMENT 'OLAP'
DISTRIBUTED BY HASH(`c_bigint`) BUCKETS 1
PROPERTIES (
"replication_allocation" = "tag.location.default: 1",
"storage_format" = "V2",
"light_schema_change" = "true",
"disable_auto_compaction" = "false"
);
"""
sql """ CREATE TABLE `test_memo_2` (
`sk` bigint(20) NULL,
`id` int(11) NULL
) ENGINE=OLAP
UNIQUE KEY(`sk`)
COMMENT 'OLAP'
DISTRIBUTED BY HASH(`sk`) BUCKETS 1
PROPERTIES (
"replication_allocation" = "tag.location.default: 1",
"storage_format" = "V2",
"light_schema_change" = "true",
"disable_auto_compaction" = "false"
);
"""
sql """ CREATE TABLE `test_memo_3` (
`id` bigint(20) NOT NULL,
`c1` varchar(150) NULL
) ENGINE=OLAP
UNIQUE KEY(`id`)
DISTRIBUTED BY HASH(`id`) BUCKETS 1
PROPERTIES (
"replication_allocation" = "tag.location.default: 1",
"storage_format" = "V2",
"light_schema_change" = "true",
"disable_auto_compaction" = "false"
);
"""
sql """
select
ref_1.`c_long_decimal` as c0,
ref_3.`c1` as c1
from
test_memo_1 as ref_1
inner join test_memo_2 as ref_2 on (case when true then 5 else 5 end is not NULL)
inner join test_memo_3 as ref_3 on (version() is not NULL)
where
ref_2.`id` is not NULL
"""
}