[enhancement](Nereids): speed up rewrite() (#22846)

- use Set<Integer> instead of Set<String> to speedup `contains`
- remove `getValidRules` and use `if` in `for` to avoid `toImmutableList`
This commit is contained in:
jakevin
2023-08-11 13:04:30 +08:00
committed by GitHub
parent caf496a67e
commit 080d613238
11 changed files with 48 additions and 47 deletions

View File

@ -37,12 +37,10 @@ import org.apache.doris.qe.SessionVariable;
import org.apache.doris.statistics.Statistics;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
@ -58,7 +56,7 @@ public abstract class Job implements TracerSupplier {
protected JobType type;
protected JobContext context;
protected boolean once;
protected final Set<String> disableRules;
protected final Set<Integer> disableRules;
protected Map<CTEId, Statistics> cteIdToStats;
@ -86,29 +84,6 @@ public abstract class Job implements TracerSupplier {
return once;
}
/**
* Get the rule set of this job. Filter out already applied rules and rules that are not matched on root node.
*
* @param groupExpression group expression to be applied on
* @param candidateRules rules to be applied
* @return all rules that can be applied on this group expression
*/
public List<Rule> getValidRules(GroupExpression groupExpression, List<Rule> candidateRules) {
return candidateRules.stream()
.filter(rule -> Objects.nonNull(rule)
&& !disableRules.contains(rule.getRuleType().name())
&& rule.getPattern().matchRoot(groupExpression.getPlan())
&& groupExpression.notApplied(rule))
.collect(ImmutableList.toImmutableList());
}
public List<Rule> getValidRules(List<Rule> candidateRules) {
return candidateRules.stream()
.filter(rule -> Objects.nonNull(rule)
&& !disableRules.contains(rule.getRuleType().name()))
.collect(ImmutableList.toImmutableList());
}
public abstract void execute();
public EventProducer getEventTracer() {
@ -149,13 +124,8 @@ public abstract class Job implements TracerSupplier {
groupExpression.getOwnerGroup(), groupExpression, groupExpression.getPlan()));
}
public static Set<String> getDisableRules(JobContext context) {
public static Set<Integer> getDisableRules(JobContext context) {
return context.getCascadesContext().getAndCacheSessionVariable(
"disableNereidsRules", ImmutableSet.of(), SessionVariable::getDisableNereidsRules);
}
public static boolean isTraceEnable(JobContext context) {
return context.getCascadesContext().getAndCacheSessionVariable(
"isTraceEnable", false, SessionVariable::isEnableNereidsTrace);
}
}

View File

@ -43,11 +43,17 @@ public class OptimizeGroupExpressionJob extends Job {
List<Rule> implementationRules = getRuleSet().getImplementationRules();
List<Rule> explorationRules = getExplorationRules();
for (Rule rule : getValidRules(groupExpression, explorationRules)) {
for (Rule rule : explorationRules) {
if (rule.isInvalid(disableRules, groupExpression)) {
continue;
}
pushJob(new ApplyRuleJob(groupExpression, rule, context));
}
for (Rule rule : getValidRules(groupExpression, implementationRules)) {
for (Rule rule : implementationRules) {
if (rule.isInvalid(disableRules, groupExpression)) {
continue;
}
pushJob(new ApplyRuleJob(groupExpression, rule, context));
}
}

View File

@ -23,7 +23,6 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import java.util.Locale;
import java.util.Objects;
import java.util.Set;
import java.util.function.Supplier;
@ -49,8 +48,8 @@ public class CustomRewriteJob implements RewriteJob {
@Override
public void execute(JobContext context) {
Set<String> disableRules = Job.getDisableRules(context);
if (disableRules.contains(ruleType.name().toUpperCase(Locale.ROOT))) {
Set<Integer> disableRules = Job.getDisableRules(context);
if (disableRules.contains(ruleType.type())) {
return;
}
Plan root = context.getCascadesContext().getRewritePlan();

View File

@ -42,8 +42,10 @@ public abstract class PlanTreeRewriteJob extends Job {
boolean isRewriteRoot = rewriteJobContext.isRewriteRoot();
CascadesContext cascadesContext = context.getCascadesContext();
cascadesContext.setIsRewriteRoot(isRewriteRoot);
List<Rule> validRules = getValidRules(rules);
for (Rule rule : validRules) {
for (Rule rule : rules) {
if (disableRules.contains(rule.getRuleType().type())) {
continue;
}
Pattern<Plan> pattern = (Pattern<Plan>) rule.getPattern();
if (pattern.matchPlanTree(plan)) {
List<Plan> newPlans = rule.transform(plan, cascadesContext);

View File

@ -87,8 +87,10 @@ public class RewriteBottomUpJob extends Job {
}
countJobExecutionTimesOfGroupExpressions(logicalExpression);
List<Rule> validRules = getValidRules(logicalExpression, rules);
for (Rule rule : validRules) {
for (Rule rule : rules) {
if (rule.isInvalid(disableRules, logicalExpression)) {
continue;
}
GroupExpressionMatching groupExpressionMatching
= new GroupExpressionMatching(rule.getPattern(), logicalExpression);
for (Plan before : groupExpressionMatching) {

View File

@ -84,8 +84,10 @@ public class RewriteTopDownJob extends Job {
public void execute() {
GroupExpression logicalExpression = group.getLogicalExpression();
countJobExecutionTimesOfGroupExpressions(logicalExpression);
List<Rule> validRules = getValidRules(logicalExpression, rules);
for (Rule rule : validRules) {
for (Rule rule : rules) {
if (rule.isInvalid(disableRules, logicalExpression)) {
continue;
}
Preconditions.checkArgument(rule.isRewrite(),
"rules must be rewritable in top down job");
GroupExpressionMatching groupExpressionMatching

View File

@ -19,11 +19,13 @@ package org.apache.doris.nereids.rules;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.TransformException;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.pattern.Pattern;
import org.apache.doris.nereids.rules.RuleType.RuleTypeClass;
import org.apache.doris.nereids.trees.plans.Plan;
import java.util.List;
import java.util.Set;
/**
* Abstract class for all rules.
@ -73,4 +75,13 @@ public abstract class Rule {
public void acceptPlan(Plan plan) {
}
/**
* Filter out already applied rules and rules that are not matched on root node.
*/
public boolean isInvalid(Set<Integer> disableRules, GroupExpression groupExpression) {
return disableRules.contains(this.getRuleType().type())
|| !groupExpression.notApplied(this)
|| !this.getPattern().matchRoot(groupExpression.getPlan());
}
}

View File

@ -64,7 +64,7 @@ import java.util.Set;
* | aggregate: count(*) as cntStar
* aggregate: count(x) as cnt
* </pre>
* Notice: when Count(*) exists, group by mustn't be empty.
* Notice: rule can't optimize condition that groupby is empty when Count(*) exists.
*/
public class PushdownCountThroughJoin implements RewriteRuleFactory {
@Override

View File

@ -28,6 +28,7 @@ import org.apache.doris.common.io.Writable;
import org.apache.doris.common.util.TimeUtils;
import org.apache.doris.nereids.metrics.Event;
import org.apache.doris.nereids.metrics.EventSwitchParser;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.qe.VariableMgr.VarAttr;
import org.apache.doris.thrift.TQueryOptions;
import org.apache.doris.thrift.TResourceLimit;
@ -1944,12 +1945,20 @@ public class SessionVariable implements Serializable, Writable {
return nthOptimizedPlan;
}
public Set<String> getDisableNereidsRules() {
public Set<String> getDisableNereidsRuleNames() {
return Arrays.stream(disableNereidsRules.split(",[\\s]*"))
.map(rule -> rule.toUpperCase(Locale.ROOT))
.collect(ImmutableSet.toImmutableSet());
}
public Set<Integer> getDisableNereidsRules() {
return Arrays.stream(disableNereidsRules.split(",[\\s]*"))
.filter(rule -> !rule.isEmpty())
.map(rule -> rule.toUpperCase(Locale.ROOT))
.map(rule -> RuleType.valueOf(rule).type())
.collect(ImmutableSet.toImmutableSet());
}
public void setEnableNewCostModel(boolean enable) {
this.enableNewCostModel = enable;
}

View File

@ -128,7 +128,7 @@ public class PlanChecker {
public PlanChecker analyze(Plan plan) {
this.cascadesContext = MemoTestUtils.createCascadesContext(connectContext, plan);
Set<String> originDisableRules = connectContext.getSessionVariable().getDisableNereidsRules();
Set<String> originDisableRules = connectContext.getSessionVariable().getDisableNereidsRuleNames();
Set<String> disableRuleWithAuth = Sets.newHashSet(originDisableRules);
disableRuleWithAuth.add(RuleType.RELATION_AUTHENTICATION.name());
connectContext.getSessionVariable().setDisableNereidsRules(String.join(",", disableRuleWithAuth));

View File

@ -194,7 +194,7 @@ public abstract class TestWithFeService {
}
public LogicalPlan analyze(String sql) {
Set<String> originDisableRules = connectContext.getSessionVariable().getDisableNereidsRules();
Set<String> originDisableRules = connectContext.getSessionVariable().getDisableNereidsRuleNames();
Set<String> disableRuleWithAuth = Sets.newHashSet(originDisableRules);
disableRuleWithAuth.add(RuleType.RELATION_AUTHENTICATION.name());
connectContext.getSessionVariable().setDisableNereidsRules(String.join(",", disableRuleWithAuth));