From fec94b727813d7acfd34d89344a05805a0a8eb49 Mon Sep 17 00:00:00 2001 From: jakevin Date: Mon, 20 Nov 2023 20:23:24 +0800 Subject: [PATCH] [feature](Nereids): use session variable to enable rule (#27036) --- .../rewrite/PushdownCountThroughJoin.java | 25 ++++- .../rewrite/PushdownDistinctThroughJoin.java | 7 ++ .../rewrite/PushdownMinMaxThroughJoin.java | 25 ++++- .../rules/rewrite/PushdownSumThroughJoin.java | 25 ++++- .../org/apache/doris/qe/SessionVariable.java | 11 ++ .../rewrite/PushdownCountThroughJoinTest.java | 82 +++++++++++++- .../PushdownDistinctThroughJoinTest.java | 19 ++++ .../PushdownMinMaxThroughJoinTest.java | 103 +++++++++++++++++- .../rewrite/PushdownSumThroughJoinTest.java | 30 ++++- 9 files changed, 300 insertions(+), 27 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java index 5bef7842a5..856a67d048 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java @@ -30,6 +30,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -71,7 +72,7 @@ public class PushdownCountThroughJoin implements RewriteRuleFactory { public List buildRules() { return ImmutableList.of( logicalAggregate(innerLogicalJoin()) - .when(agg -> agg.child().getOtherJoinConjuncts().size() == 0) + .when(agg -> agg.child().getOtherJoinConjuncts().isEmpty()) .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) .when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot)) .when(agg -> { @@ -80,11 +81,19 @@ public class PushdownCountThroughJoin implements RewriteRuleFactory { .allMatch(f -> f instanceof Count && !f.isDistinct() && (((Count) f).isCountStar() || f.child(0) instanceof Slot)); }) - .then(agg -> pushCount(agg, agg.child(), ImmutableList.of())) + .thenApply(ctx -> { + Set enableNereidsRules = ctx.cascadesContext.getConnectContext() + .getSessionVariable().getEnableNereidsRules(); + if (!enableNereidsRules.contains(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN.type())) { + return null; + } + LogicalAggregate> agg = ctx.root; + return pushCount(agg, agg.child(), ImmutableList.of()); + }) .toRule(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN), logicalAggregate(logicalProject(innerLogicalJoin())) .when(agg -> agg.child().isAllSlots()) - .when(agg -> agg.child().child().getOtherJoinConjuncts().size() == 0) + .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty()) .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) .when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot)) .when(agg -> { @@ -93,7 +102,15 @@ public class PushdownCountThroughJoin implements RewriteRuleFactory { .allMatch(f -> f instanceof Count && !f.isDistinct() && (((Count) f).isCountStar() || f.child(0) instanceof Slot)); }) - .then(agg -> pushCount(agg, agg.child().child(), agg.child().getProjects())) + .thenApply(ctx -> { + Set enableNereidsRules = ctx.cascadesContext.getConnectContext() + .getSessionVariable().getEnableNereidsRules(); + if (!enableNereidsRules.contains(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN.type())) { + return null; + } + LogicalAggregate>> agg = ctx.root; + return pushCount(agg, agg.child().child(), agg.child().getProjects()); + }) .toRule(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN) ); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownDistinctThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownDistinctThroughJoin.java index e32213f265..ac239e8f95 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownDistinctThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownDistinctThroughJoin.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.algebra.Relation; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; @@ -29,6 +30,7 @@ import org.apache.doris.nereids.util.PlanUtils; import com.google.common.collect.ImmutableList; +import java.util.Set; import java.util.function.Function; /** @@ -37,6 +39,11 @@ import java.util.function.Function; public class PushdownDistinctThroughJoin extends DefaultPlanRewriter implements CustomRewriter { @Override public Plan rewriteRoot(Plan plan, JobContext context) { + Set enableNereidsRules = context.getCascadesContext().getConnectContext() + .getSessionVariable().getEnableNereidsRules(); + if (!enableNereidsRules.contains(RuleType.PUSHDOWN_DISTINCT_THROUGH_JOIN.type())) { + return null; + } return plan.accept(this, context); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java index 9b728ad141..07f0bbc81c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java @@ -29,6 +29,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Min; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -65,7 +66,7 @@ public class PushdownMinMaxThroughJoin implements RewriteRuleFactory { public List buildRules() { return ImmutableList.of( logicalAggregate(innerLogicalJoin()) - .when(agg -> agg.child().getOtherJoinConjuncts().size() == 0) + .when(agg -> agg.child().getOtherJoinConjuncts().isEmpty()) .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) .when(agg -> { Set funcs = agg.getAggregateFunctions(); @@ -73,11 +74,19 @@ public class PushdownMinMaxThroughJoin implements RewriteRuleFactory { .allMatch(f -> (f instanceof Min || f instanceof Max) && !f.isDistinct() && f.child( 0) instanceof Slot); }) - .then(agg -> pushMinMax(agg, agg.child(), ImmutableList.of())) + .thenApply(ctx -> { + Set enableNereidsRules = ctx.cascadesContext.getConnectContext() + .getSessionVariable().getEnableNereidsRules(); + if (!enableNereidsRules.contains(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN.type())) { + return null; + } + LogicalAggregate> agg = ctx.root; + return pushMinMax(agg, agg.child(), ImmutableList.of()); + }) .toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN), logicalAggregate(logicalProject(innerLogicalJoin())) .when(agg -> agg.child().isAllSlots()) - .when(agg -> agg.child().child().getOtherJoinConjuncts().size() == 0) + .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty()) .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) .when(agg -> { Set funcs = agg.getAggregateFunctions(); @@ -86,7 +95,15 @@ public class PushdownMinMaxThroughJoin implements RewriteRuleFactory { f -> (f instanceof Min || f instanceof Max) && !f.isDistinct() && f.child( 0) instanceof Slot); }) - .then(agg -> pushMinMax(agg, agg.child().child(), agg.child().getProjects())) + .thenApply(ctx -> { + Set enableNereidsRules = ctx.cascadesContext.getConnectContext() + .getSessionVariable().getEnableNereidsRules(); + if (!enableNereidsRules.contains(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN.type())) { + return null; + } + LogicalAggregate>> agg = ctx.root; + return pushMinMax(agg, agg.child().child(), agg.child().getProjects()); + }) .toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN) ); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java index 81e655ab85..0ae16a0701 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java @@ -30,6 +30,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; @@ -65,25 +66,41 @@ public class PushdownSumThroughJoin implements RewriteRuleFactory { public List buildRules() { return ImmutableList.of( logicalAggregate(innerLogicalJoin()) - .when(agg -> agg.child().getOtherJoinConjuncts().size() == 0) + .when(agg -> agg.child().getOtherJoinConjuncts().isEmpty()) .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) .when(agg -> { Set funcs = agg.getAggregateFunctions(); return !funcs.isEmpty() && funcs.stream() .allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot); }) - .then(agg -> pushSum(agg, agg.child(), ImmutableList.of())) + .thenApply(ctx -> { + Set enableNereidsRules = ctx.cascadesContext.getConnectContext() + .getSessionVariable().getEnableNereidsRules(); + if (!enableNereidsRules.contains(RuleType.PUSHDOWN_SUM_THROUGH_JOIN.type())) { + return null; + } + LogicalAggregate> agg = ctx.root; + return pushSum(agg, agg.child(), ImmutableList.of()); + }) .toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN), logicalAggregate(logicalProject(innerLogicalJoin())) .when(agg -> agg.child().isAllSlots()) - .when(agg -> agg.child().child().getOtherJoinConjuncts().size() == 0) + .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty()) .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) .when(agg -> { Set funcs = agg.getAggregateFunctions(); return !funcs.isEmpty() && funcs.stream() .allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot); }) - .then(agg -> pushSum(agg, agg.child().child(), agg.child().getProjects())) + .thenApply(ctx -> { + Set enableNereidsRules = ctx.cascadesContext.getConnectContext() + .getSessionVariable().getEnableNereidsRules(); + if (!enableNereidsRules.contains(RuleType.PUSHDOWN_SUM_THROUGH_JOIN.type())) { + return null; + } + LogicalAggregate>> agg = ctx.root; + return pushSum(agg, agg.child().child(), agg.child().getProjects()); + }) .toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN) ); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java index 5b0c18cbeb..1fcaa56ddf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java @@ -939,6 +939,9 @@ public class SessionVariable implements Serializable, Writable { @VariableMgr.VarAttr(name = DISABLE_NEREIDS_RULES, needForward = true) private String disableNereidsRules = ""; + @VariableMgr.VarAttr(name = "ENABLE_NEREIDS_RULES", needForward = true) + public String enableNereidsRules = ""; + @VariableMgr.VarAttr(name = ENABLE_NEW_COST_MODEL, needForward = true) private boolean enableNewCostModel = false; @@ -2285,6 +2288,14 @@ public class SessionVariable implements Serializable, Writable { .collect(ImmutableSet.toImmutableSet()); } + public Set getEnableNereidsRules() { + return Arrays.stream(enableNereidsRules.split(",[\\s]*")) + .filter(rule -> !rule.isEmpty()) + .map(rule -> rule.toUpperCase(Locale.ROOT)) + .map(rule -> RuleType.valueOf(rule).type()) + .collect(ImmutableSet.toImmutableSet()); + } + public void setEnableNewCostModel(boolean enable) { this.enableNewCostModel = enable; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java index 10f814e9ba..69b9ae7751 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.plans.JoinType; @@ -28,16 +29,28 @@ import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; +import org.apache.doris.qe.SessionVariable; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import mockit.Mock; +import mockit.MockUp; import org.junit.jupiter.api.Test; +import java.util.Set; + class PushdownCountThroughJoinTest implements MemoPatternMatchSupported { private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); @Test void testSingleCount() { + new MockUp() { + @Mock + public Set getEnableNereidsRules() { + return ImmutableSet.of(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN.type()); + } + }; Alias count = new Count(scan1.getOutput().get(0)).alias("count"); LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) @@ -46,11 +59,24 @@ class PushdownCountThroughJoinTest implements MemoPatternMatchSupported { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new PushdownCountThroughJoin()) - .printlnTree(); + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalAggregate() + ) + ) + ); } @Test void testMultiCount() { + new MockUp() { + @Mock + public Set getEnableNereidsRules() { + return ImmutableSet.of(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN.type()); + } + }; Alias leftCnt1 = new Count(scan1.getOutput().get(0)).alias("leftCnt1"); Alias leftCnt2 = new Count(scan1.getOutput().get(1)).alias("leftCnt2"); Alias rightCnt1 = new Count(scan2.getOutput().get(1)).alias("rightCnt1"); @@ -62,11 +88,24 @@ class PushdownCountThroughJoinTest implements MemoPatternMatchSupported { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new PushdownCountThroughJoin()) - .printlnTree(); + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalAggregate() + ) + ) + ); } @Test void testSingleCountStar() { + new MockUp() { + @Mock + public Set getEnableNereidsRules() { + return ImmutableSet.of(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN.type()); + } + }; Alias count = new Count().alias("countStar"); LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) @@ -75,11 +114,24 @@ class PushdownCountThroughJoinTest implements MemoPatternMatchSupported { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new PushdownCountThroughJoin()) - .printlnTree(); + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalAggregate() + ) + ) + ); } @Test void testSingleCountStarEmptyGroupBy() { + new MockUp() { + @Mock + public Set getEnableNereidsRules() { + return ImmutableSet.of(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN.type()); + } + }; Alias count = new Count().alias("countStar"); LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) @@ -89,11 +141,24 @@ class PushdownCountThroughJoinTest implements MemoPatternMatchSupported { // shouldn't rewrite. PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new PushdownCountThroughJoin()) - .printlnTree(); + .matches( + logicalAggregate( + logicalJoin( + logicalOlapScan(), + logicalOlapScan() + ) + ) + ); } @Test void testBothSideCountAndCountStar() { + new MockUp() { + @Mock + public Set getEnableNereidsRules() { + return ImmutableSet.of(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN.type()); + } + }; Alias leftCnt = new Count(scan1.getOutput().get(0)).alias("leftCnt"); Alias rightCnt = new Count(scan2.getOutput().get(0)).alias("rightCnt"); Alias countStar = new Count().alias("countStar"); @@ -105,6 +170,13 @@ class PushdownCountThroughJoinTest implements MemoPatternMatchSupported { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new PushdownCountThroughJoin()) - .printlnTree(); + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalAggregate() + ) + ) + ); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownDistinctThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownDistinctThroughJoinTest.java index 3e3b2560de..feaeb5fa05 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownDistinctThroughJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownDistinctThroughJoinTest.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; @@ -26,10 +27,16 @@ import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; +import org.apache.doris.qe.SessionVariable; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import mockit.Mock; +import mockit.MockUp; import org.junit.jupiter.api.Test; +import java.util.Set; + class PushdownDistinctThroughJoinTest implements MemoPatternMatchSupported { private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); @@ -38,6 +45,12 @@ class PushdownDistinctThroughJoinTest implements MemoPatternMatchSupported { @Test void testPushdownJoin() { + new MockUp() { + @Mock + public Set getEnableNereidsRules() { + return ImmutableSet.of(RuleType.PUSHDOWN_DISTINCT_THROUGH_JOIN.type()); + } + }; LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0)) @@ -60,6 +73,12 @@ class PushdownDistinctThroughJoinTest implements MemoPatternMatchSupported { @Test void testPushdownProjectJoin() { + new MockUp() { + @Mock + public Set getEnableNereidsRules() { + return ImmutableSet.of(RuleType.PUSHDOWN_DISTINCT_THROUGH_JOIN.type()); + } + }; LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) .project(ImmutableList.of(0, 2)) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoinTest.java index 83c297f7a0..8c19c0c38c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoinTest.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.Min; @@ -29,10 +30,16 @@ import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; +import org.apache.doris.qe.SessionVariable; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import mockit.Mock; +import mockit.MockUp; import org.junit.jupiter.api.Test; +import java.util.Set; + class PushdownMinMaxThroughJoinTest implements MemoPatternMatchSupported { private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); @@ -41,6 +48,12 @@ class PushdownMinMaxThroughJoinTest implements MemoPatternMatchSupported { @Test void testSingleJoin() { + new MockUp() { + @Mock + public Set getEnableNereidsRules() { + return ImmutableSet.of(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN.type()); + } + }; Alias min = new Min(scan1.getOutput().get(0)).alias("min"); LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) @@ -49,11 +62,24 @@ class PushdownMinMaxThroughJoinTest implements MemoPatternMatchSupported { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new PushdownMinMaxThroughJoin()) - .printlnTree(); + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalOlapScan() + ) + ) + ); } @Test void testMultiJoin() { + new MockUp() { + @Mock + public Set getEnableNereidsRules() { + return ImmutableSet.of(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN.type()); + } + }; Alias min = new Min(scan1.getOutput().get(0)).alias("min"); LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) @@ -64,11 +90,35 @@ class PushdownMinMaxThroughJoinTest implements MemoPatternMatchSupported { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new PushdownMinMaxThroughJoin()) - .printlnTree(); + .printlnTree() + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate( + logicalJoin( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalOlapScan() + ) + ), + logicalOlapScan() + ) + ), + logicalOlapScan() + ) + ) + ); } @Test void testAggNotOutputGroupBy() { + new MockUp() { + @Mock + public Set getEnableNereidsRules() { + return ImmutableSet.of(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN.type()); + } + }; // agg don't output group by Alias min = new Min(scan1.getOutput().get(0)).alias("min"); LogicalPlan plan = new LogicalPlanBuilder(scan1) @@ -79,11 +129,29 @@ class PushdownMinMaxThroughJoinTest implements MemoPatternMatchSupported { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new PushdownMinMaxThroughJoin()) - .printlnTree(); + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalOlapScan() + ) + ), + logicalOlapScan() + ) + ) + ); } @Test void testBothSideSingleJoin() { + new MockUp() { + @Mock + public Set getEnableNereidsRules() { + return ImmutableSet.of(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN.type()); + } + }; Alias min = new Min(scan1.getOutput().get(1)).alias("min"); Alias max = new Max(scan2.getOutput().get(1)).alias("max"); LogicalPlan plan = new LogicalPlanBuilder(scan1) @@ -94,11 +162,24 @@ class PushdownMinMaxThroughJoinTest implements MemoPatternMatchSupported { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .printlnTree() .applyTopDown(new PushdownMinMaxThroughJoin()) - .printlnTree(); + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalAggregate() + ) + ) + ); } @Test void testBothSide() { + new MockUp() { + @Mock + public Set getEnableNereidsRules() { + return ImmutableSet.of(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN.type()); + } + }; Alias min = new Min(scan1.getOutput().get(1)).alias("min"); Alias max = new Max(scan3.getOutput().get(1)).alias("max"); LogicalPlan plan = new LogicalPlanBuilder(scan1) @@ -109,6 +190,18 @@ class PushdownMinMaxThroughJoinTest implements MemoPatternMatchSupported { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new PushdownMinMaxThroughJoin()) - .printlnTree(); + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalOlapScan() + ) + ), + logicalAggregate() + ) + ) + ); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoinTest.java index c6d65e784c..f97c824465 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoinTest.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.plans.JoinType; @@ -28,18 +29,28 @@ import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; +import org.apache.doris.qe.SessionVariable; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import mockit.Mock; +import mockit.MockUp; import org.junit.jupiter.api.Test; +import java.util.Set; + class PushdownSumThroughJoinTest implements MemoPatternMatchSupported { private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); - private static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); - private static final LogicalOlapScan scan4 = PlanConstructor.newLogicalOlapScan(3, "t4", 0); @Test void testSingleJoinLeftSum() { + new MockUp() { + @Mock + public Set getEnableNereidsRules() { + return ImmutableSet.of(RuleType.PUSHDOWN_SUM_THROUGH_JOIN.type()); + } + }; Alias sum = new Sum(scan1.getOutput().get(1)).alias("sum"); LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) @@ -47,7 +58,6 @@ class PushdownSumThroughJoinTest implements MemoPatternMatchSupported { .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .printlnTree() .applyTopDown(new PushdownSumThroughJoin()) .matches( logicalAggregate( @@ -61,6 +71,12 @@ class PushdownSumThroughJoinTest implements MemoPatternMatchSupported { @Test void testSingleJoinRightSum() { + new MockUp() { + @Mock + public Set getEnableNereidsRules() { + return ImmutableSet.of(RuleType.PUSHDOWN_SUM_THROUGH_JOIN.type()); + } + }; Alias sum = new Sum(scan2.getOutput().get(1)).alias("sum"); LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) @@ -68,7 +84,6 @@ class PushdownSumThroughJoinTest implements MemoPatternMatchSupported { .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .printlnTree() .applyTopDown(new PushdownSumThroughJoin()) .matches( logicalAggregate( @@ -82,6 +97,12 @@ class PushdownSumThroughJoinTest implements MemoPatternMatchSupported { @Test void testAggNotOutputGroupBy() { + new MockUp() { + @Mock + public Set getEnableNereidsRules() { + return ImmutableSet.of(RuleType.PUSHDOWN_SUM_THROUGH_JOIN.type()); + } + }; // agg don't output group by Alias sum = new Sum(scan1.getOutput().get(1)).alias("sum"); LogicalPlan plan = new LogicalPlanBuilder(scan1) @@ -90,7 +111,6 @@ class PushdownSumThroughJoinTest implements MemoPatternMatchSupported { .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .printlnTree() .applyTopDown(new PushdownSumThroughJoin()) .matches( logicalAggregate(