[feature](Nereids): use session variable to enable rule (#27036)
This commit is contained in:
@ -30,6 +30,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -71,7 +72,7 @@ public class PushdownCountThroughJoin implements RewriteRuleFactory {
|
||||
public List<Rule> buildRules() {
|
||||
return ImmutableList.of(
|
||||
logicalAggregate(innerLogicalJoin())
|
||||
.when(agg -> agg.child().getOtherJoinConjuncts().size() == 0)
|
||||
.when(agg -> agg.child().getOtherJoinConjuncts().isEmpty())
|
||||
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
|
||||
.when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot))
|
||||
.when(agg -> {
|
||||
@ -80,11 +81,19 @@ public class PushdownCountThroughJoin implements RewriteRuleFactory {
|
||||
.allMatch(f -> f instanceof Count && !f.isDistinct()
|
||||
&& (((Count) f).isCountStar() || f.child(0) instanceof Slot));
|
||||
})
|
||||
.then(agg -> pushCount(agg, agg.child(), ImmutableList.of()))
|
||||
.thenApply(ctx -> {
|
||||
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
|
||||
.getSessionVariable().getEnableNereidsRules();
|
||||
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN.type())) {
|
||||
return null;
|
||||
}
|
||||
LogicalAggregate<LogicalJoin<Plan, Plan>> agg = ctx.root;
|
||||
return pushCount(agg, agg.child(), ImmutableList.of());
|
||||
})
|
||||
.toRule(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN),
|
||||
logicalAggregate(logicalProject(innerLogicalJoin()))
|
||||
.when(agg -> agg.child().isAllSlots())
|
||||
.when(agg -> agg.child().child().getOtherJoinConjuncts().size() == 0)
|
||||
.when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty())
|
||||
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
|
||||
.when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot))
|
||||
.when(agg -> {
|
||||
@ -93,7 +102,15 @@ public class PushdownCountThroughJoin implements RewriteRuleFactory {
|
||||
.allMatch(f -> f instanceof Count && !f.isDistinct()
|
||||
&& (((Count) f).isCountStar() || f.child(0) instanceof Slot));
|
||||
})
|
||||
.then(agg -> pushCount(agg, agg.child().child(), agg.child().getProjects()))
|
||||
.thenApply(ctx -> {
|
||||
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
|
||||
.getSessionVariable().getEnableNereidsRules();
|
||||
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN.type())) {
|
||||
return null;
|
||||
}
|
||||
LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = ctx.root;
|
||||
return pushCount(agg, agg.child().child(), agg.child().getProjects());
|
||||
})
|
||||
.toRule(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN)
|
||||
);
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.jobs.JobContext;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.algebra.Relation;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
@ -29,6 +30,7 @@ import org.apache.doris.nereids.util.PlanUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
|
||||
/**
|
||||
@ -37,6 +39,11 @@ import java.util.function.Function;
|
||||
public class PushdownDistinctThroughJoin extends DefaultPlanRewriter<JobContext> implements CustomRewriter {
|
||||
@Override
|
||||
public Plan rewriteRoot(Plan plan, JobContext context) {
|
||||
Set<Integer> enableNereidsRules = context.getCascadesContext().getConnectContext()
|
||||
.getSessionVariable().getEnableNereidsRules();
|
||||
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_DISTINCT_THROUGH_JOIN.type())) {
|
||||
return null;
|
||||
}
|
||||
return plan.accept(this, context);
|
||||
}
|
||||
|
||||
|
||||
@ -29,6 +29,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -65,7 +66,7 @@ public class PushdownMinMaxThroughJoin implements RewriteRuleFactory {
|
||||
public List<Rule> buildRules() {
|
||||
return ImmutableList.of(
|
||||
logicalAggregate(innerLogicalJoin())
|
||||
.when(agg -> agg.child().getOtherJoinConjuncts().size() == 0)
|
||||
.when(agg -> agg.child().getOtherJoinConjuncts().isEmpty())
|
||||
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
|
||||
.when(agg -> {
|
||||
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
|
||||
@ -73,11 +74,19 @@ public class PushdownMinMaxThroughJoin implements RewriteRuleFactory {
|
||||
.allMatch(f -> (f instanceof Min || f instanceof Max) && !f.isDistinct() && f.child(
|
||||
0) instanceof Slot);
|
||||
})
|
||||
.then(agg -> pushMinMax(agg, agg.child(), ImmutableList.of()))
|
||||
.thenApply(ctx -> {
|
||||
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
|
||||
.getSessionVariable().getEnableNereidsRules();
|
||||
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN.type())) {
|
||||
return null;
|
||||
}
|
||||
LogicalAggregate<LogicalJoin<Plan, Plan>> agg = ctx.root;
|
||||
return pushMinMax(agg, agg.child(), ImmutableList.of());
|
||||
})
|
||||
.toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN),
|
||||
logicalAggregate(logicalProject(innerLogicalJoin()))
|
||||
.when(agg -> agg.child().isAllSlots())
|
||||
.when(agg -> agg.child().child().getOtherJoinConjuncts().size() == 0)
|
||||
.when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty())
|
||||
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
|
||||
.when(agg -> {
|
||||
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
|
||||
@ -86,7 +95,15 @@ public class PushdownMinMaxThroughJoin implements RewriteRuleFactory {
|
||||
f -> (f instanceof Min || f instanceof Max) && !f.isDistinct() && f.child(
|
||||
0) instanceof Slot);
|
||||
})
|
||||
.then(agg -> pushMinMax(agg, agg.child().child(), agg.child().getProjects()))
|
||||
.thenApply(ctx -> {
|
||||
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
|
||||
.getSessionVariable().getEnableNereidsRules();
|
||||
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN.type())) {
|
||||
return null;
|
||||
}
|
||||
LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = ctx.root;
|
||||
return pushMinMax(agg, agg.child().child(), agg.child().getProjects());
|
||||
})
|
||||
.toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN)
|
||||
);
|
||||
}
|
||||
|
||||
@ -30,6 +30,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableList.Builder;
|
||||
@ -65,25 +66,41 @@ public class PushdownSumThroughJoin implements RewriteRuleFactory {
|
||||
public List<Rule> buildRules() {
|
||||
return ImmutableList.of(
|
||||
logicalAggregate(innerLogicalJoin())
|
||||
.when(agg -> agg.child().getOtherJoinConjuncts().size() == 0)
|
||||
.when(agg -> agg.child().getOtherJoinConjuncts().isEmpty())
|
||||
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
|
||||
.when(agg -> {
|
||||
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
|
||||
return !funcs.isEmpty() && funcs.stream()
|
||||
.allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot);
|
||||
})
|
||||
.then(agg -> pushSum(agg, agg.child(), ImmutableList.of()))
|
||||
.thenApply(ctx -> {
|
||||
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
|
||||
.getSessionVariable().getEnableNereidsRules();
|
||||
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_SUM_THROUGH_JOIN.type())) {
|
||||
return null;
|
||||
}
|
||||
LogicalAggregate<LogicalJoin<Plan, Plan>> agg = ctx.root;
|
||||
return pushSum(agg, agg.child(), ImmutableList.of());
|
||||
})
|
||||
.toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN),
|
||||
logicalAggregate(logicalProject(innerLogicalJoin()))
|
||||
.when(agg -> agg.child().isAllSlots())
|
||||
.when(agg -> agg.child().child().getOtherJoinConjuncts().size() == 0)
|
||||
.when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty())
|
||||
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
|
||||
.when(agg -> {
|
||||
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
|
||||
return !funcs.isEmpty() && funcs.stream()
|
||||
.allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot);
|
||||
})
|
||||
.then(agg -> pushSum(agg, agg.child().child(), agg.child().getProjects()))
|
||||
.thenApply(ctx -> {
|
||||
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
|
||||
.getSessionVariable().getEnableNereidsRules();
|
||||
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_SUM_THROUGH_JOIN.type())) {
|
||||
return null;
|
||||
}
|
||||
LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = ctx.root;
|
||||
return pushSum(agg, agg.child().child(), agg.child().getProjects());
|
||||
})
|
||||
.toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN)
|
||||
);
|
||||
}
|
||||
|
||||
@ -939,6 +939,9 @@ public class SessionVariable implements Serializable, Writable {
|
||||
@VariableMgr.VarAttr(name = DISABLE_NEREIDS_RULES, needForward = true)
|
||||
private String disableNereidsRules = "";
|
||||
|
||||
@VariableMgr.VarAttr(name = "ENABLE_NEREIDS_RULES", needForward = true)
|
||||
public String enableNereidsRules = "";
|
||||
|
||||
@VariableMgr.VarAttr(name = ENABLE_NEW_COST_MODEL, needForward = true)
|
||||
private boolean enableNewCostModel = false;
|
||||
|
||||
@ -2285,6 +2288,14 @@ public class SessionVariable implements Serializable, Writable {
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
}
|
||||
|
||||
public Set<Integer> getEnableNereidsRules() {
|
||||
return Arrays.stream(enableNereidsRules.split(",[\\s]*"))
|
||||
.filter(rule -> !rule.isEmpty())
|
||||
.map(rule -> rule.toUpperCase(Locale.ROOT))
|
||||
.map(rule -> RuleType.valueOf(rule).type())
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
}
|
||||
|
||||
public void setEnableNewCostModel(boolean enable) {
|
||||
this.enableNewCostModel = enable;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user