[feature](Nereids): use session variable to enable rule (#27036)

This commit is contained in:
jakevin
2023-11-20 20:23:24 +08:00
committed by GitHub
parent 20d7ab061b
commit fec94b7278
9 changed files with 300 additions and 27 deletions

View File

@ -30,6 +30,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -71,7 +72,7 @@ public class PushdownCountThroughJoin implements RewriteRuleFactory {
public List<Rule> buildRules() {
return ImmutableList.of(
logicalAggregate(innerLogicalJoin())
.when(agg -> agg.child().getOtherJoinConjuncts().size() == 0)
.when(agg -> agg.child().getOtherJoinConjuncts().isEmpty())
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
.when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot))
.when(agg -> {
@ -80,11 +81,19 @@ public class PushdownCountThroughJoin implements RewriteRuleFactory {
.allMatch(f -> f instanceof Count && !f.isDistinct()
&& (((Count) f).isCountStar() || f.child(0) instanceof Slot));
})
.then(agg -> pushCount(agg, agg.child(), ImmutableList.of()))
.thenApply(ctx -> {
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
.getSessionVariable().getEnableNereidsRules();
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN.type())) {
return null;
}
LogicalAggregate<LogicalJoin<Plan, Plan>> agg = ctx.root;
return pushCount(agg, agg.child(), ImmutableList.of());
})
.toRule(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN),
logicalAggregate(logicalProject(innerLogicalJoin()))
.when(agg -> agg.child().isAllSlots())
.when(agg -> agg.child().child().getOtherJoinConjuncts().size() == 0)
.when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty())
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
.when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot))
.when(agg -> {
@ -93,7 +102,15 @@ public class PushdownCountThroughJoin implements RewriteRuleFactory {
.allMatch(f -> f instanceof Count && !f.isDistinct()
&& (((Count) f).isCountStar() || f.child(0) instanceof Slot));
})
.then(agg -> pushCount(agg, agg.child().child(), agg.child().getProjects()))
.thenApply(ctx -> {
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
.getSessionVariable().getEnableNereidsRules();
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN.type())) {
return null;
}
LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = ctx.root;
return pushCount(agg, agg.child().child(), agg.child().getProjects());
})
.toRule(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN)
);
}

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Relation;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
@ -29,6 +30,7 @@ import org.apache.doris.nereids.util.PlanUtils;
import com.google.common.collect.ImmutableList;
import java.util.Set;
import java.util.function.Function;
/**
@ -37,6 +39,11 @@ import java.util.function.Function;
public class PushdownDistinctThroughJoin extends DefaultPlanRewriter<JobContext> implements CustomRewriter {
@Override
public Plan rewriteRoot(Plan plan, JobContext context) {
Set<Integer> enableNereidsRules = context.getCascadesContext().getConnectContext()
.getSessionVariable().getEnableNereidsRules();
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_DISTINCT_THROUGH_JOIN.type())) {
return null;
}
return plan.accept(this, context);
}

View File

@ -29,6 +29,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -65,7 +66,7 @@ public class PushdownMinMaxThroughJoin implements RewriteRuleFactory {
public List<Rule> buildRules() {
return ImmutableList.of(
logicalAggregate(innerLogicalJoin())
.when(agg -> agg.child().getOtherJoinConjuncts().size() == 0)
.when(agg -> agg.child().getOtherJoinConjuncts().isEmpty())
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
@ -73,11 +74,19 @@ public class PushdownMinMaxThroughJoin implements RewriteRuleFactory {
.allMatch(f -> (f instanceof Min || f instanceof Max) && !f.isDistinct() && f.child(
0) instanceof Slot);
})
.then(agg -> pushMinMax(agg, agg.child(), ImmutableList.of()))
.thenApply(ctx -> {
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
.getSessionVariable().getEnableNereidsRules();
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN.type())) {
return null;
}
LogicalAggregate<LogicalJoin<Plan, Plan>> agg = ctx.root;
return pushMinMax(agg, agg.child(), ImmutableList.of());
})
.toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN),
logicalAggregate(logicalProject(innerLogicalJoin()))
.when(agg -> agg.child().isAllSlots())
.when(agg -> agg.child().child().getOtherJoinConjuncts().size() == 0)
.when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty())
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
@ -86,7 +95,15 @@ public class PushdownMinMaxThroughJoin implements RewriteRuleFactory {
f -> (f instanceof Min || f instanceof Max) && !f.isDistinct() && f.child(
0) instanceof Slot);
})
.then(agg -> pushMinMax(agg, agg.child().child(), agg.child().getProjects()))
.thenApply(ctx -> {
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
.getSessionVariable().getEnableNereidsRules();
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN.type())) {
return null;
}
LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = ctx.root;
return pushMinMax(agg, agg.child().child(), agg.child().getProjects());
})
.toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN)
);
}

View File

@ -30,6 +30,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
@ -65,25 +66,41 @@ public class PushdownSumThroughJoin implements RewriteRuleFactory {
public List<Rule> buildRules() {
return ImmutableList.of(
logicalAggregate(innerLogicalJoin())
.when(agg -> agg.child().getOtherJoinConjuncts().size() == 0)
.when(agg -> agg.child().getOtherJoinConjuncts().isEmpty())
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot);
})
.then(agg -> pushSum(agg, agg.child(), ImmutableList.of()))
.thenApply(ctx -> {
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
.getSessionVariable().getEnableNereidsRules();
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_SUM_THROUGH_JOIN.type())) {
return null;
}
LogicalAggregate<LogicalJoin<Plan, Plan>> agg = ctx.root;
return pushSum(agg, agg.child(), ImmutableList.of());
})
.toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN),
logicalAggregate(logicalProject(innerLogicalJoin()))
.when(agg -> agg.child().isAllSlots())
.when(agg -> agg.child().child().getOtherJoinConjuncts().size() == 0)
.when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty())
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot);
})
.then(agg -> pushSum(agg, agg.child().child(), agg.child().getProjects()))
.thenApply(ctx -> {
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
.getSessionVariable().getEnableNereidsRules();
if (!enableNereidsRules.contains(RuleType.PUSHDOWN_SUM_THROUGH_JOIN.type())) {
return null;
}
LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = ctx.root;
return pushSum(agg, agg.child().child(), agg.child().getProjects());
})
.toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN)
);
}

View File

@ -939,6 +939,9 @@ public class SessionVariable implements Serializable, Writable {
@VariableMgr.VarAttr(name = DISABLE_NEREIDS_RULES, needForward = true)
private String disableNereidsRules = "";
@VariableMgr.VarAttr(name = "ENABLE_NEREIDS_RULES", needForward = true)
public String enableNereidsRules = "";
@VariableMgr.VarAttr(name = ENABLE_NEW_COST_MODEL, needForward = true)
private boolean enableNewCostModel = false;
@ -2285,6 +2288,14 @@ public class SessionVariable implements Serializable, Writable {
.collect(ImmutableSet.toImmutableSet());
}
public Set<Integer> getEnableNereidsRules() {
return Arrays.stream(enableNereidsRules.split(",[\\s]*"))
.filter(rule -> !rule.isEmpty())
.map(rule -> rule.toUpperCase(Locale.ROOT))
.map(rule -> RuleType.valueOf(rule).type())
.collect(ImmutableSet.toImmutableSet());
}
public void setEnableNewCostModel(boolean enable) {
this.enableNewCostModel = enable;
}