[fix](Nereids): other cond should be kept for each anti join when expanding anti join such as (#31521)

This commit is contained in:
谢健
2024-02-28 18:14:47 +08:00
committed by yiguolei
parent ac38356058
commit 4de25ede85
3 changed files with 425 additions and 52 deletions

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
@ -150,11 +151,11 @@ public class OrExpansion extends OneExplorationRuleFactory {
}
// expand Anti Join:
// Left Anti join cond1 or cond2 Left Anti join cond1
// / \ / \
//left right ===> Anti join cond2 CTERight2
// / \
// CTELeft CTERight1
// Left Anti join cond1 or cond2, other Left Anti join cond1 and other
// / \ / \
//left right ===> Anti join cond2 and other CTERight2
// / \
// CTELeft CTERight1
private Plan expandLeftAntiJoin(CascadesContext ctx,
Pair<List<Expression>, List<Expression>> hashOtherConditions,
LogicalJoin<? extends Plan, ? extends Plan> originJoin,
@ -171,14 +172,14 @@ public class OrExpansion extends OneExplorationRuleFactory {
replaced.putAll(right.getProducerToConsumerOutputMap());
List<Expression> disjunctions = hashOtherConditions.first;
List<Expression> otherConditions = hashOtherConditions.second;
otherConditions = otherConditions.stream()
List<Expression> newOtherConditions = otherConditions.stream()
.map(e -> e.rewriteUp(s -> replaced.containsKey(s) ? replaced.get(s) : s)).collect(Collectors.toList());
Expression hashCond = disjunctions.get(0);
hashCond = hashCond.rewriteUp(s -> replaced.containsKey(s) ? replaced.get(s) : s);
Plan newPlan = new LogicalJoin<>(JoinType.LEFT_ANTI_JOIN, Lists.newArrayList(hashCond),
otherConditions, originJoin.getDistributeHint(),
originJoin.getMarkJoinSlotReference(), left, right, null);
newOtherConditions, originJoin.getDistributeHint(),
originJoin.getMarkJoinSlotReference(), left, right, JoinReorderContext.EMPTY);
if (hashCond.children().stream().anyMatch(e -> !(e instanceof Slot))) {
Plan normalizedPlan = PushDownExpressionsInHashCondition.pushDownHashExpression(
(LogicalJoin<? extends Plan, ? extends Plan>) newPlan);
@ -192,10 +193,13 @@ public class OrExpansion extends OneExplorationRuleFactory {
ctx.putCTEIdToConsumer(newRight);
Map<Slot, Slot> newReplaced = new HashMap<>(left.getProducerToConsumerOutputMap());
newReplaced.putAll(newRight.getProducerToConsumerOutputMap());
newOtherConditions = otherConditions.stream()
.map(e -> e.rewriteUp(s -> newReplaced.containsKey(s) ? newReplaced.get(s) : s))
.collect(Collectors.toList());
hashCond = hashCond.rewriteUp(s -> newReplaced.containsKey(s) ? newReplaced.get(s) : s);
newPlan = new LogicalJoin<>(JoinType.LEFT_ANTI_JOIN, Lists.newArrayList(hashCond),
new ArrayList<>(), originJoin.getDistributeHint(),
originJoin.getMarkJoinSlotReference(), newPlan, newRight, null);
newOtherConditions, originJoin.getDistributeHint(),
originJoin.getMarkJoinSlotReference(), newPlan, newRight, JoinReorderContext.EMPTY);
if (hashCond.children().stream().anyMatch(e -> !(e instanceof Slot))) {
newPlan = PushDownExpressionsInHashCondition.pushDownHashExpression(
(LogicalJoin<? extends Plan, ? extends Plan>) newPlan);

View File

@ -2,11 +2,18 @@
-- !order_ij --
\N \N
1 \N
1 \N
\N 1
1 1
1 1
1 1
1 1
1 1
1 1
-- !order_laj --
\N
\N
0
2
3
@ -20,8 +27,10 @@
-- !order_loj --
\N \N
\N \N
\N \N
0 \N
1 \N
1 \N
2 \N
3 \N
4 \N
@ -30,6 +39,11 @@
7 \N
8 \N
9 \N
\N 1
1 1
1 1
1 1
1 1
1 1
1 1
@ -37,8 +51,11 @@
\N \N
\N \N
\N \N
\N \N
\N \N
0 \N
1 \N
1 \N
2 \N
3 \N
4 \N
@ -47,6 +64,11 @@
7 \N
8 \N
9 \N
\N 1
1 1
1 1
1 1
1 1
1 1
1 1
\N 20
@ -60,3 +82,282 @@
\N 28
\N 29
-- !order_ij_multi_cond --
\N \N
1 \N
1 \N
\N 1
1 1
1 1
-- !order_laj_multi_cond --
\N
\N
0
1
2
3
4
5
6
7
8
9
-- !order_loj_multi_cond --
\N \N
\N \N
\N \N
0 \N
1 \N
1 \N
1 \N
2 \N
3 \N
4 \N
5 \N
6 \N
7 \N
8 \N
9 \N
\N 1
1 1
1 1
-- !order_foj_multi_cond --
\N \N
\N \N
\N \N
\N \N
\N \N
0 \N
1 \N
1 \N
1 \N
2 \N
3 \N
4 \N
5 \N
6 \N
7 \N
8 \N
9 \N
\N 1
\N 1
1 1
1 1
\N 20
\N 21
\N 22
\N 23
\N 24
\N 25
\N 26
\N 27
\N 28
\N 29
-- !order_loj_unary_cond --
\N \N
\N \N
\N \N
0 \N
1 \N
1 \N
1 \N
2 \N
3 \N
4 \N
5 \N
6 \N
7 \N
8 \N
9 \N
\N 1
1 1
1 1
1 1
1 1
-- !order_foj_unary_cond --
\N \N
\N \N
\N \N
\N \N
\N \N
0 \N
1 \N
1 \N
1 \N
2 \N
3 \N
4 \N
5 \N
6 \N
7 \N
8 \N
9 \N
\N 1
1 1
1 1
1 1
1 1
\N 20
\N 21
\N 22
\N 23
\N 24
\N 25
\N 26
\N 27
\N 28
\N 29
-- !order_loj_unary_cond --
\N \N
\N \N
\N \N
\N \N
\N \N
\N \N
\N \N
0 \N
1 \N
1 \N
1 \N
1 \N
1 \N
1 \N
2 \N
3 \N
4 \N
5 \N
6 \N
7 \N
8 \N
9 \N
\N 1
\N 1
\N 1
\N 1
1 1
1 1
1 1
1 1
1 1
1 1
\N 20
\N 20
1 20
1 20
\N 21
\N 21
1 21
1 21
\N 22
\N 22
1 22
1 22
\N 23
\N 23
1 23
1 23
\N 24
\N 24
1 24
1 24
\N 25
\N 25
1 25
1 25
\N 26
\N 26
1 26
1 26
\N 27
\N 27
1 27
1 27
\N 28
\N 28
1 28
1 28
\N 29
\N 29
1 29
1 29
-- !order_foj_unary_cond --
\N \N
\N \N
\N \N
\N \N
\N \N
\N \N
\N \N
0 \N
1 \N
1 \N
1 \N
1 \N
1 \N
1 \N
2 \N
3 \N
4 \N
5 \N
6 \N
7 \N
8 \N
9 \N
\N 1
\N 1
\N 1
\N 1
1 1
1 1
1 1
1 1
1 1
1 1
\N 20
\N 20
1 20
1 20
\N 21
\N 21
1 21
1 21
\N 22
\N 22
1 22
1 22
\N 23
\N 23
1 23
1 23
\N 24
\N 24
1 24
1 24
\N 25
\N 25
1 25
1 25
\N 26
\N 26
1 26
1 26
\N 27
\N 27
1 27
1 27
\N 28
\N 28
1 28
1 28
\N 29
\N 29
1 29
1 29

View File

@ -19,12 +19,14 @@ suite("or_expansion") {
sql "SET enable_nereids_planner=true"
sql "SET enable_fallback_to_original_planner=false"
sql "SET enable_pipeline_engine = true"
sql "drop table if exists oe1"
sql "drop table if exists oe2"
sql """
CREATE TABLE IF NOT EXISTS oe1 (
k0 bigint,
k1 bigint
k1 bigint,
k2 bigint
)
DUPLICATE KEY(k0)
DISTRIBUTED BY HASH(k0) BUCKETS 1
@ -36,7 +38,8 @@ suite("or_expansion") {
sql """
CREATE TABLE IF NOT EXISTS oe2 (
k0 bigint,
k1 bigint
k1 bigint,
k2 bigint
)
DUPLICATE KEY(k0)
DISTRIBUTED BY HASH(k0) BUCKETS 1
@ -57,53 +60,63 @@ suite("or_expansion") {
sql """
alter table oe2 modify column k1 set stats ('row_count'='1000', 'ndv'='1000', 'min_value'='1000', 'max_value'='2000', 'avg_size'='1000', 'max_size'='1000' )
"""
sql """
alter table oe1 modify column k2 set stats ('row_count'='1000', 'ndv'='1000', 'min_value'='1000', 'max_value'='2000', 'avg_size'='1000', 'max_size'='1000' )
"""
sql """
alter table oe2 modify column k2 set stats ('row_count'='1000', 'ndv'='1000', 'min_value'='1000', 'max_value'='2000', 'avg_size'='1000', 'max_size'='1000' )
"""
explain {
sql("""
select oe1.k0, oe2.k0
from oe1 inner join oe2
on oe1.k0 = oe2.k0 or oe1.k1 + 1 = oe2.k1 * 2
""")
contains "VHASH JOIN"
}
// explain {
// sql("""
// select oe1.k0, oe2.k0
// from oe1 inner join oe2
// on oe1.k0 = oe2.k0 or oe1.k1 + 1 = oe2.k1 * 2
// """)
// contains "VHASH JOIN"
// }
explain {
sql("""
select oe1.k0
from oe1 left anti join oe2
on oe1.k0 = oe2.k0 or oe1.k1 + 1 = oe2.k1 * 2
""")
contains "VHASH JOIN"
}
// explain {
// sql("""
// select oe1.k0
// from oe1 left anti join oe2
// on oe1.k0 = oe2.k0 or oe1.k1 + 1 = oe2.k1 * 2
// """)
// contains "VHASH JOIN"
// }
explain {
sql("""
select oe1.k0, oe2.k0
from oe1 left outer join oe2
on oe1.k0 = oe2.k0 or oe1.k1 + 1 = oe2.k1 * 2
""")
contains "VHASH JOIN"
}
// explain {
// sql("""
// select oe1.k0, oe2.k0
// from oe1 left outer join oe2
// on oe1.k0 = oe2.k0 or oe1.k1 + 1 = oe2.k1 * 2
// """)
// contains "VHASH JOIN"
// }
explain {
sql("""
select oe1.k0, oe2.k0
from oe1 full outer join oe2
on oe1.k0 = oe2.k0 or oe1.k1 + 1 = oe2.k1 * 2
""")
contains "VHASH JOIN"
}
// explain {
// sql("""
// select oe1.k0, oe2.k0
// from oe1 full outer join oe2
// on oe1.k0 = oe2.k0 or oe1.k1 + 1 = oe2.k1 * 2
// """)
// contains "VHASH JOIN"
// }
for (int i = 0; i < 10; i++) {
sql "insert into oe1 values(${i}, ${i})"
sql "insert into oe2 values(${i+20}, ${i+20})"
sql "insert into oe1 values(${i}, ${i}, ${i})"
sql "insert into oe2 values(${i+20}, ${i+20}, ${i+20})"
}
sql "insert into oe1 values(null, 1)"
sql "insert into oe1 values(1, null)"
sql "insert into oe1 values(null, null)"
sql "insert into oe2 values(null, 1)"
sql "insert into oe2 values(1, null)"
sql "insert into oe2 values(null, null)"
sql "insert into oe1 values(1, 1, 1)"
sql "insert into oe1 values(null, null, null)"
sql "insert into oe1 values(null, 1, 1)"
sql "insert into oe1 values(1, null, null)"
sql "insert into oe1 values(null, null, 1)"
sql "insert into oe2 values(1, 1, 1)"
sql "insert into oe2 values(null, null, null)"
sql "insert into oe2 values(null, 1, 1)"
sql "insert into oe2 values(1, null, null)"
sql "insert into oe2 values(null, null, 1)"
qt_order_ij """
select oe1.k0, oe2.k0
@ -132,4 +145,59 @@ suite("or_expansion") {
on oe1.k0 = oe2.k0 or oe1.k1 + 1 = oe2.k1 * 2
order by oe2.k0, oe1.k0
"""
qt_order_ij_multi_cond """
select oe1.k0, oe2.k0
from oe1 inner join oe2
on (oe1.k0 = oe2.k0 or oe1.k1 + 1 = oe2.k1 * 2) and oe1.k2 = oe2.k2
order by oe2.k0, oe1.k0
"""
qt_order_laj_multi_cond """
select oe1.k0
from oe1 left anti join oe2
on (oe1.k0 = oe2.k0 or oe1.k1 + 1 = oe2.k1 * 2) and oe1.k2 = oe2.k2
order by oe1.k0
"""
qt_order_loj_multi_cond """
select oe1.k0, oe2.k0
from oe1 left outer join oe2
on (oe1.k0 = oe2.k0 or oe1.k1 + 1 = oe2.k1 * 2) and oe1.k2 = oe2.k2
order by oe2.k0, oe1.k0
"""
qt_order_foj_multi_cond """
select oe1.k0, oe2.k0
from oe1 full outer join oe2
on (oe1.k0 = oe2.k0 or oe1.k1 + 1 = oe2.k1 * 2) and oe1.k2 = oe2.k2
order by oe2.k0, oe1.k0
"""
qt_order_loj_unary_cond """
select oe1.k0, oe2.k0
from oe1 left outer join oe2
on (oe1.k0 = oe2.k0 or oe1.k1 + 1 = oe2.k1 * 2) and oe1.k2 = 1
order by oe2.k0, oe1.k0
"""
qt_order_foj_unary_cond """
select oe1.k0, oe2.k0
from oe1 full outer join oe2
on (oe1.k0 = oe2.k0 or oe1.k1 + 1 = oe2.k1 * 2) and oe1.k2 = 1
order by oe2.k0, oe1.k0
"""
qt_order_loj_unary_cond """
select oe1.k0, oe2.k0
from oe1 left outer join oe2
on (oe1.k0 = oe2.k0 or oe1.k1 + 1 = oe2.k1 * 2) or oe1.k2 = 1
order by oe2.k0, oe1.k0
"""
qt_order_foj_unary_cond """
select oe1.k0, oe2.k0
from oe1 full outer join oe2
on (oe1.k0 = oe2.k0 or oe1.k1 + 1 = oe2.k1 * 2) or oe1.k2 = 1
order by oe2.k0, oe1.k0
"""
}