From 045dd05f2aecf11dc9651edca5e0b3e1c212abc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E5=81=A5?= Date: Mon, 8 Apr 2024 11:08:05 +0800 Subject: [PATCH] [fix](Nereids): don't transpose agg and join if join is mark join (#33312) --- .../exploration/TransposeAggSemiJoin.java | 4 ++-- .../TransposeAggSemiJoinProject.java | 4 ++-- .../exploration/TransposeAggSemiJoinTest.java | 18 ++++++++++++++++++ .../TransposeSemiJoinAggProjectTest.java | 15 +++++++++++++++ 4 files changed, 37 insertions(+), 4 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoin.java index e25e1c816a..564fc07513 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoin.java @@ -31,8 +31,8 @@ public class TransposeAggSemiJoin extends OneExplorationRuleFactory { @Override public Rule build() { - return logicalAggregate(logicalJoin()) - .when(agg -> agg.child().getJoinType().isLeftSemiOrAntiJoin()) + return logicalAggregate( + logicalJoin().when(join -> join.getJoinType().isLeftSemiOrAntiJoin() && !join.isMarkJoin())) .then(agg -> { LogicalJoin join = agg.child(); if (!TransposeSemiJoinAgg.canTranspose(agg, join)) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinProject.java index f1a7355a19..9beb93b965 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinProject.java @@ -32,8 +32,8 @@ public class TransposeAggSemiJoinProject extends OneExplorationRuleFactory { @Override public Rule build() { - return logicalAggregate(logicalProject(logicalJoin())) - .when(agg -> agg.child().child().getJoinType().isLeftSemiOrAntiJoin()) + return logicalAggregate(logicalProject( + logicalJoin().when(join -> join.getJoinType().isLeftSemiOrAntiJoin() && !join.isMarkJoin()))) .then(agg -> { LogicalProject> project = agg.child(); LogicalJoin join = project.child(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinTest.java index 9c1e19282a..68cac382bc 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinTest.java @@ -30,6 +30,7 @@ import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; class TransposeAggSemiJoinTest implements MemoPatternMatchSupported { @@ -57,4 +58,21 @@ class TransposeAggSemiJoinTest implements MemoPatternMatchSupported { ) ); } + + @Test + void markJoin() { + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .markJoin(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), + ImmutableList.of( + scan1.getOutput().get(0), + new Alias(new Sum(scan1.getOutput().get(1)), "sum") + ) + ) + .build(); + int size = PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyExploration(TransposeAggSemiJoin.INSTANCE.build()) + .getAllPlan().size(); + Assertions.assertEquals(1, size); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/TransposeSemiJoinAggProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/TransposeSemiJoinAggProjectTest.java index ae91e5074e..810ab1e629 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/TransposeSemiJoinAggProjectTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/TransposeSemiJoinAggProjectTest.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.exploration.TransposeAggSemiJoin; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; @@ -28,6 +29,7 @@ import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; class TransposeSemiJoinAggProjectTest implements MemoPatternMatchSupported { @@ -53,4 +55,17 @@ class TransposeSemiJoinAggProjectTest implements MemoPatternMatchSupported { ); } + @Test + void markJoin() { + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .aggAllUsingIndex(ImmutableList.of(0, 1), ImmutableList.of(0, 1)) + .project(ImmutableList.of(0)) + .markJoin(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) + .build(); + int size = PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyExploration(TransposeAggSemiJoin.INSTANCE.build()) + .getAllPlan().size(); + Assertions.assertEquals(1, size); + } + }