[enhancement](Nereids): speed up rewrite() (#22846)
- use Set<Integer> instead of Set<String> to speedup `contains` - remove `getValidRules` and use `if` in `for` to avoid `toImmutableList`
This commit is contained in:
@ -37,12 +37,10 @@ import org.apache.doris.qe.SessionVariable;
|
||||
import org.apache.doris.statistics.Statistics;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
@ -58,7 +56,7 @@ public abstract class Job implements TracerSupplier {
|
||||
protected JobType type;
|
||||
protected JobContext context;
|
||||
protected boolean once;
|
||||
protected final Set<String> disableRules;
|
||||
protected final Set<Integer> disableRules;
|
||||
|
||||
protected Map<CTEId, Statistics> cteIdToStats;
|
||||
|
||||
@ -86,29 +84,6 @@ public abstract class Job implements TracerSupplier {
|
||||
return once;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the rule set of this job. Filter out already applied rules and rules that are not matched on root node.
|
||||
*
|
||||
* @param groupExpression group expression to be applied on
|
||||
* @param candidateRules rules to be applied
|
||||
* @return all rules that can be applied on this group expression
|
||||
*/
|
||||
public List<Rule> getValidRules(GroupExpression groupExpression, List<Rule> candidateRules) {
|
||||
return candidateRules.stream()
|
||||
.filter(rule -> Objects.nonNull(rule)
|
||||
&& !disableRules.contains(rule.getRuleType().name())
|
||||
&& rule.getPattern().matchRoot(groupExpression.getPlan())
|
||||
&& groupExpression.notApplied(rule))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
|
||||
public List<Rule> getValidRules(List<Rule> candidateRules) {
|
||||
return candidateRules.stream()
|
||||
.filter(rule -> Objects.nonNull(rule)
|
||||
&& !disableRules.contains(rule.getRuleType().name()))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
|
||||
public abstract void execute();
|
||||
|
||||
public EventProducer getEventTracer() {
|
||||
@ -149,13 +124,8 @@ public abstract class Job implements TracerSupplier {
|
||||
groupExpression.getOwnerGroup(), groupExpression, groupExpression.getPlan()));
|
||||
}
|
||||
|
||||
public static Set<String> getDisableRules(JobContext context) {
|
||||
public static Set<Integer> getDisableRules(JobContext context) {
|
||||
return context.getCascadesContext().getAndCacheSessionVariable(
|
||||
"disableNereidsRules", ImmutableSet.of(), SessionVariable::getDisableNereidsRules);
|
||||
}
|
||||
|
||||
public static boolean isTraceEnable(JobContext context) {
|
||||
return context.getCascadesContext().getAndCacheSessionVariable(
|
||||
"isTraceEnable", false, SessionVariable::isEnableNereidsTrace);
|
||||
}
|
||||
}
|
||||
|
||||
@ -43,11 +43,17 @@ public class OptimizeGroupExpressionJob extends Job {
|
||||
List<Rule> implementationRules = getRuleSet().getImplementationRules();
|
||||
List<Rule> explorationRules = getExplorationRules();
|
||||
|
||||
for (Rule rule : getValidRules(groupExpression, explorationRules)) {
|
||||
for (Rule rule : explorationRules) {
|
||||
if (rule.isInvalid(disableRules, groupExpression)) {
|
||||
continue;
|
||||
}
|
||||
pushJob(new ApplyRuleJob(groupExpression, rule, context));
|
||||
}
|
||||
|
||||
for (Rule rule : getValidRules(groupExpression, implementationRules)) {
|
||||
for (Rule rule : implementationRules) {
|
||||
if (rule.isInvalid(disableRules, groupExpression)) {
|
||||
continue;
|
||||
}
|
||||
pushJob(new ApplyRuleJob(groupExpression, rule, context));
|
||||
}
|
||||
}
|
||||
|
||||
@ -23,7 +23,6 @@ import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
|
||||
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.function.Supplier;
|
||||
@ -49,8 +48,8 @@ public class CustomRewriteJob implements RewriteJob {
|
||||
|
||||
@Override
|
||||
public void execute(JobContext context) {
|
||||
Set<String> disableRules = Job.getDisableRules(context);
|
||||
if (disableRules.contains(ruleType.name().toUpperCase(Locale.ROOT))) {
|
||||
Set<Integer> disableRules = Job.getDisableRules(context);
|
||||
if (disableRules.contains(ruleType.type())) {
|
||||
return;
|
||||
}
|
||||
Plan root = context.getCascadesContext().getRewritePlan();
|
||||
|
||||
@ -42,8 +42,10 @@ public abstract class PlanTreeRewriteJob extends Job {
|
||||
boolean isRewriteRoot = rewriteJobContext.isRewriteRoot();
|
||||
CascadesContext cascadesContext = context.getCascadesContext();
|
||||
cascadesContext.setIsRewriteRoot(isRewriteRoot);
|
||||
List<Rule> validRules = getValidRules(rules);
|
||||
for (Rule rule : validRules) {
|
||||
for (Rule rule : rules) {
|
||||
if (disableRules.contains(rule.getRuleType().type())) {
|
||||
continue;
|
||||
}
|
||||
Pattern<Plan> pattern = (Pattern<Plan>) rule.getPattern();
|
||||
if (pattern.matchPlanTree(plan)) {
|
||||
List<Plan> newPlans = rule.transform(plan, cascadesContext);
|
||||
|
||||
@ -87,8 +87,10 @@ public class RewriteBottomUpJob extends Job {
|
||||
}
|
||||
|
||||
countJobExecutionTimesOfGroupExpressions(logicalExpression);
|
||||
List<Rule> validRules = getValidRules(logicalExpression, rules);
|
||||
for (Rule rule : validRules) {
|
||||
for (Rule rule : rules) {
|
||||
if (rule.isInvalid(disableRules, logicalExpression)) {
|
||||
continue;
|
||||
}
|
||||
GroupExpressionMatching groupExpressionMatching
|
||||
= new GroupExpressionMatching(rule.getPattern(), logicalExpression);
|
||||
for (Plan before : groupExpressionMatching) {
|
||||
|
||||
@ -84,8 +84,10 @@ public class RewriteTopDownJob extends Job {
|
||||
public void execute() {
|
||||
GroupExpression logicalExpression = group.getLogicalExpression();
|
||||
countJobExecutionTimesOfGroupExpressions(logicalExpression);
|
||||
List<Rule> validRules = getValidRules(logicalExpression, rules);
|
||||
for (Rule rule : validRules) {
|
||||
for (Rule rule : rules) {
|
||||
if (rule.isInvalid(disableRules, logicalExpression)) {
|
||||
continue;
|
||||
}
|
||||
Preconditions.checkArgument(rule.isRewrite(),
|
||||
"rules must be rewritable in top down job");
|
||||
GroupExpressionMatching groupExpressionMatching
|
||||
|
||||
@ -19,11 +19,13 @@ package org.apache.doris.nereids.rules;
|
||||
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.exceptions.TransformException;
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.pattern.Pattern;
|
||||
import org.apache.doris.nereids.rules.RuleType.RuleTypeClass;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Abstract class for all rules.
|
||||
@ -73,4 +75,13 @@ public abstract class Rule {
|
||||
public void acceptPlan(Plan plan) {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Filter out already applied rules and rules that are not matched on root node.
|
||||
*/
|
||||
public boolean isInvalid(Set<Integer> disableRules, GroupExpression groupExpression) {
|
||||
return disableRules.contains(this.getRuleType().type())
|
||||
|| !groupExpression.notApplied(this)
|
||||
|| !this.getPattern().matchRoot(groupExpression.getPlan());
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,7 +64,7 @@ import java.util.Set;
|
||||
* | aggregate: count(*) as cntStar
|
||||
* aggregate: count(x) as cnt
|
||||
* </pre>
|
||||
* Notice: when Count(*) exists, group by mustn't be empty.
|
||||
* Notice: rule can't optimize condition that groupby is empty when Count(*) exists.
|
||||
*/
|
||||
public class PushdownCountThroughJoin implements RewriteRuleFactory {
|
||||
@Override
|
||||
|
||||
@ -28,6 +28,7 @@ import org.apache.doris.common.io.Writable;
|
||||
import org.apache.doris.common.util.TimeUtils;
|
||||
import org.apache.doris.nereids.metrics.Event;
|
||||
import org.apache.doris.nereids.metrics.EventSwitchParser;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.qe.VariableMgr.VarAttr;
|
||||
import org.apache.doris.thrift.TQueryOptions;
|
||||
import org.apache.doris.thrift.TResourceLimit;
|
||||
@ -1944,12 +1945,20 @@ public class SessionVariable implements Serializable, Writable {
|
||||
return nthOptimizedPlan;
|
||||
}
|
||||
|
||||
public Set<String> getDisableNereidsRules() {
|
||||
public Set<String> getDisableNereidsRuleNames() {
|
||||
return Arrays.stream(disableNereidsRules.split(",[\\s]*"))
|
||||
.map(rule -> rule.toUpperCase(Locale.ROOT))
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
}
|
||||
|
||||
public Set<Integer> getDisableNereidsRules() {
|
||||
return Arrays.stream(disableNereidsRules.split(",[\\s]*"))
|
||||
.filter(rule -> !rule.isEmpty())
|
||||
.map(rule -> rule.toUpperCase(Locale.ROOT))
|
||||
.map(rule -> RuleType.valueOf(rule).type())
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
}
|
||||
|
||||
public void setEnableNewCostModel(boolean enable) {
|
||||
this.enableNewCostModel = enable;
|
||||
}
|
||||
|
||||
@ -128,7 +128,7 @@ public class PlanChecker {
|
||||
|
||||
public PlanChecker analyze(Plan plan) {
|
||||
this.cascadesContext = MemoTestUtils.createCascadesContext(connectContext, plan);
|
||||
Set<String> originDisableRules = connectContext.getSessionVariable().getDisableNereidsRules();
|
||||
Set<String> originDisableRules = connectContext.getSessionVariable().getDisableNereidsRuleNames();
|
||||
Set<String> disableRuleWithAuth = Sets.newHashSet(originDisableRules);
|
||||
disableRuleWithAuth.add(RuleType.RELATION_AUTHENTICATION.name());
|
||||
connectContext.getSessionVariable().setDisableNereidsRules(String.join(",", disableRuleWithAuth));
|
||||
|
||||
@ -194,7 +194,7 @@ public abstract class TestWithFeService {
|
||||
}
|
||||
|
||||
public LogicalPlan analyze(String sql) {
|
||||
Set<String> originDisableRules = connectContext.getSessionVariable().getDisableNereidsRules();
|
||||
Set<String> originDisableRules = connectContext.getSessionVariable().getDisableNereidsRuleNames();
|
||||
Set<String> disableRuleWithAuth = Sets.newHashSet(originDisableRules);
|
||||
disableRuleWithAuth.add(RuleType.RELATION_AUTHENTICATION.name());
|
||||
connectContext.getSessionVariable().setDisableNereidsRules(String.join(",", disableRuleWithAuth));
|
||||
|
||||
Reference in New Issue
Block a user