[refactor](Nereids) Refactor rewrite framework to speed up plan (#17126)

This pr refactor the rewrite framework from memo to plan tree, and speed up the analyze/rewrite stage.

Changes:
- abandoned memo in the analysis/rewrite stage, so that we can skip some actions, like new GroupExpression, distinct GroupExpression in the memo(high cost), update children to GroupPlan
- change the most of rules to static rule, so that we can skip initialize lots of rules in Analyzer/Rewriter at every query. but some rules need context, like visitor rule, create rule at the runtime make it is easy to use, so make `custom` rule can help us to create it.
- remove the `logger` field in the Job, Job are generated in large quantities at runtime, we don't need to use logger so save huge time to initialize logger.
- skip some rule as far as possible, e.g. `SelectMaterializedIndexWithoutAggregate`, skip select mv if the table not exist rullup.
- add some caches for frequent operation, like get Job.getDisableRules, Plan.getUnboundExpression
- new bottom up rewrite rule, it can keep traverse multiple new plan which return by rules, this feature depends on `Plan.mutableState`, it is necessary to add this variable field for plan. if the plan is fully immutable, we must use withXxx to renew the plan and set the state for it, this take more runtime overhead and developing workload. another reason is we need multiple mutable state, e.g. whether is applied the rule,  whether this plan is manage by the rewrite framework. the good side of mutable state is efficient, but I suggest we don't direct use mutable state in the rule as far as possible, if we need use it, please wrap the mutable state in the framework to update and release it correctly. a good example is `AppliedAwareRuleCondition`, it can update and get the state: whether this plan is applied to a rule before.
- merge some rules, invoke multiple rules in one traverse
- refactor the `EliminateUnnecessaryProject` by CustomRewritor, fix the problem which eliminate some Project which decided the query output order, the case is limit(project), sort(project).

TODO: add trace for new rewrite framework

benchmark:

legacy optimizer:
```
+-----------+---------------+---------------+---------------+
|  SQL ID   |      avg      |      min      |      max      |
+-----------+---------------+---------------+---------------+
|  SQL 1    |       1.39 ms |          0 ms |          9 ms |
|  SQL 2    |       1.38 ms |          0 ms |         10 ms |
|  SQL 3    |       2.05 ms |          1 ms |         18 ms |
|  SQL 4    |       0.89 ms |          0 ms |          9 ms |
|  SQL 5    |       1.74 ms |          1 ms |         11 ms |
|  SQL 6    |       2.00 ms |          1 ms |         13 ms |
|  SQL 7    |       1.83 ms |          1 ms |         15 ms |
|  SQL 8    |       0.92 ms |          0 ms |          7 ms |
|  SQL 9    |       2.60 ms |          1 ms |         19 ms |
|  SQL 10   |       3.54 ms |          2 ms |         28 ms |
|  SQL 11   |       3.04 ms |          1 ms |         18 ms |
|  SQL 12   |       3.26 ms |          2 ms |         16 ms |
|  SQL 13   |       1.10 ms |          0 ms |         10 ms |
|  SQL 14   |       2.90 ms |          1 ms |         13 ms |
|  SQL 15   |       1.18 ms |          0 ms |          9 ms |
|  SQL 16   |       1.05 ms |          0 ms |         13 ms |
|  SQL 17   |       1.03 ms |          0 ms |          7 ms |
|  SQL 18   |       0.94 ms |          0 ms |          7 ms |
|  SQL 19   |       1.47 ms |          0 ms |         13 ms |
|  SQL 20   |       0.47 ms |          0 ms |          4 ms |
|  SQL 21   |       0.54 ms |          0 ms |          5 ms |
|  SQL 22   |       3.34 ms |          1 ms |         19 ms |
|  SQL 23   |       7.97 ms |          4 ms |         44 ms |
|  SQL 24   |      11.11 ms |          7 ms |         28 ms |
|  SQL 25   |       0.98 ms |          0 ms |          8 ms |
|  SQL 26   |       0.83 ms |          0 ms |          7 ms |
|  SQL 27   |       0.93 ms |          0 ms |         16 ms |
|  SQL 28   |       2.19 ms |          1 ms |         18 ms |
|  SQL 29   |       3.23 ms |          1 ms |         20 ms |
|  SQL 30   |      59.99 ms |         51 ms |         81 ms |
|  SQL 31   |       2.65 ms |          1 ms |         18 ms |
|  SQL 32   |       2.47 ms |          1 ms |         17 ms |
|  SQL 33   |       2.30 ms |          1 ms |         16 ms |
|  SQL 34   |       0.66 ms |          0 ms |          8 ms |
|  SQL 35   |       0.63 ms |          0 ms |          6 ms |
|  SQL 36   |       2.25 ms |          1 ms |         15 ms |
|  SQL 37   |       5.97 ms |          3 ms |         20 ms |
|  SQL 38   |       5.73 ms |          3 ms |         21 ms |
|  SQL 39   |       6.32 ms |          4 ms |         23 ms |
|  SQL 40   |       8.61 ms |          5 ms |         35 ms |
|  SQL 41   |       6.29 ms |          4 ms |         28 ms |
|  SQL 42   |       6.04 ms |          4 ms |         15 ms |
|  SQL 43   |       5.81 ms |          3 ms |         16 ms |
+-----------+---------------+---------------+---------------+
| TOTAL AVG |       4.22 ms |       2.47 ms |      17.05 ms |
| TOTAL SUM |     181.62 ms |        106 ms |        733 ms |
+-----------+---------------+---------------+---------------+
```

nereids with memo rewrite framework(old):
```
+-----------+---------------+---------------+---------------+
|  SQL ID   |      avg      |      min      |      max      |
+-----------+---------------+---------------+---------------+
|  SQL 1    |       3.61 ms |          1 ms |         20 ms |
|  SQL 2    |       3.47 ms |          2 ms |         16 ms |
|  SQL 3    |       3.27 ms |          1 ms |         18 ms |
|  SQL 4    |       2.23 ms |          1 ms |         12 ms |
|  SQL 5    |       3.60 ms |          1 ms |         20 ms |
|  SQL 6    |       2.73 ms |          1 ms |         17 ms |
|  SQL 7    |       3.04 ms |          1 ms |         23 ms |
|  SQL 8    |       3.53 ms |          2 ms |         20 ms |
|  SQL 9    |       3.74 ms |          2 ms |         22 ms |
|  SQL 10   |       3.66 ms |          2 ms |         18 ms |
|  SQL 11   |       3.93 ms |          2 ms |         15 ms |
|  SQL 12   |       4.85 ms |          2 ms |         27 ms |
|  SQL 13   |       4.41 ms |          2 ms |         28 ms |
|  SQL 14   |       5.16 ms |          2 ms |         41 ms |
|  SQL 15   |       4.33 ms |          2 ms |         33 ms |
|  SQL 16   |       4.94 ms |          2 ms |         51 ms |
|  SQL 17   |       3.27 ms |          1 ms |         25 ms |
|  SQL 18   |       2.78 ms |          1 ms |         22 ms |
|  SQL 19   |       3.51 ms |          1 ms |         42 ms |
|  SQL 20   |       1.84 ms |          1 ms |         13 ms |
|  SQL 21   |       3.47 ms |          1 ms |         66 ms |
|  SQL 22   |       5.21 ms |          2 ms |         29 ms |
|  SQL 23   |       5.55 ms |          3 ms |         25 ms |
|  SQL 24   |       4.21 ms |          2 ms |         28 ms |
|  SQL 25   |       3.47 ms |          1 ms |         23 ms |
|  SQL 26   |       3.03 ms |          2 ms |         21 ms |
|  SQL 27   |       3.07 ms |          1 ms |         17 ms |
|  SQL 28   |       4.51 ms |          3 ms |         22 ms |
|  SQL 29   |       4.97 ms |          3 ms |         21 ms |
|  SQL 30   |      11.95 ms |          8 ms |         33 ms |
|  SQL 31   |       3.92 ms |          2 ms |         23 ms |
|  SQL 32   |       3.74 ms |          2 ms |         15 ms |
|  SQL 33   |       3.62 ms |          2 ms |         22 ms |
|  SQL 34   |       4.60 ms |          1 ms |         55 ms |
|  SQL 35   |       3.47 ms |          2 ms |         25 ms |
|  SQL 36   |       3.34 ms |          2 ms |         18 ms |
|  SQL 37   |       4.77 ms |          2 ms |         23 ms |
|  SQL 38   |       4.44 ms |          2 ms |         39 ms |
|  SQL 39   |       4.52 ms |          2 ms |         23 ms |
|  SQL 40   |       5.50 ms |          3 ms |         30 ms |
|  SQL 41   |       5.01 ms |          2 ms |         24 ms |
|  SQL 42   |       4.32 ms |          2 ms |         24 ms |
|  SQL 43   |       4.29 ms |          2 ms |         42 ms |
+-----------+---------------+---------------+---------------+
| TOTAL AVG |       4.11 ms |       1.91 ms |      26.30 ms |
| TOTAL SUM |     176.88 ms |         82 ms |       1131 ms |
+-----------+---------------+---------------+---------------+
```

nereids with plan tree rewrite framework(new):
```
+-----------+---------------+---------------+---------------+
|  SQL ID   |      avg      |      min      |      max      |
+-----------+---------------+---------------+---------------+
|  SQL 1    |       3.21 ms |          1 ms |         18 ms |
|  SQL 2    |       3.99 ms |          1 ms |         76 ms |
|  SQL 3    |       2.93 ms |          1 ms |         21 ms |
|  SQL 4    |       2.13 ms |          1 ms |         21 ms |
|  SQL 5    |       2.43 ms |          1 ms |         30 ms |
|  SQL 6    |       2.08 ms |          1 ms |         11 ms |
|  SQL 7    |       2.03 ms |          1 ms |         11 ms |
|  SQL 8    |       2.27 ms |          1 ms |         22 ms |
|  SQL 9    |       2.42 ms |          1 ms |         16 ms |
|  SQL 10   |       2.65 ms |          1 ms |         14 ms |
|  SQL 11   |       2.78 ms |          1 ms |         14 ms |
|  SQL 12   |       3.09 ms |          1 ms |         19 ms |
|  SQL 13   |       2.33 ms |          1 ms |         13 ms |
|  SQL 14   |       2.66 ms |          1 ms |         16 ms |
|  SQL 15   |       2.34 ms |          1 ms |         15 ms |
|  SQL 16   |       2.04 ms |          1 ms |         30 ms |
|  SQL 17   |       2.09 ms |          1 ms |         17 ms |
|  SQL 18   |       1.87 ms |          1 ms |         15 ms |
|  SQL 19   |       2.21 ms |          1 ms |         50 ms |
|  SQL 20   |       1.32 ms |          0 ms |         12 ms |
|  SQL 21   |       1.63 ms |          1 ms |         11 ms |
|  SQL 22   |       2.75 ms |          1 ms |         30 ms |
|  SQL 23   |       3.44 ms |          2 ms |         17 ms |
|  SQL 24   |       2.01 ms |          1 ms |         14 ms |
|  SQL 25   |       1.58 ms |          1 ms |         11 ms |
|  SQL 26   |       1.53 ms |          0 ms |         13 ms |
|  SQL 27   |       1.62 ms |          1 ms |         12 ms |
|  SQL 28   |       2.90 ms |          1 ms |         21 ms |
|  SQL 29   |       3.04 ms |          2 ms |         17 ms |
|  SQL 30   |      10.54 ms |          7 ms |         49 ms |
|  SQL 31   |       2.61 ms |          1 ms |         21 ms |
|  SQL 32   |       2.42 ms |          1 ms |         14 ms |
|  SQL 33   |       2.13 ms |          1 ms |         14 ms |
|  SQL 34   |       1.69 ms |          1 ms |         14 ms |
|  SQL 35   |       1.87 ms |          1 ms |         15 ms |
|  SQL 36   |       2.37 ms |          1 ms |         21 ms |
|  SQL 37   |       3.06 ms |          1 ms |         15 ms |
|  SQL 38   |       4.09 ms |          1 ms |         31 ms |
|  SQL 39   |       5.81 ms |          2 ms |         43 ms |
|  SQL 40   |       4.55 ms |          2 ms |         34 ms |
|  SQL 41   |       3.49 ms |          1 ms |         20 ms |
|  SQL 42   |       2.75 ms |          1 ms |         26 ms |
|  SQL 43   |       2.81 ms |          1 ms |         14 ms |
+-----------+---------------+---------------+---------------+
| TOTAL AVG |       2.78 ms |       1.19 ms |      21.35 ms |
| TOTAL SUM |     119.56 ms |         51 ms |        918 ms |
+-----------+---------------+---------------+---------------+
```
This commit is contained in:
924060929
2023-02-28 16:02:09 +08:00
committed by GitHub
parent a1db5c6f52
commit 9db56201a6
202 changed files with 3490 additions and 1652 deletions

View File

@ -26,11 +26,14 @@ import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.rewrite.CustomRewriteJob;
import org.apache.doris.nereids.jobs.rewrite.RewriteBottomUpJob;
import org.apache.doris.nereids.jobs.rewrite.RewriteTopDownJob;
import org.apache.doris.nereids.jobs.rewrite.RootPlanTreeRewriteJob.RootRewriteJobContext;
import org.apache.doris.nereids.jobs.scheduler.JobPool;
import org.apache.doris.nereids.jobs.scheduler.JobScheduler;
import org.apache.doris.nereids.jobs.scheduler.JobStack;
import org.apache.doris.nereids.jobs.scheduler.ScheduleContext;
import org.apache.doris.nereids.jobs.scheduler.SimpleJobScheduler;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.processor.post.RuntimeFilterContext;
@ -38,12 +41,14 @@ import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleFactory;
import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTE;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
@ -57,12 +62,20 @@ import java.util.Optional;
import java.util.Set;
import java.util.Stack;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
/**
* Context used in memo.
*/
public class CascadesContext {
private final Memo memo;
public class CascadesContext implements ScheduleContext, PlanSource {
// in analyze/rewrite stage, the plan will storage in this field
private Plan plan;
private Optional<RootRewriteJobContext> currentRootRewriteJobContext;
// in optimize stage, the plan will storage in the memo
private Memo memo;
private final StatementContext statementContext;
private CTEContext cteContext;
@ -76,8 +89,13 @@ public class CascadesContext {
private List<Table> tables = null;
public CascadesContext(Memo memo, StatementContext statementContext, PhysicalProperties requestProperties) {
this(memo, statementContext, new CTEContext(), requestProperties);
private boolean isRewriteRoot;
private Optional<Scope> outerScope = Optional.empty();
public CascadesContext(Plan plan, Memo memo, StatementContext statementContext,
PhysicalProperties requestProperties) {
this(plan, memo, statementContext, new CTEContext(), requestProperties);
}
/**
@ -86,8 +104,9 @@ public class CascadesContext {
* @param memo {@link Memo} reference
* @param statementContext {@link StatementContext} reference
*/
public CascadesContext(Memo memo, StatementContext statementContext,
public CascadesContext(Plan plan, Memo memo, StatementContext statementContext,
CTEContext cteContext, PhysicalProperties requireProperties) {
this.plan = plan;
this.memo = memo;
this.statementContext = statementContext;
this.ruleSet = new RuleSet();
@ -99,19 +118,30 @@ public class CascadesContext {
this.cteContext = cteContext;
}
public static CascadesContext newContext(StatementContext statementContext,
public static CascadesContext newMemoContext(StatementContext statementContext,
Plan initPlan, PhysicalProperties requireProperties) {
return new CascadesContext(new Memo(initPlan), statementContext, requireProperties);
return new CascadesContext(initPlan, new Memo(initPlan), statementContext, requireProperties);
}
public static CascadesContext newRewriteContext(StatementContext statementContext,
Plan initPlan, PhysicalProperties requireProperties) {
return new CascadesContext(initPlan, null, statementContext, requireProperties);
}
public static CascadesContext newRewriteContext(StatementContext statementContext,
Plan initPlan, CTEContext cteContext) {
return new CascadesContext(initPlan, null, statementContext, cteContext, PhysicalProperties.ANY);
}
public void toMemo() {
this.memo = new Memo(plan);
}
public NereidsAnalyzer newAnalyzer() {
return new NereidsAnalyzer(this);
}
public NereidsAnalyzer newAnalyzer(Optional<Scope> outerScope) {
return new NereidsAnalyzer(this, outerScope);
}
@Override
public void pushJob(Job job) {
jobPool.push(job);
}
@ -136,6 +166,7 @@ public class CascadesContext {
this.ruleSet = ruleSet;
}
@Override
public JobPool getJobPool() {
return jobPool;
}
@ -165,6 +196,23 @@ public class CascadesContext {
return this;
}
public Plan getRewritePlan() {
return plan;
}
public void setRewritePlan(Plan plan) {
this.plan = plan;
}
public Optional<RootRewriteJobContext> getCurrentRootRewriteJobContext() {
return currentRootRewriteJobContext;
}
public void setCurrentRootRewriteJobContext(
RootRewriteJobContext currentRootRewriteJobContext) {
this.currentRootRewriteJobContext = Optional.ofNullable(currentRootRewriteJobContext);
}
public void setSubqueryExprIsAnalyzed(SubqueryExpr subqueryExpr, boolean isAnalyzed) {
subqueryExprIsAnalyzed.put(subqueryExpr, isAnalyzed);
}
@ -201,6 +249,13 @@ public class CascadesContext {
return execute(new RewriteTopDownJob(memo.getRoot(), rules, currentJobContext));
}
public CascadesContext topDownRewrite(CustomRewriter customRewriter) {
CustomRewriteJob customRewriteJob = new CustomRewriteJob(() -> customRewriter, RuleType.TEST_REWRITE);
customRewriteJob.execute(currentJobContext);
toMemo();
return this;
}
public CTEContext getCteContext() {
return cteContext;
}
@ -209,6 +264,22 @@ public class CascadesContext {
this.cteContext = cteContext;
}
public void setIsRewriteRoot(boolean isRewriteRoot) {
this.isRewriteRoot = isRewriteRoot;
}
public boolean isRewriteRoot() {
return isRewriteRoot;
}
public Optional<Scope> getOuterScope() {
return outerScope;
}
public void setOuterScope(@Nullable Scope outerScope) {
this.outerScope = Optional.ofNullable(outerScope);
}
private CascadesContext execute(Job job) {
pushJob(job);
jobScheduler.executeJobPool(this);

View File

@ -26,8 +26,8 @@ import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator;
import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
import org.apache.doris.nereids.jobs.batch.NereidsRewriteJobExecutor;
import org.apache.doris.nereids.jobs.batch.OptimizeRulesJob;
import org.apache.doris.nereids.jobs.batch.CascadesOptimizer;
import org.apache.doris.nereids.jobs.batch.NereidsRewriter;
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob;
import org.apache.doris.nereids.memo.CopyInResult;
@ -156,7 +156,7 @@ public class NereidsPlanner extends Planner {
// resolve column, table and function
analyze();
if (explainLevel == ExplainLevel.ANALYZED_PLAN || explainLevel == ExplainLevel.ALL_PLAN) {
analyzedPlan = cascadesContext.getMemo().copyOut(false);
analyzedPlan = cascadesContext.getRewritePlan();
if (explainLevel == ExplainLevel.ANALYZED_PLAN) {
return analyzedPlan;
}
@ -164,11 +164,14 @@ public class NereidsPlanner extends Planner {
// rule-based optimize
rewrite();
if (explainLevel == ExplainLevel.REWRITTEN_PLAN || explainLevel == ExplainLevel.ALL_PLAN) {
rewrittenPlan = cascadesContext.getMemo().copyOut(false);
rewrittenPlan = cascadesContext.getRewritePlan();
if (explainLevel == ExplainLevel.REWRITTEN_PLAN) {
return rewrittenPlan;
}
}
initMemo();
deriveStats();
optimize();
@ -190,7 +193,7 @@ public class NereidsPlanner extends Planner {
}
private void initCascadesContext(LogicalPlan plan, PhysicalProperties requireProperties) {
cascadesContext = CascadesContext.newContext(statementContext, plan, requireProperties);
cascadesContext = CascadesContext.newRewriteContext(statementContext, plan, requireProperties);
}
private void analyze() {
@ -201,7 +204,11 @@ public class NereidsPlanner extends Planner {
* Logical plan rewrite based on a series of heuristic rules.
*/
private void rewrite() {
new NereidsRewriteJobExecutor(cascadesContext).execute();
new NereidsRewriter(cascadesContext).execute();
}
private void initMemo() {
cascadesContext.toMemo();
}
private void deriveStats() {
@ -236,7 +243,7 @@ public class NereidsPlanner extends Planner {
.getSessionVariable().getMaxTableCountUseCascadesJoinReorder()) {
dpHypOptimize();
}
new OptimizeRulesJob(cascadesContext).execute();
new CascadesOptimizer(cascadesContext).execute();
}
private PhysicalPlan postProcess(PhysicalPlan physicalPlan) {

View File

@ -0,0 +1,22 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids;
/** PlanSource */
public interface PlanSource {
}

View File

@ -18,41 +18,80 @@
package org.apache.doris.nereids.analyzer;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.batch.AdjustAggregateNullableForEmptySetJob;
import org.apache.doris.nereids.jobs.batch.AnalyzeRulesJob;
import org.apache.doris.nereids.jobs.batch.AnalyzeSubqueryRulesJob;
import org.apache.doris.nereids.jobs.batch.CheckAnalysisJob;
import org.apache.doris.nereids.jobs.RewriteJob;
import org.apache.doris.nereids.jobs.batch.BatchRewriteJob;
import org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet;
import org.apache.doris.nereids.rules.analysis.BindExpression;
import org.apache.doris.nereids.rules.analysis.BindRelation;
import org.apache.doris.nereids.rules.analysis.CheckAnalysis;
import org.apache.doris.nereids.rules.analysis.CheckPolicy;
import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots;
import org.apache.doris.nereids.rules.analysis.NormalizeRepeat;
import org.apache.doris.nereids.rules.analysis.ProjectToGlobalAggregate;
import org.apache.doris.nereids.rules.analysis.ProjectWithDistinctToAggregate;
import org.apache.doris.nereids.rules.analysis.RegisterCTE;
import org.apache.doris.nereids.rules.analysis.ReplaceExpressionByChildOutput;
import org.apache.doris.nereids.rules.analysis.ResolveOrdinalInOrderByAndGroupBy;
import org.apache.doris.nereids.rules.analysis.SubqueryToApply;
import org.apache.doris.nereids.rules.analysis.UserAuthentication;
import org.apache.doris.nereids.rules.rewrite.logical.HideOneRowRelationUnderUnion;
import java.util.Objects;
import java.util.Optional;
import java.util.List;
/**
* Bind symbols according to metadata in the catalog, perform semantic analysis, etc.
* TODO: revisit the interface after subquery analysis is supported.
*/
public class NereidsAnalyzer {
private final CascadesContext cascadesContext;
private final Optional<Scope> outerScope;
public class NereidsAnalyzer extends BatchRewriteJob {
public static final List<RewriteJob> ANALYZE_JOBS = jobs(
topDown(
new RegisterCTE()
),
bottomUp(
new BindRelation(),
new CheckPolicy(),
new UserAuthentication(),
new BindExpression(),
new ProjectToGlobalAggregate(),
// this rule check's the logicalProject node's isDisinct property
// and replace the logicalProject node with a LogicalAggregate node
// so any rule before this, if create a new logicalProject node
// should make sure isDistinct property is correctly passed around.
// please see rule BindSlotReference or BindFunction for example
new ProjectWithDistinctToAggregate(),
new ResolveOrdinalInOrderByAndGroupBy(),
new ReplaceExpressionByChildOutput(),
new HideOneRowRelationUnderUnion()
),
topDown(
new FillUpMissingSlots(),
// We should use NormalizeRepeat to compute nullable properties for LogicalRepeat in the analysis
// stage. NormalizeRepeat will compute nullable property, add virtual slot, LogicalAggregate and
// LogicalProject for normalize. This rule depends on FillUpMissingSlots to fill up slots.
new NormalizeRepeat()
),
bottomUp(new SubqueryToApply()),
bottomUp(new AdjustAggregateNullableForEmptySet()),
bottomUp(new CheckAnalysis())
);
/**
* Execute the analysis job with scope.
* @param cascadesContext planner context for execute job
*/
public NereidsAnalyzer(CascadesContext cascadesContext) {
this(cascadesContext, Optional.empty());
super(cascadesContext);
}
public NereidsAnalyzer(CascadesContext cascadesContext, Optional<Scope> outerScope) {
this.cascadesContext = Objects.requireNonNull(cascadesContext, "cascadesContext cannot be null");
this.outerScope = Objects.requireNonNull(outerScope, "outerScope cannot be null");
@Override
public List<RewriteJob> getJobs() {
return ANALYZE_JOBS;
}
/**
* nereids analyze sql.
*/
public void analyze() {
new AnalyzeRulesJob(cascadesContext, outerScope).execute();
new AnalyzeSubqueryRulesJob(cascadesContext).execute();
new AdjustAggregateNullableForEmptySetJob(cascadesContext).execute();
// check whether analyze result is meaningful
new CheckAnalysisJob(cascadesContext).execute();
execute();
}
}

View File

@ -22,11 +22,12 @@ import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
/**
* The slot range required for expression analyze.
@ -57,13 +58,13 @@ public class Scope {
private final List<Slot> slots;
private final Optional<SubqueryExpr> ownerSubquery;
private List<Slot> correlatedSlots;
private Set<Slot> correlatedSlots;
public Scope(Optional<Scope> outerScope, List<Slot> slots, Optional<SubqueryExpr> subqueryExpr) {
this.outerScope = outerScope;
this.slots = ImmutableList.copyOf(Objects.requireNonNull(slots, "slots can not be null"));
this.ownerSubquery = subqueryExpr;
this.correlatedSlots = new ArrayList<>();
this.correlatedSlots = Sets.newLinkedHashSet();
}
public Scope(List<Slot> slots) {
@ -82,7 +83,7 @@ public class Scope {
return ownerSubquery;
}
public List<Slot> getCorrelatedSlots() {
public Set<Slot> getCorrelatedSlots() {
return correlatedSlots;
}

View File

@ -19,7 +19,6 @@ package org.apache.doris.nereids.jobs;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.memo.CopyInResult;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
@ -39,17 +38,14 @@ import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* Abstract class for all job using for analyze and optimize query plan in Nereids.
@ -60,8 +56,6 @@ public abstract class Job implements TracerSupplier {
EventChannel.getDefaultChannel()
.addEnhancers(new AddCounterEventEnhancer())
.addConsumers(new LogConsumer(CounterEvent.class, EventChannel.LOG)));
public final Logger logger = LogManager.getLogger(getClass());
protected JobType type;
protected JobContext context;
protected boolean once;
@ -76,12 +70,11 @@ public abstract class Job implements TracerSupplier {
this.type = type;
this.context = context;
this.once = once;
this.disableRules = getAndCacheSessionVariable(context, "disableNereidsRules",
ImmutableSet.of(), SessionVariable::getDisableNereidsRules);
this.disableRules = getDisableRules(context);
}
public void pushJob(Job job) {
context.getCascadesContext().pushJob(job);
context.getScheduleContext().pushJob(job);
}
public RuleSet getRuleSet() {
@ -101,12 +94,21 @@ public abstract class Job implements TracerSupplier {
*/
public List<Rule> getValidRules(GroupExpression groupExpression, List<Rule> candidateRules) {
return candidateRules.stream()
.filter(rule -> !disableRules.contains(rule.getRuleType().name().toUpperCase(Locale.ROOT)))
.filter(rule -> Objects.nonNull(rule) && rule.getPattern().matchRoot(groupExpression.getPlan())
&& groupExpression.notApplied(rule)).collect(Collectors.toList());
.filter(rule -> Objects.nonNull(rule)
&& !disableRules.contains(rule.getRuleType().name())
&& rule.getPattern().matchRoot(groupExpression.getPlan())
&& groupExpression.notApplied(rule))
.collect(ImmutableList.toImmutableList());
}
public abstract void execute() throws AnalysisException;
public List<Rule> getValidRules(List<Rule> candidateRules) {
return candidateRules.stream()
.filter(rule -> Objects.nonNull(rule)
&& !disableRules.contains(rule.getRuleType().name()))
.collect(ImmutableList.toImmutableList());
}
public abstract void execute();
public EventProducer getEventTracer() {
throw new UnsupportedOperationException("get_event_tracer is unsupported");
@ -146,7 +148,17 @@ public abstract class Job implements TracerSupplier {
groupExpression.getOwnerGroup(), groupExpression, groupExpression.getPlan()));
}
private <T> T getAndCacheSessionVariable(JobContext context, String cacheName,
public static Set<String> getDisableRules(JobContext context) {
return getAndCacheSessionVariable(context, "disableNereidsRules",
ImmutableSet.of(), SessionVariable::getDisableNereidsRules);
}
public static boolean isTraceEnable(JobContext context) {
return getAndCacheSessionVariable(context, "isTraceEnable",
false, SessionVariable::isEnableNereidsTrace);
}
private static <T> T getAndCacheSessionVariable(JobContext context, String cacheName,
T defaultValue, Function<SessionVariable, T> variableSupplier) {
CascadesContext cascadesContext = context.getCascadesContext();
ConnectContext connectContext = cascadesContext.getConnectContext();

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.jobs;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.scheduler.ScheduleContext;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.RuleType;
@ -29,21 +30,25 @@ import java.util.Map;
* Context for one job in Nereids' cascades framework.
*/
public class JobContext {
protected final CascadesContext cascadesContext;
protected final ScheduleContext scheduleContext;
protected final PhysicalProperties requiredProperties;
protected double costUpperBound;
protected boolean rewritten = false;
protected Map<RuleType, Integer> ruleInvokeTimes = Maps.newLinkedHashMap();
public JobContext(CascadesContext cascadesContext, PhysicalProperties requiredProperties, double costUpperBound) {
this.cascadesContext = cascadesContext;
public JobContext(ScheduleContext scheduleContext, PhysicalProperties requiredProperties, double costUpperBound) {
this.scheduleContext = scheduleContext;
this.requiredProperties = requiredProperties;
this.costUpperBound = costUpperBound;
}
public ScheduleContext getScheduleContext() {
return scheduleContext;
}
public CascadesContext getCascadesContext() {
return cascadesContext;
return (CascadesContext) scheduleContext;
}
public PhysicalProperties getRequiredProperties() {

View File

@ -31,5 +31,6 @@ public enum JobType {
TOP_DOWN_REWRITE,
VISITOR_REWRITE,
BOTTOM_UP_REWRITE,
JOIN_ORDER;
JOIN_ORDER,
LINK_PLAN;
}

View File

@ -0,0 +1,25 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs;
/** RewriteJob */
public interface RewriteJob {
void execute(JobContext jobContext);
boolean isOnce();
}

View File

@ -0,0 +1,52 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.stream.Stream;
/** TopicRewriteJob */
public class TopicRewriteJob implements RewriteJob {
public final String topicName;
public final List<RewriteJob> jobs;
/** constructor */
public TopicRewriteJob(String topicName, List<RewriteJob> jobs) {
this.topicName = topicName;
this.jobs = jobs.stream()
.flatMap(job -> job instanceof TopicRewriteJob
? ((TopicRewriteJob) job).jobs.stream()
: Stream.of(job)
)
.collect(ImmutableList.toImmutableList());
}
@Override
public void execute(JobContext jobContext) {
for (RewriteJob job : jobs) {
job.execute(jobContext);
}
}
@Override
public boolean isOnce() {
return true;
}
}

View File

@ -1,82 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount;
import org.apache.doris.nereids.rules.analysis.BindExpression;
import org.apache.doris.nereids.rules.analysis.BindRelation;
import org.apache.doris.nereids.rules.analysis.CheckPolicy;
import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots;
import org.apache.doris.nereids.rules.analysis.NormalizeRepeat;
import org.apache.doris.nereids.rules.analysis.ProjectToGlobalAggregate;
import org.apache.doris.nereids.rules.analysis.ProjectWithDistinctToAggregate;
import org.apache.doris.nereids.rules.analysis.RegisterCTE;
import org.apache.doris.nereids.rules.analysis.ReplaceExpressionByChildOutput;
import org.apache.doris.nereids.rules.analysis.ResolveOrdinalInOrderByAndGroupBy;
import org.apache.doris.nereids.rules.analysis.UserAuthentication;
import org.apache.doris.nereids.rules.rewrite.logical.HideOneRowRelationUnderUnion;
import com.google.common.collect.ImmutableList;
import java.util.Optional;
/**
* Execute the analysis rules.
*/
public class AnalyzeRulesJob extends BatchRulesJob {
/**
* Execute the analysis job with scope.
* @param cascadesContext planner context for execute job
* @param scope Parse the symbolic scope of the field
*/
public AnalyzeRulesJob(CascadesContext cascadesContext, Optional<Scope> scope) {
super(cascadesContext);
rulesJob.addAll(ImmutableList.of(
bottomUpBatch(
new RegisterCTE()
),
bottomUpBatch(
new BindRelation(),
new CheckPolicy(),
new UserAuthentication(),
new BindExpression(scope),
new ProjectToGlobalAggregate(),
// this rule check's the logicalProject node's isDisinct property
// and replace the logicalProject node with a LogicalAggregate node
// so any rule before this, if create a new logicalProject node
// should make sure isDisinct property is correctly passed around.
// please see rule BindSlotReference or BindFunction for example
new ProjectWithDistinctToAggregate(),
new AvgDistinctToSumDivCount(),
new ResolveOrdinalInOrderByAndGroupBy(),
new ReplaceExpressionByChildOutput(),
new HideOneRowRelationUnderUnion()
),
topDownBatch(
new FillUpMissingSlots(),
// We should use NormalizeRepeat to compute nullable properties for LogicalRepeat in the analysis
// stage. NormalizeRepeat will compute nullable property, add virtual slot, LogicalAggregate and
// LogicalProject for normalize. This rule depends on FillUpMissingSlots to fill up slots.
new NormalizeRepeat()
)
));
}
}

View File

@ -17,27 +17,28 @@
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.RuleFactory;
import org.apache.doris.nereids.rules.rewrite.BatchRewriteRuleFactory;
import org.apache.doris.nereids.rules.rewrite.logical.ExistsApplyToJoin;
import org.apache.doris.nereids.rules.rewrite.logical.InApplyToJoin;
import org.apache.doris.nereids.rules.rewrite.logical.ScalarApplyToJoin;
import com.google.common.collect.ImmutableList;
import java.util.List;
/**
* Convert logicalApply without a correlated to a logicalJoin.
*/
public class ConvertApplyToJoinJob extends BatchRulesJob {
/**
* Constructor.
*/
public ConvertApplyToJoinJob(CascadesContext cascadesContext) {
super(cascadesContext);
rulesJob.addAll(ImmutableList.of(
topDownBatch(ImmutableList.of(
new ScalarApplyToJoin(),
new InApplyToJoin(),
new ExistsApplyToJoin())
)));
public class ApplyToJoin implements BatchRewriteRuleFactory {
public static final List<RuleFactory> RULES = ImmutableList.of(
new ScalarApplyToJoin(),
new InApplyToJoin(),
new ExistsApplyToJoin()
);
@Override
public List<RuleFactory> getRuleFactories() {
return RULES;
}
}

View File

@ -0,0 +1,116 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.RewriteJob;
import org.apache.doris.nereids.jobs.TopicRewriteJob;
import org.apache.doris.nereids.jobs.rewrite.CustomRewriteJob;
import org.apache.doris.nereids.jobs.rewrite.PlanTreeRewriteBottomUpJob;
import org.apache.doris.nereids.jobs.rewrite.PlanTreeRewriteTopDownJob;
import org.apache.doris.nereids.jobs.rewrite.RootPlanTreeRewriteJob;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleFactory;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;
import java.util.stream.Stream;
/**
* Base class for executing all jobs.
*
* Each batch of rules will be uniformly executed.
*/
public abstract class BatchRewriteJob {
protected CascadesContext cascadesContext;
public BatchRewriteJob(CascadesContext cascadesContext) {
this.cascadesContext = Objects.requireNonNull(cascadesContext, "cascadesContext can not null");
}
public static List<RewriteJob> jobs(RewriteJob... jobs) {
return Arrays.stream(jobs)
.flatMap(job -> job instanceof TopicRewriteJob
? ((TopicRewriteJob) job).jobs.stream()
: Stream.of(job)
).collect(ImmutableList.toImmutableList());
}
public static TopicRewriteJob topic(String topicName, RewriteJob... jobs) {
return new TopicRewriteJob(topicName, Arrays.asList(jobs));
}
public static RewriteJob bottomUp(String batchName, RuleFactory... ruleFactories) {
return bottomUp(Arrays.asList(ruleFactories));
}
public static RewriteJob bottomUp(RuleFactory... ruleFactories) {
return bottomUp(Arrays.asList(ruleFactories));
}
public static RewriteJob bottomUp(List<RuleFactory> ruleFactories) {
List<Rule> rules = new ArrayList<>();
for (RuleFactory ruleFactory : ruleFactories) {
rules.addAll(ruleFactory.buildRules());
}
return new RootPlanTreeRewriteJob(rules, PlanTreeRewriteBottomUpJob::new, true);
}
public static RewriteJob topDown(RuleFactory... ruleFactories) {
return topDown(Arrays.asList(ruleFactories));
}
public static RewriteJob topDown(List<RuleFactory> ruleFactories) {
return topDown(ruleFactories, true);
}
public static RewriteJob topDown(List<RuleFactory> ruleFactories, boolean once) {
List<Rule> rules = new ArrayList<>();
for (RuleFactory ruleFactory : ruleFactories) {
rules.addAll(ruleFactory.buildRules());
}
return new RootPlanTreeRewriteJob(rules, PlanTreeRewriteTopDownJob::new, once);
}
public static RewriteJob custom(RuleType ruleType, Supplier<CustomRewriter> planRewriter) {
return new CustomRewriteJob(planRewriter, ruleType);
}
/**
* execute.
*/
public void execute() {
for (RewriteJob job : getJobs()) {
JobContext jobContext = cascadesContext.getCurrentJobContext();
do {
jobContext.setRewritten(false);
job.execute(jobContext);
} while (!job.isOnce() && jobContext.isRewritten());
}
}
public abstract List<RewriteJob> getJobs();
}

View File

@ -1,109 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.cascades.OptimizeGroupJob;
import org.apache.doris.nereids.jobs.rewrite.RewriteBottomUpJob;
import org.apache.doris.nereids.jobs.rewrite.RewriteTopDownJob;
import org.apache.doris.nereids.jobs.rewrite.VisitorRewriteJob;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleFactory;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
/**
* Base class for executing all jobs.
*
* Each batch of rules will be uniformly executed.
*/
public abstract class BatchRulesJob {
protected CascadesContext cascadesContext;
protected List<Job> rulesJob = new ArrayList<>();
BatchRulesJob(CascadesContext cascadesContext) {
this.cascadesContext = Objects.requireNonNull(cascadesContext, "cascadesContext can not null");
}
protected Job bottomUpBatch(RuleFactory... ruleFactories) {
return bottomUpBatch(Arrays.asList(ruleFactories));
}
protected Job bottomUpBatch(List<RuleFactory> ruleFactories) {
List<Rule> rules = new ArrayList<>();
for (RuleFactory ruleFactory : ruleFactories) {
rules.addAll(ruleFactory.buildRules());
}
return new RewriteBottomUpJob(
cascadesContext.getMemo().getRoot(),
rules,
cascadesContext.getCurrentJobContext());
}
protected Job topDownBatch(RuleFactory... ruleFactories) {
return topDownBatch(Arrays.asList(ruleFactories));
}
protected Job topDownBatch(List<RuleFactory> ruleFactories) {
List<Rule> rules = new ArrayList<>();
for (RuleFactory ruleFactory : ruleFactories) {
rules.addAll(ruleFactory.buildRules());
}
return new RewriteTopDownJob(cascadesContext.getMemo().getRoot(), rules,
cascadesContext.getCurrentJobContext());
}
protected Job topDownBatch(List<RuleFactory> ruleFactories, boolean once) {
List<Rule> rules = new ArrayList<>();
for (RuleFactory ruleFactory : ruleFactories) {
rules.addAll(ruleFactory.buildRules());
}
return new RewriteTopDownJob(cascadesContext.getMemo().getRoot(), rules,
cascadesContext.getCurrentJobContext(), once);
}
protected Job visitorJob(RuleType ruleType, DefaultPlanRewriter<JobContext> planRewriter) {
return new VisitorRewriteJob(cascadesContext, planRewriter, ruleType);
}
protected Job optimize() {
return new OptimizeGroupJob(
cascadesContext.getMemo().getRoot(),
cascadesContext.getCurrentJobContext());
}
/**
* execute.
*/
public void execute() {
for (Job job : rulesJob) {
do {
cascadesContext.getCurrentJobContext().setRewritten(false);
cascadesContext.pushJob(job);
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
} while (!job.isOnce() && cascadesContext.getCurrentJobContext().isRewritten());
}
}
}

View File

@ -18,17 +18,25 @@
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet;
import org.apache.doris.nereids.jobs.cascades.OptimizeGroupJob;
import com.google.common.collect.ImmutableList;
import java.util.Objects;
/**
* Analyze subquery.
* cascade optimizer.
*/
public class AdjustAggregateNullableForEmptySetJob extends BatchRulesJob {
public AdjustAggregateNullableForEmptySetJob(CascadesContext cascadesContext) {
super(cascadesContext);
rulesJob.addAll(ImmutableList.of(
bottomUpBatch(ImmutableList.of(new AdjustAggregateNullableForEmptySet()))));
public class CascadesOptimizer {
private CascadesContext cascadesContext;
public CascadesOptimizer(CascadesContext cascadesContext) {
this.cascadesContext = Objects.requireNonNull(cascadesContext, "cascadesContext cannot be null");
}
public void execute() {
cascadesContext.pushJob(new OptimizeGroupJob(
cascadesContext.getMemo().getRoot(),
cascadesContext.getCurrentJobContext())
);
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
}
}

View File

@ -17,15 +17,18 @@
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.rewrite.logical.ApplyPullFilterOnAgg;
import org.apache.doris.nereids.rules.rewrite.logical.ApplyPullFilterOnProjectUnderAgg;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateFilterUnderApplyProject;
import org.apache.doris.nereids.rules.rewrite.logical.PushApplyUnderFilter;
import org.apache.doris.nereids.rules.rewrite.logical.PushApplyUnderProject;
import org.apache.doris.nereids.rules.RuleFactory;
import org.apache.doris.nereids.rules.rewrite.BatchRewriteRuleFactory;
import org.apache.doris.nereids.rules.rewrite.logical.PullUpCorrelatedFilterUnderApplyAggregateProject;
import org.apache.doris.nereids.rules.rewrite.logical.PullUpProjectUnderApply;
import org.apache.doris.nereids.rules.rewrite.logical.UnCorrelatedApplyAggregateFilter;
import org.apache.doris.nereids.rules.rewrite.logical.UnCorrelatedApplyFilter;
import org.apache.doris.nereids.rules.rewrite.logical.UnCorrelatedApplyProjectFilter;
import com.google.common.collect.ImmutableList;
import java.util.List;
/**
* Adjust the plan in logicalApply so that there are no correlated columns in the subquery.
* Adjust the positions of apply and sub query nodes and apply,
@ -34,19 +37,17 @@ import com.google.common.collect.ImmutableList;
* For the project and filter on AGG, try to adjust them to apply.
* For the project and filter under AGG, bring the filter under AGG and merge it with agg.
*/
public class AdjustApplyFromCorrelateToUnCorrelateJob extends BatchRulesJob {
/**
* Constructor.
*/
public AdjustApplyFromCorrelateToUnCorrelateJob(CascadesContext cascadesContext) {
super(cascadesContext);
rulesJob.addAll(ImmutableList.of(
topDownBatch(ImmutableList.of(
new PushApplyUnderProject(),
new PushApplyUnderFilter(),
new EliminateFilterUnderApplyProject(),
new ApplyPullFilterOnAgg(),
new ApplyPullFilterOnProjectUnderAgg()
))));
public class CorrelateApplyToUnCorrelateApply implements BatchRewriteRuleFactory {
public static final List<RuleFactory> RULES = ImmutableList.of(
new PullUpProjectUnderApply(),
new UnCorrelatedApplyFilter(),
new UnCorrelatedApplyProjectFilter(),
new UnCorrelatedApplyAggregateFilter(),
new PullUpCorrelatedFilterUnderApplyAggregateProject()
);
@Override
public List<RuleFactory> getRuleFactories() {
return RULES;
}
}

View File

@ -17,26 +17,27 @@
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.RuleFactory;
import org.apache.doris.nereids.rules.rewrite.BatchRewriteRuleFactory;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateLimitUnderApply;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateSortUnderApply;
import com.google.common.collect.ImmutableList;
import java.util.List;
/**
* Eliminate useless operators in the subquery, including limit and sort.
* Compatible with the old optimizer, the sort and limit in the subquery will not take effect, just delete it directly.
*/
public class EliminateSpecificPlanUnderApplyJob extends BatchRulesJob {
/**
* Constructor.
*/
public EliminateSpecificPlanUnderApplyJob(CascadesContext cascadesContext) {
super(cascadesContext);
rulesJob.addAll(ImmutableList.of(
topDownBatch(ImmutableList.of(
new EliminateLimitUnderApply(),
new EliminateSortUnderApply()
))));
public class EliminateUselessPlanUnderApply implements BatchRewriteRuleFactory {
public static final List<RuleFactory> RULES = ImmutableList.of(
new EliminateLimitUnderApply(),
new EliminateSortUnderApply()
);
@Override
public List<RuleFactory> getRuleFactories() {
return RULES;
}
}

View File

@ -1,150 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet;
import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite;
import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionOptimization;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewrite;
import org.apache.doris.nereids.rules.mv.SelectMaterializedIndexWithAggregate;
import org.apache.doris.nereids.rules.mv.SelectMaterializedIndexWithoutAggregate;
import org.apache.doris.nereids.rules.rewrite.logical.AdjustNullable;
import org.apache.doris.nereids.rules.rewrite.logical.BuildAggForUnion;
import org.apache.doris.nereids.rules.rewrite.logical.CheckAndStandardizeWindowFunctionAndFrame;
import org.apache.doris.nereids.rules.rewrite.logical.ColumnPruning;
import org.apache.doris.nereids.rules.rewrite.logical.CountDistinctRewrite;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateAggregate;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateDedupJoinCondition;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateFilter;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateGroupByConstant;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateLimit;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateNotNull;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateNullAwareLeftAntiJoin;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateOrderByConstant;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateUnnecessaryProject;
import org.apache.doris.nereids.rules.rewrite.logical.ExtractAndNormalizeWindowExpression;
import org.apache.doris.nereids.rules.rewrite.logical.ExtractFilterFromCrossJoin;
import org.apache.doris.nereids.rules.rewrite.logical.ExtractSingleTableExpressionFromDisjunction;
import org.apache.doris.nereids.rules.rewrite.logical.FindHashConditionForJoin;
import org.apache.doris.nereids.rules.rewrite.logical.InferFilterNotNull;
import org.apache.doris.nereids.rules.rewrite.logical.InferJoinNotNull;
import org.apache.doris.nereids.rules.rewrite.logical.InferPredicates;
import org.apache.doris.nereids.rules.rewrite.logical.InnerToCrossJoin;
import org.apache.doris.nereids.rules.rewrite.logical.LimitPushDown;
import org.apache.doris.nereids.rules.rewrite.logical.MergeFilters;
import org.apache.doris.nereids.rules.rewrite.logical.MergeProjects;
import org.apache.doris.nereids.rules.rewrite.logical.MergeSetOperations;
import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate;
import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanPartition;
import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanTablet;
import org.apache.doris.nereids.rules.rewrite.logical.PushFilterInsideJoin;
import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;
import com.google.common.collect.ImmutableList;
/**
* Apply rules to optimize logical plan.
*/
public class NereidsRewriteJobExecutor extends BatchRulesJob {
/**
* Constructor.
*
* @param cascadesContext context for applying rules.
*/
public NereidsRewriteJobExecutor(CascadesContext cascadesContext) {
super(cascadesContext);
ImmutableList<Job> jobs = new ImmutableList.Builder<Job>()
.addAll(new EliminateSpecificPlanUnderApplyJob(cascadesContext).rulesJob)
// MergeProjects depends on this rule
.add(bottomUpBatch(ImmutableList.of(new LogicalSubQueryAliasToLogicalProject())))
// AdjustApplyFromCorrelateToUnCorrelateJob and ConvertApplyToJoinJob
// and SelectMaterializedIndexWithAggregate depends on this rule
.add(topDownBatch(ImmutableList.of(new MergeProjects())))
.add(topDownBatch(ImmutableList.of(new ExpressionNormalization(cascadesContext.getConnectContext()))))
.add(topDownBatch(ImmutableList.of(new ExpressionOptimization())))
.add(topDownBatch(ImmutableList.of(new ExtractSingleTableExpressionFromDisjunction())))
/*
* Subquery unnesting.
* 1. Adjust the plan in correlated logicalApply
* so that there are no correlated columns in the subquery.
* 2. Convert logicalApply to a logicalJoin.
* TODO: group these rules to make sure the result plan is what we expected.
*/
.addAll(new AdjustApplyFromCorrelateToUnCorrelateJob(cascadesContext).rulesJob)
.addAll(new ConvertApplyToJoinJob(cascadesContext).rulesJob)
.add(bottomUpBatch(ImmutableList.of(new AdjustAggregateNullableForEmptySet())))
.add(topDownBatch(ImmutableList.of(new EliminateGroupByConstant())))
.add(topDownBatch(ImmutableList.of(new NormalizeAggregate())))
.add(topDownBatch(ImmutableList.of(new ExtractAndNormalizeWindowExpression())))
// execute NormalizeAggregate() again to resolve nested AggregateFunctions in WindowExpression,
// e.g. sum(sum(c1)) over(partition by avg(c1))
.add(topDownBatch(ImmutableList.of(new NormalizeAggregate())))
.add(topDownBatch(ImmutableList.of(new CheckAndStandardizeWindowFunctionAndFrame())))
.add(topDownBatch(ImmutableList.of(new InferFilterNotNull())))
.add(topDownBatch(ImmutableList.of(new InferJoinNotNull())))
.add(topDownBatch(RuleSet.PUSH_DOWN_FILTERS, false))
.add(visitorJob(RuleType.INFER_PREDICATES, new InferPredicates()))
.add(topDownBatch(ImmutableList.of(new ExtractFilterFromCrossJoin())))
.add(topDownBatch(ImmutableList.of(new MergeFilters())))
.add(topDownBatch(ImmutableList.of(new ReorderJoin())))
.add(topDownBatch(ImmutableList.of(new EliminateDedupJoinCondition())))
.add(topDownBatch(ImmutableList.of(new ColumnPruning())))
.add(topDownBatch(RuleSet.PUSH_DOWN_FILTERS, false))
.add(visitorJob(RuleType.INFER_PREDICATES, new InferPredicates()))
.add(topDownBatch(RuleSet.PUSH_DOWN_FILTERS, false))
.add(visitorJob(RuleType.INFER_PREDICATES, new InferPredicates()))
.add(topDownBatch(RuleSet.PUSH_DOWN_FILTERS, false))
.add(topDownBatch(ImmutableList.of(new PushFilterInsideJoin())))
.add(topDownBatch(ImmutableList.of(new FindHashConditionForJoin())))
.add(topDownBatch(RuleSet.PUSH_DOWN_FILTERS, false))
.add(topDownBatch(ImmutableList.of(new InnerToCrossJoin())))
.add(topDownBatch(ImmutableList.of(new EliminateNotNull())))
.add(topDownBatch(ImmutableList.of(new EliminateLimit())))
.add(topDownBatch(ImmutableList.of(new EliminateFilter())))
.add(topDownBatch(ImmutableList.of(new PruneOlapScanPartition())))
.add(topDownBatch(ImmutableList.of(new CountDistinctRewrite())))
// we need to execute this rule at the end of rewrite
// to avoid two consecutive same project appear when we do optimization.
.add(topDownBatch(ImmutableList.of(new EliminateOrderByConstant())))
.add(topDownBatch(ImmutableList.of(new EliminateUnnecessaryProject())))
.add(topDownBatch(ImmutableList.of(new SelectMaterializedIndexWithAggregate())))
.add(topDownBatch(ImmutableList.of(new SelectMaterializedIndexWithoutAggregate())))
.add(topDownBatch(ImmutableList.of(new PruneOlapScanTablet())))
.add(topDownBatch(ImmutableList.of(new EliminateAggregate())))
.add(bottomUpBatch(ImmutableList.of(new MergeSetOperations())))
.add(topDownBatch(ImmutableList.of(new LimitPushDown())))
.add(topDownBatch(ImmutableList.of(new BuildAggForUnion())))
.add(topDownBatch(ImmutableList.of(new EliminateNullAwareLeftAntiJoin())))
// this rule batch must keep at the end of rewrite to do some plan check
.add(bottomUpBatch(ImmutableList.of(
new AdjustNullable(),
new ExpressionRewrite(CheckLegalityAfterRewrite.INSTANCE),
new CheckAfterRewrite()))
)
.build();
rulesJob.addAll(jobs);
}
}

View File

@ -0,0 +1,214 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.RewriteJob;
import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet;
import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount;
import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite;
import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionOptimization;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewrite;
import org.apache.doris.nereids.rules.mv.SelectMaterializedIndexWithAggregate;
import org.apache.doris.nereids.rules.mv.SelectMaterializedIndexWithoutAggregate;
import org.apache.doris.nereids.rules.rewrite.logical.AdjustNullable;
import org.apache.doris.nereids.rules.rewrite.logical.BuildAggForUnion;
import org.apache.doris.nereids.rules.rewrite.logical.CheckAndStandardizeWindowFunctionAndFrame;
import org.apache.doris.nereids.rules.rewrite.logical.ColumnPruning;
import org.apache.doris.nereids.rules.rewrite.logical.CountDistinctRewrite;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateAggregate;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateDedupJoinCondition;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateFilter;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateGroupByConstant;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateLimit;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateNotNull;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateNullAwareLeftAntiJoin;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateOrderByConstant;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateUnnecessaryProject;
import org.apache.doris.nereids.rules.rewrite.logical.ExtractAndNormalizeWindowExpression;
import org.apache.doris.nereids.rules.rewrite.logical.ExtractFilterFromCrossJoin;
import org.apache.doris.nereids.rules.rewrite.logical.ExtractSingleTableExpressionFromDisjunction;
import org.apache.doris.nereids.rules.rewrite.logical.FindHashConditionForJoin;
import org.apache.doris.nereids.rules.rewrite.logical.InferFilterNotNull;
import org.apache.doris.nereids.rules.rewrite.logical.InferJoinNotNull;
import org.apache.doris.nereids.rules.rewrite.logical.InferPredicates;
import org.apache.doris.nereids.rules.rewrite.logical.InnerToCrossJoin;
import org.apache.doris.nereids.rules.rewrite.logical.LimitPushDown;
import org.apache.doris.nereids.rules.rewrite.logical.MergeFilters;
import org.apache.doris.nereids.rules.rewrite.logical.MergeProjects;
import org.apache.doris.nereids.rules.rewrite.logical.MergeSetOperations;
import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate;
import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanPartition;
import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanTablet;
import org.apache.doris.nereids.rules.rewrite.logical.PushFilterInsideJoin;
import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;
import java.util.List;
/**
* Apply rules to optimize logical plan.
*/
public class NereidsRewriter extends BatchRewriteJob {
private static final List<RewriteJob> REWRITE_JOBS = jobs(
topic("Normalization",
topDown(
new EliminateOrderByConstant(),
new EliminateGroupByConstant(),
// MergeProjects depends on this rule
new LogicalSubQueryAliasToLogicalProject(),
// rewrite expressions, no depends
new ExpressionNormalization(),
new ExpressionOptimization(),
new AvgDistinctToSumDivCount(),
new CountDistinctRewrite(),
new NormalizeAggregate(),
new ExtractFilterFromCrossJoin()
),
// ExtractSingleTableExpressionFromDisjunction conflict to InPredicateToEqualToRule
// in the ExpressionNormalization, so must invoke in another job, or else run into
// deep loop
topDown(
new ExtractSingleTableExpressionFromDisjunction()
)
),
topic("Subquery unnesting",
bottomUp(
new EliminateUselessPlanUnderApply(),
// CorrelateApplyToUnCorrelateApply and ApplyToJoin
// and SelectMaterializedIndexWithAggregate depends on this rule
new MergeProjects(),
/*
* Subquery unnesting.
* 1. Adjust the plan in correlated logicalApply
* so that there are no correlated columns in the subquery.
* 2. Convert logicalApply to a logicalJoin.
* TODO: group these rules to make sure the result plan is what we expected.
*/
new CorrelateApplyToUnCorrelateApply(),
new ApplyToJoin()
)
),
topDown(
new AdjustAggregateNullableForEmptySet()
),
topic("Window analysis",
topDown(
new ExtractAndNormalizeWindowExpression(),
// execute NormalizeAggregate() again to resolve nested AggregateFunctions in WindowExpression,
// e.g. sum(sum(c1)) over(partition by avg(c1))
new NormalizeAggregate(),
new CheckAndStandardizeWindowFunctionAndFrame()
)
),
topic("Rewrite join",
// infer not null filter, then push down filter, and then reorder join(cross join to inner join)
topDown(
new InferFilterNotNull(),
new InferJoinNotNull()
),
// ReorderJoin depends PUSH_DOWN_FILTERS
// the PUSH_DOWN_FILTERS depends on lots of rules, e.g. merge project, eliminate outer,
// sometimes transform the bottom plan make some rules usable which can apply to the top plan,
// but top-down traverse can not cover this case in one iteration, so bottom-up is more
// efficient because it can find the new plans and apply transform wherever it is
bottomUp(RuleSet.PUSH_DOWN_FILTERS),
topDown(
new MergeFilters(),
new ReorderJoin(),
new PushFilterInsideJoin(),
new FindHashConditionForJoin(),
new InnerToCrossJoin(),
new EliminateNullAwareLeftAntiJoin()
),
topDown(
new EliminateDedupJoinCondition()
)
),
topic("Column pruning and infer predicate",
topDown(new ColumnPruning()),
custom(RuleType.INFER_PREDICATES, () -> new InferPredicates()),
// column pruning create new project, so we should use PUSH_DOWN_FILTERS
// to change filter-project to project-filter
bottomUp(RuleSet.PUSH_DOWN_FILTERS),
// after eliminate outer join in the PUSH_DOWN_FILTERS, we can infer more predicate and push down
custom(RuleType.INFER_PREDICATES, () -> new InferPredicates()),
bottomUp(RuleSet.PUSH_DOWN_FILTERS),
// after eliminate outer join, we can move some filters to join.otherJoinConjuncts,
// this can help to translate plan to backend
topDown(
new PushFilterInsideJoin()
)
),
// this rule should invoke after ColumnPruning
custom(RuleType.ELIMINATE_UNNECESSARY_PROJECT, () -> new EliminateUnnecessaryProject()),
// we need to execute this rule at the end of rewrite
// to avoid two consecutive same project appear when we do optimization.
topic("Others optimization", topDown(
new EliminateNotNull(),
new EliminateLimit(),
new EliminateFilter(),
new PruneOlapScanPartition(),
new SelectMaterializedIndexWithAggregate(),
new SelectMaterializedIndexWithoutAggregate(),
new PruneOlapScanTablet(),
new EliminateAggregate(),
new MergeSetOperations(),
new LimitPushDown(),
new BuildAggForUnion()
)),
// this rule batch must keep at the end of rewrite to do some plan check
topic("Final rewrite and check", bottomUp(
new AdjustNullable(),
new ExpressionRewrite(CheckLegalityAfterRewrite.INSTANCE),
new CheckAfterRewrite()
))
);
public NereidsRewriter(CascadesContext cascadesContext) {
super(cascadesContext);
}
@Override
public List<RewriteJob> getJobs() {
return REWRITE_JOBS;
}
}

View File

@ -0,0 +1,78 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.rewrite;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.RewriteJob;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import java.util.Locale;
import java.util.Objects;
import java.util.Set;
import java.util.function.Supplier;
/**
* Custom rewrite the plan.
*/
public class CustomRewriteJob implements RewriteJob {
private final RuleType ruleType;
private final Supplier<CustomRewriter> customRewriter;
/**
* Constructor.
*/
public CustomRewriteJob(Supplier<CustomRewriter> rewriter, RuleType ruleType) {
this.ruleType = Objects.requireNonNull(ruleType, "ruleType cannot be null");
this.customRewriter = Objects.requireNonNull(rewriter, "customRewriter cannot be null");
}
@Override
public void execute(JobContext context) {
Set<String> disableRules = Job.getDisableRules(context);
if (disableRules.contains(ruleType.name().toUpperCase(Locale.ROOT))) {
return;
}
Plan root = context.getCascadesContext().getRewritePlan();
// COUNTER_TRACER.log(CounterEvent.of(Memo.get=-StateId(), CounterType.JOB_EXECUTION, group, logicalExpression,
// root));
Plan rewrittenRoot = customRewriter.get().rewriteRoot(root, context);
// don't remove this comment, it can help us to trace some bug when developing.
// if (!root.deepEquals(rewrittenRoot)) {
// String traceBefore = root.treeString();
// String traceAfter = root.treeString();
// printTraceLog(ruleType, traceBefore, traceAfter);
// }
context.getCascadesContext().setRewritePlan(rewrittenRoot);
}
@Override
public boolean isOnce() {
return false;
}
private void printTraceLog(RuleType ruleType, String traceBefore, String traceAfter) {
System.out.println("========== " + getClass().getSimpleName() + " " + ruleType
+ " ==========\nbefore:\n" + traceBefore + "\n\nafter:\n" + traceAfter + "\n");
}
}

View File

@ -0,0 +1,128 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.rewrite;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.plans.Plan;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
/** PlanTreeRewriteBottomUpJob */
public class PlanTreeRewriteBottomUpJob extends PlanTreeRewriteJob {
private static final String REWRITE_STATE_KEY = "rewrite_state";
private RewriteJobContext rewriteJobContext;
private List<Rule> rules;
enum RewriteState {
ENSURE_CHILDREN_REWRITTEN, REWRITE_THIS, REWRITTEN
}
public PlanTreeRewriteBottomUpJob(RewriteJobContext rewriteJobContext, JobContext context, List<Rule> rules) {
super(JobType.TOP_DOWN_REWRITE, context);
this.rewriteJobContext = Objects.requireNonNull(rewriteJobContext, "rewriteContext cannot be null");
this.rules = Objects.requireNonNull(rules, "rules cannot be null");
}
@Override
public void execute() {
// use childrenVisited to judge whether clear the state in the previous batch
boolean clearStatePhase = !rewriteJobContext.childrenVisited;
if (clearStatePhase) {
traverseClearState();
return;
}
Plan plan = rewriteJobContext.plan;
RewriteState state = getState(plan);
switch (state) {
case REWRITE_THIS:
rewriteThis();
return;
case ENSURE_CHILDREN_REWRITTEN:
ensureChildrenRewritten();
return;
case REWRITTEN:
rewriteJobContext.result = plan;
return;
default:
throw new IllegalStateException("Unknown rewrite state: " + state);
}
}
private void traverseClearState() {
RewriteJobContext clearedStateContext = rewriteJobContext.withChildrenVisited(true);
setState(clearedStateContext.plan, RewriteState.REWRITE_THIS);
pushJob(new PlanTreeRewriteBottomUpJob(clearedStateContext, context, rules));
List<Plan> children = clearedStateContext.plan.children();
for (int i = children.size() - 1; i >= 0; i--) {
Plan child = children.get(i);
RewriteJobContext childRewriteJobContext = new RewriteJobContext(
child, clearedStateContext, i, false);
pushJob(new PlanTreeRewriteBottomUpJob(childRewriteJobContext, context, rules));
}
}
private void rewriteThis() {
Plan plan = linkChildren(rewriteJobContext.plan, rewriteJobContext.childrenContext);
RewriteResult rewriteResult = rewrite(plan, rules, rewriteJobContext);
if (rewriteResult.hasNewPlan) {
RewriteJobContext newJobContext = rewriteJobContext.withPlan(rewriteResult.plan);
RewriteState state = getState(rewriteResult.plan);
// some eliminate rule will return a rewritten plan
if (state == RewriteState.REWRITTEN) {
newJobContext.setResult(rewriteResult.plan);
return;
}
pushJob(new PlanTreeRewriteBottomUpJob(newJobContext, context, rules));
setState(rewriteResult.plan, RewriteState.ENSURE_CHILDREN_REWRITTEN);
} else {
setState(rewriteResult.plan, RewriteState.REWRITTEN);
rewriteJobContext.setResult(rewriteResult.plan);
}
}
private void ensureChildrenRewritten() {
Plan plan = rewriteJobContext.plan;
setState(plan, RewriteState.REWRITE_THIS);
pushJob(new PlanTreeRewriteBottomUpJob(rewriteJobContext, context, rules));
List<Plan> children = plan.children();
for (int i = children.size() - 1; i >= 0; i--) {
Plan child = children.get(i);
// some rule return new plan tree, which the number of new plan node > 1,
// we should transform this new plan nodes too.
RewriteJobContext childRewriteJobContext = new RewriteJobContext(
child, rewriteJobContext, i, false);
pushJob(new PlanTreeRewriteBottomUpJob(childRewriteJobContext, context, rules));
}
}
private static final RewriteState getState(Plan plan) {
Optional<RewriteState> state = plan.getMutableState(REWRITE_STATE_KEY);
return state.orElse(RewriteState.ENSURE_CHILDREN_REWRITTEN);
}
private static final void setState(Plan plan, RewriteState state) {
plan.setMutableState(REWRITE_STATE_KEY, state);
}
}

View File

@ -0,0 +1,117 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.rewrite;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.pattern.Pattern;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.plans.Plan;
import com.google.common.base.Preconditions;
import java.util.List;
/** PlanTreeRewriteJob */
public abstract class PlanTreeRewriteJob extends Job {
public PlanTreeRewriteJob(JobType type, JobContext context) {
super(type, context);
}
protected RewriteResult rewrite(Plan plan, List<Rule> rules, RewriteJobContext rewriteJobContext) {
// boolean traceEnable = isTraceEnable(context);
boolean isRewriteRoot = rewriteJobContext.isRewriteRoot();
CascadesContext cascadesContext = context.getCascadesContext();
cascadesContext.setIsRewriteRoot(isRewriteRoot);
List<Rule> validRules = getValidRules(rules);
for (Rule rule : validRules) {
Pattern<Plan> pattern = (Pattern<Plan>) rule.getPattern();
if (pattern.matchPlanTree(plan)) {
List<Plan> newPlans = rule.transform(plan, cascadesContext);
Preconditions.checkState(newPlans.size() == 1,
"Rewrite rule should generate one plan: " + rule.getRuleType());
Plan newPlan = newPlans.get(0);
if (!newPlan.deepEquals(plan)) {
// don't remove this comment, it can help us to trace some bug when developing.
// String traceBefore = null;
// if (traceEnable) {
// traceBefore = getCurrentPlanTreeString();
// }
rewriteJobContext.result = newPlan;
context.setRewritten(true);
rule.acceptPlan(newPlan);
// if (traceEnable) {
// String traceAfter = getCurrentPlanTreeString();
// printTraceLog(rule, traceBefore, traceAfter);
// }
return new RewriteResult(true, newPlan);
}
}
}
return new RewriteResult(false, plan);
}
protected Plan linkChildrenAndParent(Plan plan, RewriteJobContext rewriteJobContext) {
Plan newPlan = linkChildren(plan, rewriteJobContext.childrenContext);
rewriteJobContext.setResult(newPlan);
return newPlan;
}
protected Plan linkChildren(Plan plan, RewriteJobContext[] childrenContext) {
boolean changed = false;
Plan[] newChildren = new Plan[childrenContext.length];
for (int i = 0; i < childrenContext.length; ++i) {
Plan result = childrenContext[i].result;
Plan oldChild = plan.child(i);
if (result != null && result != oldChild) {
newChildren[i] = result;
changed = true;
} else {
newChildren[i] = oldChild;
}
}
return changed ? plan.withChildren(newChildren) : plan;
}
private String getCurrentPlanTreeString() {
return context.getCascadesContext()
.getCurrentRootRewriteJobContext().get()
.getNewestPlan()
.treeString();
}
private void printTraceLog(Rule rule, String traceBefore, String traceAfter) {
System.out.println("========== " + getClass().getSimpleName() + " " + rule.getRuleType()
+ " ==========\nbefore:\n" + traceBefore + "\n\nafter:\n" + traceAfter + "\n");
// LOGGER.info("========== {} {} ==========\nbefore:\n{}\n\nafter:\n{}\n",
// getClass().getSimpleName(), rule.getRuleType(), traceBefore, traceAfter);
}
static class RewriteResult {
final boolean hasNewPlan;
final Plan plan;
public RewriteResult(boolean hasNewPlan, Plan plan) {
this.hasNewPlan = hasNewPlan;
this.plan = plan;
}
}
}

View File

@ -0,0 +1,66 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.rewrite;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.plans.Plan;
import java.util.List;
import java.util.Objects;
/** PlanTreeRewriteTopDownJob */
public class PlanTreeRewriteTopDownJob extends PlanTreeRewriteJob {
private RewriteJobContext rewriteJobContext;
private List<Rule> rules;
public PlanTreeRewriteTopDownJob(RewriteJobContext rewriteJobContext, JobContext context, List<Rule> rules) {
super(JobType.TOP_DOWN_REWRITE, context);
this.rewriteJobContext = Objects.requireNonNull(rewriteJobContext, "rewriteContext cannot be null");
this.rules = Objects.requireNonNull(rules, "rules cannot be null");
}
@Override
public void execute() {
if (!rewriteJobContext.childrenVisited) {
RewriteResult rewriteResult = rewrite(rewriteJobContext.plan, rules, rewriteJobContext);
if (rewriteResult.hasNewPlan) {
RewriteJobContext newContext = rewriteJobContext
.withPlanAndChildrenVisited(rewriteResult.plan, false);
pushJob(new PlanTreeRewriteTopDownJob(newContext, context, rules));
return;
}
RewriteJobContext newRewriteJobContext = rewriteJobContext.withChildrenVisited(true);
pushJob(new PlanTreeRewriteTopDownJob(newRewriteJobContext, context, rules));
List<Plan> children = newRewriteJobContext.plan.children();
for (int i = children.size() - 1; i >= 0; i--) {
RewriteJobContext childRewriteJobContext = new RewriteJobContext(
children.get(i), newRewriteJobContext, i, false);
pushJob(new PlanTreeRewriteTopDownJob(childRewriteJobContext, context, rules));
}
} else {
Plan result = linkChildrenAndParent(rewriteJobContext.plan, rewriteJobContext);
if (rewriteJobContext.parentContext == null) {
context.getCascadesContext().setRewritePlan(result);
}
}
}
}

View File

@ -0,0 +1,65 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.rewrite;
import org.apache.doris.nereids.trees.plans.Plan;
import javax.annotation.Nullable;
/** RewriteJobContext */
public class RewriteJobContext {
final boolean childrenVisited;
final RewriteJobContext parentContext;
final int childIndexInParentContext;
final Plan plan;
final RewriteJobContext[] childrenContext;
Plan result;
/** RewriteJobContext */
public RewriteJobContext(Plan plan, @Nullable RewriteJobContext parentContext, int childIndexInParentContext,
boolean childrenVisited) {
this.plan = plan;
this.parentContext = parentContext;
this.childIndexInParentContext = childIndexInParentContext;
this.childrenVisited = childrenVisited;
this.childrenContext = new RewriteJobContext[plan.arity()];
if (parentContext != null) {
parentContext.childrenContext[childIndexInParentContext] = this;
}
}
public void setResult(Plan result) {
this.result = result;
}
public RewriteJobContext withChildrenVisited(boolean childrenVisited) {
return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited);
}
public RewriteJobContext withPlan(Plan plan) {
return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited);
}
public RewriteJobContext withPlanAndChildrenVisited(Plan plan, boolean childrenVisited) {
return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited);
}
public boolean isRewriteRoot() {
return false;
}
}

View File

@ -0,0 +1,171 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.rewrite;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.jobs.RewriteJob;
import org.apache.doris.nereids.jobs.scheduler.JobStack;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.plans.Plan;
import java.util.List;
import java.util.Objects;
/** RootPlanTreeRewriteJob */
public class RootPlanTreeRewriteJob implements RewriteJob {
private final List<Rule> rules;
private final RewriteJobBuilder rewriteJobBuilder;
private final boolean once;
public RootPlanTreeRewriteJob(List<Rule> rules, RewriteJobBuilder rewriteJobBuilder, boolean once) {
this.rules = Objects.requireNonNull(rules, "rules cannot be null");
this.rewriteJobBuilder = Objects.requireNonNull(rewriteJobBuilder, "rewriteJobBuilder cannot be null");
this.once = once;
}
@Override
public void execute(JobContext context) {
CascadesContext cascadesContext = context.getCascadesContext();
// get plan from the cascades context
Plan root = cascadesContext.getRewritePlan();
// write rewritten root plan to cascades context by the RootRewriteJobContext
RootRewriteJobContext rewriteJobContext = new RootRewriteJobContext(root, false, context);
Job rewriteJob = rewriteJobBuilder.build(rewriteJobContext, context, rules);
context.getScheduleContext().pushJob(rewriteJob);
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
cascadesContext.setCurrentRootRewriteJobContext(null);
}
@Override
public boolean isOnce() {
return once;
}
/** RewriteJobBuilder */
public interface RewriteJobBuilder {
Job build(RewriteJobContext rewriteJobContext, JobContext jobContext, List<Rule> rules);
}
/** RootRewriteJobContext */
public static class RootRewriteJobContext extends RewriteJobContext {
private JobContext jobContext;
RootRewriteJobContext(Plan plan, boolean childrenVisited, JobContext jobContext) {
super(plan, null, -1, childrenVisited);
this.jobContext = Objects.requireNonNull(jobContext, "jobContext cannot be null");
jobContext.getCascadesContext().setCurrentRootRewriteJobContext(this);
}
@Override
public boolean isRewriteRoot() {
return true;
}
@Override
public void setResult(Plan result) {
jobContext.getCascadesContext().setRewritePlan(result);
}
@Override
public RewriteJobContext withChildrenVisited(boolean childrenVisited) {
return new RootRewriteJobContext(plan, childrenVisited, jobContext);
}
@Override
public RewriteJobContext withPlan(Plan plan) {
return new RootRewriteJobContext(plan, childrenVisited, jobContext);
}
@Override
public RewriteJobContext withPlanAndChildrenVisited(Plan plan, boolean childrenVisited) {
return new RootRewriteJobContext(plan, childrenVisited, jobContext);
}
/** linkChildren */
public Plan getNewestPlan() {
JobStack jobStack = new JobStack();
LinkPlanJob linkPlanJob = new LinkPlanJob(
jobContext, this, null, false, jobStack);
jobStack.push(linkPlanJob);
while (!jobStack.isEmpty()) {
Job job = jobStack.pop();
job.execute();
}
return linkPlanJob.result;
}
}
/** use to assemble the rewriting plan */
private static class LinkPlanJob extends Job {
LinkPlanJob parentJob;
RewriteJobContext rewriteJobContext;
Plan[] childrenResult;
Plan result;
boolean linked;
JobStack jobStack;
private LinkPlanJob(JobContext context, RewriteJobContext rewriteJobContext,
LinkPlanJob parentJob, boolean linked, JobStack jobStack) {
super(JobType.LINK_PLAN, context);
this.rewriteJobContext = rewriteJobContext;
this.parentJob = parentJob;
this.linked = linked;
this.childrenResult = new Plan[rewriteJobContext.plan.arity()];
this.jobStack = jobStack;
}
@Override
public void execute() {
if (!linked) {
linked = true;
jobStack.push(this);
for (int i = rewriteJobContext.childrenContext.length - 1; i >= 0; i--) {
RewriteJobContext childContext = rewriteJobContext.childrenContext[i];
if (childContext != null) {
jobStack.push(new LinkPlanJob(context, childContext, this, false, jobStack));
}
}
} else if (rewriteJobContext.result != null) {
linkResult(rewriteJobContext.result);
} else {
Plan[] newChildren = new Plan[childrenResult.length];
for (int i = 0; i < newChildren.length; i++) {
Plan childResult = childrenResult[i];
if (childResult == null) {
childResult = rewriteJobContext.plan.child(i);
}
newChildren[i] = childResult;
}
linkResult(rewriteJobContext.plan.withChildren(newChildren));
}
}
private void linkResult(Plan result) {
if (parentJob != null) {
parentJob.childrenResult[rewriteJobContext.childIndexInParentContext] = result;
} else {
this.result = result;
}
}
}
}

View File

@ -1,69 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.rewrite;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.metrics.CounterType;
import org.apache.doris.nereids.metrics.event.CounterEvent;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import java.util.Locale;
import java.util.Objects;
/**
* Use visitor to rewrite the plan.
*/
public class VisitorRewriteJob extends Job {
private final RuleType ruleType;
private final Group group;
private final DefaultPlanRewriter<JobContext> planRewriter;
/**
* Constructor.
*/
public VisitorRewriteJob(CascadesContext cascadesContext,
DefaultPlanRewriter<JobContext> rewriter, RuleType ruleType) {
super(JobType.VISITOR_REWRITE, cascadesContext.getCurrentJobContext(), true);
this.ruleType = Objects.requireNonNull(ruleType, "ruleType cannot be null");
this.group = Objects.requireNonNull(cascadesContext.getMemo().getRoot(), "group cannot be null");
this.planRewriter = Objects.requireNonNull(rewriter, "planRewriter cannot be null");
}
@Override
public void execute() {
if (disableRules.contains(ruleType.name().toUpperCase(Locale.ROOT))) {
return;
}
GroupExpression logicalExpression = group.getLogicalExpression();
Plan root = context.getCascadesContext().getMemo().copyOut(logicalExpression, true);
COUNTER_TRACER.log(CounterEvent.of(Memo.getStateId(), CounterType.JOB_EXECUTION, group, logicalExpression,
root));
Plan rewrittenRoot = root.accept(planRewriter, context);
context.getCascadesContext().getMemo().copyIn(rewrittenRoot, group, true);
}
}

View File

@ -17,15 +17,9 @@
package org.apache.doris.nereids.jobs.scheduler;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.jobs.Job;
/**
* Scheduler to schedule jobs in Nereids.
*/
public interface JobScheduler {
void executeJob(Job job, CascadesContext context);
void executeJobPool(CascadesContext cascadesContext) throws AnalysisException;
void executeJobPool(ScheduleContext scheduleContext);
}

View File

@ -15,20 +15,14 @@
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.batch;
package org.apache.doris.nereids.jobs.scheduler;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.analysis.AnalyzeSubquery;
import org.apache.doris.nereids.jobs.Job;
import com.google.common.collect.ImmutableList;
/** ScheduleContext */
public interface ScheduleContext {
/**
* Analyze subquery.
*/
public class AnalyzeSubqueryRulesJob extends BatchRulesJob {
public AnalyzeSubqueryRulesJob(CascadesContext cascadesContext) {
super(cascadesContext);
rulesJob.addAll(ImmutableList.of(
bottomUpBatch(ImmutableList.of(new AnalyzeSubquery()))));
}
JobPool getJobPool();
void pushJob(Job job);
}

View File

@ -17,8 +17,6 @@
package org.apache.doris.nereids.jobs.scheduler;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.jobs.Job;
/**
@ -26,13 +24,8 @@ import org.apache.doris.nereids.jobs.Job;
*/
public class SimpleJobScheduler implements JobScheduler {
@Override
public void executeJob(Job job, CascadesContext context) {
}
@Override
public void executeJobPool(CascadesContext cascadesContext) throws AnalysisException {
JobPool pool = cascadesContext.getJobPool();
public void executeJobPool(ScheduleContext scheduleContext) {
JobPool pool = scheduleContext.getJobPool();
while (!pool.isEmpty()) {
Job job = pool.pop();
job.execute();

View File

@ -393,7 +393,16 @@ public class Group {
@Override
public String toString() {
return "Group[" + groupId + "]";
StringBuilder str = new StringBuilder("Group[" + groupId + "]\n");
str.append("logical expressions:\n");
for (GroupExpression logicalExpression : logicalExpressions) {
str.append(" ").append(logicalExpression).append("\n");
}
str.append("physical expressions:\n");
for (GroupExpression physicalExpression : physicalExpressions) {
str.append(" ").append(physicalExpression).append("\n");
}
return str.toString();
}
/**
@ -404,7 +413,7 @@ public class Group {
public String treeString() {
Function<Object, String> toString = obj -> {
if (obj instanceof Group) {
return obj.toString();
return "Group[" + ((Group) obj).groupId + "]";
} else if (obj instanceof GroupExpression) {
return ((GroupExpression) obj).getPlan().toString();
} else if (obj instanceof Pair) {

View File

@ -216,11 +216,11 @@ public class Memo {
* Utility function to create a new {@link CascadesContext} with this Memo.
*/
public CascadesContext newCascadesContext(StatementContext statementContext) {
return new CascadesContext(this, statementContext, PhysicalProperties.ANY);
return new CascadesContext(null, this, statementContext, PhysicalProperties.ANY);
}
public CascadesContext newCascadesContext(StatementContext statementContext, CTEContext cteContext) {
return new CascadesContext(this, statementContext, cteContext, PhysicalProperties.ANY);
return new CascadesContext(null, this, statementContext, cteContext, PhysicalProperties.ANY);
}
/**

View File

@ -49,4 +49,8 @@ public class MatchingContext<TYPE extends Plan> {
this.connectContext = cascadesContext.getConnectContext();
this.cteContext = cascadesContext.getCteContext();
}
public boolean isRewriteRoot() {
return cascadesContext.isRewriteRoot();
}
}

View File

@ -0,0 +1,322 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.pattern;
import org.apache.doris.nereids.trees.plans.BinaryPlan;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.LeafPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.UnaryPlan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalBinary;
import org.apache.doris.nereids.trees.plans.logical.LogicalExcept;
import org.apache.doris.nereids.trees.plans.logical.LogicalIntersect;
import org.apache.doris.nereids.trees.plans.logical.LogicalLeaf;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnary;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.physical.PhysicalBinary;
import org.apache.doris.nereids.trees.plans.physical.PhysicalLeaf;
import org.apache.doris.nereids.trees.plans.physical.PhysicalRelation;
import org.apache.doris.nereids.trees.plans.physical.PhysicalUnary;
import java.util.Arrays;
/** MemoPatterns */
public interface MemoPatterns extends Patterns {
default PatternDescriptor<GroupPlan> group() {
return new PatternDescriptor<>(Pattern.GROUP, defaultPromise());
}
default PatternDescriptor<GroupPlan> multiGroup() {
return new PatternDescriptor<>(Pattern.MULTI_GROUP, defaultPromise());
}
/* abstract plan operator patterns */
/**
* create a leafPlan pattern.
*/
default PatternDescriptor<LeafPlan> leafPlan() {
return new PatternDescriptor(new TypePattern(LeafPlan.class), defaultPromise());
}
/**
* create a unaryPlan pattern.
*/
default PatternDescriptor<UnaryPlan<GroupPlan>> unaryPlan() {
return new PatternDescriptor(new TypePattern(UnaryPlan.class, Pattern.GROUP), defaultPromise());
}
/**
* create a unaryPlan pattern.
*/
default <C extends Plan> PatternDescriptor<UnaryPlan<C>>
unaryPlan(PatternDescriptor<C> child) {
return new PatternDescriptor(new TypePattern(UnaryPlan.class, child.pattern), defaultPromise());
}
/**
* create a binaryPlan pattern.
*/
default PatternDescriptor<BinaryPlan<GroupPlan, GroupPlan>> binaryPlan() {
return new PatternDescriptor(
new TypePattern(BinaryPlan.class, Pattern.GROUP, Pattern.GROUP),
defaultPromise()
);
}
/**
* create a binaryPlan pattern.
*/
default <LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Plan>
PatternDescriptor<BinaryPlan<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE>> binaryPlan(
PatternDescriptor<LEFT_CHILD_TYPE> leftChild,
PatternDescriptor<RIGHT_CHILD_TYPE> rightChild) {
return new PatternDescriptor(
new TypePattern(BinaryPlan.class, leftChild.pattern, rightChild.pattern),
defaultPromise()
);
}
/* abstract logical plan patterns */
/**
* create a logicalPlan pattern.
*/
default PatternDescriptor<LogicalPlan> logicalPlan() {
return new PatternDescriptor(new TypePattern(LogicalPlan.class, multiGroup().pattern), defaultPromise());
}
/**
* create a logicalLeaf pattern.
*/
default PatternDescriptor<LogicalLeaf> logicalLeaf() {
return new PatternDescriptor(new TypePattern(LogicalLeaf.class), defaultPromise());
}
/**
* create a logicalUnary pattern.
*/
default PatternDescriptor<LogicalUnary<GroupPlan>> logicalUnary() {
return new PatternDescriptor(new TypePattern(LogicalUnary.class, Pattern.GROUP), defaultPromise());
}
/**
* create a logicalUnary pattern.
*/
default <C extends Plan> PatternDescriptor<LogicalUnary<C>>
logicalUnary(PatternDescriptor<C> child) {
return new PatternDescriptor(new TypePattern(LogicalUnary.class, child.pattern), defaultPromise());
}
/**
* create a logicalBinary pattern.
*/
default PatternDescriptor<LogicalBinary<GroupPlan, GroupPlan>> logicalBinary() {
return new PatternDescriptor(
new TypePattern(LogicalBinary.class, Pattern.GROUP, Pattern.GROUP),
defaultPromise()
);
}
/**
* create a logicalBinary pattern.
*/
default <LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Plan>
PatternDescriptor<LogicalBinary<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE>>
logicalBinary(
PatternDescriptor<LEFT_CHILD_TYPE> leftChild,
PatternDescriptor<RIGHT_CHILD_TYPE> rightChild) {
return new PatternDescriptor(
new TypePattern(LogicalBinary.class, leftChild.pattern, rightChild.pattern),
defaultPromise()
);
}
/**
* create a logicalRelation pattern.
*/
default PatternDescriptor<LogicalRelation> logicalRelation() {
return new PatternDescriptor(new TypePattern(LogicalRelation.class), defaultPromise());
}
/**
* create a logicalSetOperation pattern.
*/
default PatternDescriptor<LogicalSetOperation>
logicalSetOperation(
PatternDescriptor... children) {
return new PatternDescriptor(
new TypePattern(LogicalSetOperation.class,
Arrays.stream(children)
.map(PatternDescriptor::getPattern)
.toArray(Pattern[]::new)),
defaultPromise());
}
/**
* create a logicalSetOperation group.
*/
default PatternDescriptor<LogicalSetOperation> logicalSetOperation() {
return new PatternDescriptor(
new TypePattern(LogicalSetOperation.class, multiGroup().pattern),
defaultPromise());
}
/**
* create a logicalUnion pattern.
*/
default PatternDescriptor<LogicalUnion>
logicalUnion(
PatternDescriptor... children) {
return new PatternDescriptor(
new TypePattern(LogicalUnion.class,
Arrays.stream(children)
.map(PatternDescriptor::getPattern)
.toArray(Pattern[]::new)),
defaultPromise());
}
/**
* create a logicalUnion group.
*/
default PatternDescriptor<LogicalUnion> logicalUnion() {
return new PatternDescriptor(
new TypePattern(LogicalUnion.class, multiGroup().pattern),
defaultPromise());
}
/**
* create a logicalExcept pattern.
*/
default PatternDescriptor<LogicalExcept>
logicalExcept(
PatternDescriptor... children) {
return new PatternDescriptor(
new TypePattern(LogicalExcept.class,
Arrays.stream(children)
.map(PatternDescriptor::getPattern)
.toArray(Pattern[]::new)),
defaultPromise());
}
/**
* create a logicalExcept group.
*/
default PatternDescriptor<LogicalExcept> logicalExcept() {
return new PatternDescriptor(
new TypePattern(LogicalExcept.class, multiGroup().pattern),
defaultPromise());
}
/**
* create a logicalUnion pattern.
*/
default PatternDescriptor<LogicalIntersect>
logicalIntersect(
PatternDescriptor... children) {
return new PatternDescriptor(
new TypePattern(LogicalIntersect.class,
Arrays.stream(children)
.map(PatternDescriptor::getPattern)
.toArray(Pattern[]::new)),
defaultPromise());
}
/**
* create a logicalUnion group.
*/
default PatternDescriptor<LogicalIntersect> logicalIntersect() {
return new PatternDescriptor(
new TypePattern(LogicalIntersect.class, multiGroup().pattern),
defaultPromise());
}
/* abstract physical plan patterns */
/**
* create a physicalLeaf pattern.
*/
default PatternDescriptor<PhysicalLeaf> physicalLeaf() {
return new PatternDescriptor(new TypePattern(PhysicalLeaf.class), defaultPromise());
}
/**
* create a physicalUnary pattern.
*/
default PatternDescriptor<PhysicalUnary<GroupPlan>> physicalUnary() {
return new PatternDescriptor(new TypePattern(PhysicalUnary.class, Pattern.GROUP), defaultPromise());
}
/**
* create a physicalUnary pattern.
*/
default <C extends Plan> PatternDescriptor<PhysicalUnary<C>>
physicalUnary(PatternDescriptor<C> child) {
return new PatternDescriptor(new TypePattern(PhysicalUnary.class, child.pattern), defaultPromise());
}
/**
* create a physicalBinary pattern.
*/
default PatternDescriptor<PhysicalBinary<GroupPlan, GroupPlan>> physicalBinary() {
return new PatternDescriptor(
new TypePattern(PhysicalBinary.class, Pattern.GROUP, Pattern.GROUP),
defaultPromise()
);
}
/**
* create a physicalBinary pattern.
*/
default <LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Plan>
PatternDescriptor<PhysicalBinary<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE>>
physicalBinary(
PatternDescriptor<LEFT_CHILD_TYPE> leftChild,
PatternDescriptor<RIGHT_CHILD_TYPE> rightChild) {
return new PatternDescriptor(
new TypePattern(PhysicalBinary.class, leftChild.pattern, rightChild.pattern),
defaultPromise()
);
}
/**
* create a physicalRelation pattern.
*/
default PatternDescriptor<PhysicalRelation> physicalRelation() {
return new PatternDescriptor(new TypePattern(PhysicalRelation.class), defaultPromise());
}
/**
* create a aggregate pattern.
*/
default PatternDescriptor<Aggregate<GroupPlan>> aggregate() {
return new PatternDescriptor(new TypePattern(Aggregate.class, Pattern.GROUP), defaultPromise());
}
/**
* create a aggregate pattern.
*/
default <C extends Plan> PatternDescriptor<Aggregate<C>> aggregate(PatternDescriptor<C> child) {
return new PatternDescriptor(new TypePattern(Aggregate.class, child.pattern), defaultPromise());
}
}

View File

@ -75,7 +75,7 @@ public class Pattern<TYPE extends Plan>
* @param predicates custom matching predicate
* @param children sub pattern
*/
private Pattern(PatternType patternType, PlanType planType,
protected Pattern(PatternType patternType, PlanType planType,
List<Predicate<TYPE>> predicates, Pattern... children) {
super(children);
this.patternType = patternType;
@ -134,6 +134,35 @@ public class Pattern<TYPE extends Plan>
return patternType == PatternType.MULTI;
}
/** matchPlan */
public boolean matchPlanTree(Plan plan) {
if (!matchRoot(plan)) {
return false;
}
int childPatternNum = arity();
if (childPatternNum != plan.arity() && childPatternNum > 0 && child(childPatternNum - 1) != MULTI) {
return false;
}
switch (patternType) {
case ANY:
case MULTI:
return matchPredicates((TYPE) plan);
default:
}
if (this instanceof SubTreePattern) {
return matchPredicates((TYPE) plan);
}
List<Plan> childrenPlan = plan.children();
for (int i = 0; i < childrenPlan.size(); i++) {
Plan child = childrenPlan.get(i);
Pattern childPattern = child(Math.min(i, childPatternNum - 1));
if (!childPattern.matchPlanTree(child)) {
return false;
}
}
return matchPredicates((TYPE) plan);
}
/**
* Return ture if current Pattern match Plan in params.
*
@ -161,7 +190,13 @@ public class Pattern<TYPE extends Plan>
* @return true if all predicates matched
*/
public boolean matchPredicates(TYPE root) {
return predicates.stream().allMatch(predicate -> predicate.test(root));
// use loop to speed up
for (Predicate<TYPE> predicate : predicates) {
if (!predicate.test(root)) {
return false;
}
}
return true;
}
@Override

View File

@ -18,27 +18,7 @@
package org.apache.doris.nereids.pattern;
import org.apache.doris.nereids.rules.RulePromise;
import org.apache.doris.nereids.trees.plans.BinaryPlan;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.LeafPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.UnaryPlan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalBinary;
import org.apache.doris.nereids.trees.plans.logical.LogicalExcept;
import org.apache.doris.nereids.trees.plans.logical.LogicalIntersect;
import org.apache.doris.nereids.trees.plans.logical.LogicalLeaf;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnary;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.physical.PhysicalBinary;
import org.apache.doris.nereids.trees.plans.physical.PhysicalLeaf;
import org.apache.doris.nereids.trees.plans.physical.PhysicalRelation;
import org.apache.doris.nereids.trees.plans.physical.PhysicalUnary;
import java.util.Arrays;
/**
* An interface provided some PatternDescriptor.
@ -59,285 +39,7 @@ public interface Patterns {
return new PatternDescriptor<>(Pattern.MULTI, defaultPromise());
}
default PatternDescriptor<GroupPlan> group() {
return new PatternDescriptor<>(Pattern.GROUP, defaultPromise());
}
default PatternDescriptor<GroupPlan> multiGroup() {
return new PatternDescriptor<>(Pattern.MULTI_GROUP, defaultPromise());
}
default <T extends Plan> PatternDescriptor<T> subTree(Class<? extends Plan>... subTreeNodeTypes) {
return new PatternDescriptor<>(new SubTreePattern(subTreeNodeTypes), defaultPromise());
}
/* abstract plan operator patterns */
/**
* create a leafPlan pattern.
*/
default PatternDescriptor<LeafPlan> leafPlan() {
return new PatternDescriptor(new TypePattern(LeafPlan.class), defaultPromise());
}
/**
* create a unaryPlan pattern.
*/
default PatternDescriptor<UnaryPlan<GroupPlan>> unaryPlan() {
return new PatternDescriptor(new TypePattern(UnaryPlan.class, Pattern.GROUP), defaultPromise());
}
/**
* create a unaryPlan pattern.
*/
default <C extends Plan> PatternDescriptor<UnaryPlan<C>>
unaryPlan(PatternDescriptor<C> child) {
return new PatternDescriptor(new TypePattern(UnaryPlan.class, child.pattern), defaultPromise());
}
/**
* create a binaryPlan pattern.
*/
default PatternDescriptor<BinaryPlan<GroupPlan, GroupPlan>> binaryPlan() {
return new PatternDescriptor(
new TypePattern(BinaryPlan.class, Pattern.GROUP, Pattern.GROUP),
defaultPromise()
);
}
/**
* create a binaryPlan pattern.
*/
default <LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Plan>
PatternDescriptor<BinaryPlan<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE>> binaryPlan(
PatternDescriptor<LEFT_CHILD_TYPE> leftChild,
PatternDescriptor<RIGHT_CHILD_TYPE> rightChild) {
return new PatternDescriptor(
new TypePattern(BinaryPlan.class, leftChild.pattern, rightChild.pattern),
defaultPromise()
);
}
/* abstract logical plan patterns */
/**
* create a logicalPlan pattern.
*/
default PatternDescriptor<LogicalPlan> logicalPlan() {
return new PatternDescriptor(new TypePattern(LogicalPlan.class, multiGroup().pattern), defaultPromise());
}
/**
* create a logicalLeaf pattern.
*/
default PatternDescriptor<LogicalLeaf> logicalLeaf() {
return new PatternDescriptor(new TypePattern(LogicalLeaf.class), defaultPromise());
}
/**
* create a logicalUnary pattern.
*/
default PatternDescriptor<LogicalUnary<GroupPlan>> logicalUnary() {
return new PatternDescriptor(new TypePattern(LogicalUnary.class, Pattern.GROUP), defaultPromise());
}
/**
* create a logicalUnary pattern.
*/
default <C extends Plan> PatternDescriptor<LogicalUnary<C>>
logicalUnary(PatternDescriptor<C> child) {
return new PatternDescriptor(new TypePattern(LogicalUnary.class, child.pattern), defaultPromise());
}
/**
* create a logicalBinary pattern.
*/
default PatternDescriptor<LogicalBinary<GroupPlan, GroupPlan>> logicalBinary() {
return new PatternDescriptor(
new TypePattern(LogicalBinary.class, Pattern.GROUP, Pattern.GROUP),
defaultPromise()
);
}
/**
* create a logicalBinary pattern.
*/
default <LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Plan>
PatternDescriptor<LogicalBinary<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE>>
logicalBinary(
PatternDescriptor<LEFT_CHILD_TYPE> leftChild,
PatternDescriptor<RIGHT_CHILD_TYPE> rightChild) {
return new PatternDescriptor(
new TypePattern(LogicalBinary.class, leftChild.pattern, rightChild.pattern),
defaultPromise()
);
}
/**
* create a logicalRelation pattern.
*/
default PatternDescriptor<LogicalRelation> logicalRelation() {
return new PatternDescriptor(new TypePattern(LogicalRelation.class), defaultPromise());
}
/**
* create a logicalSetOperation pattern.
*/
default PatternDescriptor<LogicalSetOperation>
logicalSetOperation(
PatternDescriptor... children) {
return new PatternDescriptor(
new TypePattern(LogicalSetOperation.class,
Arrays.stream(children)
.map(PatternDescriptor::getPattern)
.toArray(Pattern[]::new)),
defaultPromise());
}
/**
* create a logicalSetOperation group.
*/
default PatternDescriptor<LogicalSetOperation> logicalSetOperation() {
return new PatternDescriptor(
new TypePattern(LogicalSetOperation.class, multiGroup().pattern),
defaultPromise());
}
/**
* create a logicalUnion pattern.
*/
default PatternDescriptor<LogicalUnion>
logicalUnion(
PatternDescriptor... children) {
return new PatternDescriptor(
new TypePattern(LogicalUnion.class,
Arrays.stream(children)
.map(PatternDescriptor::getPattern)
.toArray(Pattern[]::new)),
defaultPromise());
}
/**
* create a logicalUnion group.
*/
default PatternDescriptor<LogicalUnion> logicalUnion() {
return new PatternDescriptor(
new TypePattern(LogicalUnion.class, multiGroup().pattern),
defaultPromise());
}
/**
* create a logicalExcept pattern.
*/
default PatternDescriptor<LogicalExcept>
logicalExcept(
PatternDescriptor... children) {
return new PatternDescriptor(
new TypePattern(LogicalExcept.class,
Arrays.stream(children)
.map(PatternDescriptor::getPattern)
.toArray(Pattern[]::new)),
defaultPromise());
}
/**
* create a logicalExcept group.
*/
default PatternDescriptor<LogicalExcept> logicalExcept() {
return new PatternDescriptor(
new TypePattern(LogicalExcept.class, multiGroup().pattern),
defaultPromise());
}
/**
* create a logicalUnion pattern.
*/
default PatternDescriptor<LogicalIntersect>
logicalIntersect(
PatternDescriptor... children) {
return new PatternDescriptor(
new TypePattern(LogicalIntersect.class,
Arrays.stream(children)
.map(PatternDescriptor::getPattern)
.toArray(Pattern[]::new)),
defaultPromise());
}
/**
* create a logicalUnion group.
*/
default PatternDescriptor<LogicalIntersect> logicalIntersect() {
return new PatternDescriptor(
new TypePattern(LogicalIntersect.class, multiGroup().pattern),
defaultPromise());
}
/* abstract physical plan patterns */
/**
* create a physicalLeaf pattern.
*/
default PatternDescriptor<PhysicalLeaf> physicalLeaf() {
return new PatternDescriptor(new TypePattern(PhysicalLeaf.class), defaultPromise());
}
/**
* create a physicalUnary pattern.
*/
default PatternDescriptor<PhysicalUnary<GroupPlan>> physicalUnary() {
return new PatternDescriptor(new TypePattern(PhysicalUnary.class, Pattern.GROUP), defaultPromise());
}
/**
* create a physicalUnary pattern.
*/
default <C extends Plan> PatternDescriptor<PhysicalUnary<C>>
physicalUnary(PatternDescriptor<C> child) {
return new PatternDescriptor(new TypePattern(PhysicalUnary.class, child.pattern), defaultPromise());
}
/**
* create a physicalBinary pattern.
*/
default PatternDescriptor<PhysicalBinary<GroupPlan, GroupPlan>> physicalBinary() {
return new PatternDescriptor(
new TypePattern(PhysicalBinary.class, Pattern.GROUP, Pattern.GROUP),
defaultPromise()
);
}
/**
* create a physicalBinary pattern.
*/
default <LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Plan>
PatternDescriptor<PhysicalBinary<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE>>
physicalBinary(
PatternDescriptor<LEFT_CHILD_TYPE> leftChild,
PatternDescriptor<RIGHT_CHILD_TYPE> rightChild) {
return new PatternDescriptor(
new TypePattern(PhysicalBinary.class, leftChild.pattern, rightChild.pattern),
defaultPromise()
);
}
/**
* create a physicalRelation pattern.
*/
default PatternDescriptor<PhysicalRelation> physicalRelation() {
return new PatternDescriptor(new TypePattern(PhysicalRelation.class), defaultPromise());
}
/**
* create a aggregate pattern.
*/
default PatternDescriptor<Aggregate<GroupPlan>> aggregate() {
return new PatternDescriptor(new TypePattern(Aggregate.class, Pattern.GROUP), defaultPromise());
}
/**
* create a aggregate pattern.
*/
default <C extends Plan> PatternDescriptor<Aggregate<C>> aggregate(PatternDescriptor<C> child) {
return new PatternDescriptor(new TypePattern(Aggregate.class, child.pattern), defaultPromise());
}
}

View File

@ -0,0 +1,314 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.pattern;
import org.apache.doris.nereids.trees.plans.BinaryPlan;
import org.apache.doris.nereids.trees.plans.LeafPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.UnaryPlan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalBinary;
import org.apache.doris.nereids.trees.plans.logical.LogicalExcept;
import org.apache.doris.nereids.trees.plans.logical.LogicalIntersect;
import org.apache.doris.nereids.trees.plans.logical.LogicalLeaf;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnary;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.physical.PhysicalBinary;
import org.apache.doris.nereids.trees.plans.physical.PhysicalLeaf;
import org.apache.doris.nereids.trees.plans.physical.PhysicalRelation;
import org.apache.doris.nereids.trees.plans.physical.PhysicalUnary;
import java.util.Arrays;
/** PlanPatterns */
public interface PlanPatterns extends Patterns {
/* abstract plan operator patterns */
/**
* create a leafPlan pattern.
*/
default PatternDescriptor<LeafPlan> leafPlan() {
return new PatternDescriptor(new TypePattern(LeafPlan.class), defaultPromise());
}
/**
* create a unaryPlan pattern.
*/
default PatternDescriptor<UnaryPlan<Plan>> unaryPlan() {
return new PatternDescriptor(new TypePattern(UnaryPlan.class, Pattern.ANY), defaultPromise());
}
/**
* create a unaryPlan pattern.
*/
default <C extends Plan> PatternDescriptor<UnaryPlan<C>>
unaryPlan(PatternDescriptor<C> child) {
return new PatternDescriptor(new TypePattern(UnaryPlan.class, child.pattern), defaultPromise());
}
/**
* create a binaryPlan pattern.
*/
default PatternDescriptor<BinaryPlan<Plan, Plan>> binaryPlan() {
return new PatternDescriptor(
new TypePattern(BinaryPlan.class, Pattern.ANY, Pattern.ANY),
defaultPromise()
);
}
/**
* create a binaryPlan pattern.
*/
default <LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Plan>
PatternDescriptor<BinaryPlan<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE>> binaryPlan(
PatternDescriptor<LEFT_CHILD_TYPE> leftChild,
PatternDescriptor<RIGHT_CHILD_TYPE> rightChild) {
return new PatternDescriptor(
new TypePattern(BinaryPlan.class, leftChild.pattern, rightChild.pattern),
defaultPromise()
);
}
/* abstract logical plan patterns */
/**
* create a logicalPlan pattern.
*/
default PatternDescriptor<LogicalPlan> logicalPlan() {
return new PatternDescriptor(new TypePattern(LogicalPlan.class, multi().pattern), defaultPromise());
}
/**
* create a logicalLeaf pattern.
*/
default PatternDescriptor<LogicalLeaf> logicalLeaf() {
return new PatternDescriptor(new TypePattern(LogicalLeaf.class), defaultPromise());
}
/**
* create a logicalUnary pattern.
*/
default PatternDescriptor<LogicalUnary<Plan>> logicalUnary() {
return new PatternDescriptor(new TypePattern(LogicalUnary.class, Pattern.ANY), defaultPromise());
}
/**
* create a logicalUnary pattern.
*/
default <C extends Plan> PatternDescriptor<LogicalUnary<C>>
logicalUnary(PatternDescriptor<C> child) {
return new PatternDescriptor(new TypePattern(LogicalUnary.class, child.pattern), defaultPromise());
}
/**
* create a logicalBinary pattern.
*/
default PatternDescriptor<LogicalBinary<Plan, Plan>> logicalBinary() {
return new PatternDescriptor(
new TypePattern(LogicalBinary.class, Pattern.ANY, Pattern.ANY),
defaultPromise()
);
}
/**
* create a logicalBinary pattern.
*/
default <LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Plan>
PatternDescriptor<LogicalBinary<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE>>
logicalBinary(
PatternDescriptor<LEFT_CHILD_TYPE> leftChild,
PatternDescriptor<RIGHT_CHILD_TYPE> rightChild) {
return new PatternDescriptor(
new TypePattern(LogicalBinary.class, leftChild.pattern, rightChild.pattern),
defaultPromise()
);
}
/**
* create a logicalRelation pattern.
*/
default PatternDescriptor<LogicalRelation> logicalRelation() {
return new PatternDescriptor(new TypePattern(LogicalRelation.class), defaultPromise());
}
/**
* create a logicalSetOperation pattern.
*/
default PatternDescriptor<LogicalSetOperation>
logicalSetOperation(
PatternDescriptor... children) {
return new PatternDescriptor(
new TypePattern(LogicalSetOperation.class,
Arrays.stream(children)
.map(PatternDescriptor::getPattern)
.toArray(Pattern[]::new)),
defaultPromise());
}
/**
* create a logicalSetOperation multi.
*/
default PatternDescriptor<LogicalSetOperation> logicalSetOperation() {
return new PatternDescriptor(
new TypePattern(LogicalSetOperation.class, multi().pattern),
defaultPromise());
}
/**
* create a logicalUnion pattern.
*/
default PatternDescriptor<LogicalUnion>
logicalUnion(
PatternDescriptor... children) {
return new PatternDescriptor(
new TypePattern(LogicalUnion.class,
Arrays.stream(children)
.map(PatternDescriptor::getPattern)
.toArray(Pattern[]::new)),
defaultPromise());
}
/**
* create a logicalUnion multi.
*/
default PatternDescriptor<LogicalUnion> logicalUnion() {
return new PatternDescriptor(
new TypePattern(LogicalUnion.class, multi().pattern),
defaultPromise());
}
/**
* create a logicalExcept pattern.
*/
default PatternDescriptor<LogicalExcept>
logicalExcept(
PatternDescriptor... children) {
return new PatternDescriptor(
new TypePattern(LogicalExcept.class,
Arrays.stream(children)
.map(PatternDescriptor::getPattern)
.toArray(Pattern[]::new)),
defaultPromise());
}
/**
* create a logicalExcept multi.
*/
default PatternDescriptor<LogicalExcept> logicalExcept() {
return new PatternDescriptor(
new TypePattern(LogicalExcept.class, multi().pattern),
defaultPromise());
}
/**
* create a logicalUnion pattern.
*/
default PatternDescriptor<LogicalIntersect>
logicalIntersect(
PatternDescriptor... children) {
return new PatternDescriptor(
new TypePattern(LogicalIntersect.class,
Arrays.stream(children)
.map(PatternDescriptor::getPattern)
.toArray(Pattern[]::new)),
defaultPromise());
}
/**
* create a logicalUnion multi.
*/
default PatternDescriptor<LogicalIntersect> logicalIntersect() {
return new PatternDescriptor(
new TypePattern(LogicalIntersect.class, multi().pattern),
defaultPromise());
}
/* abstract physical plan patterns */
/**
* create a physicalLeaf pattern.
*/
default PatternDescriptor<PhysicalLeaf> physicalLeaf() {
return new PatternDescriptor(new TypePattern(PhysicalLeaf.class), defaultPromise());
}
/**
* create a physicalUnary pattern.
*/
default PatternDescriptor<PhysicalUnary<Plan>> physicalUnary() {
return new PatternDescriptor(new TypePattern(PhysicalUnary.class, Pattern.ANY), defaultPromise());
}
/**
* create a physicalUnary pattern.
*/
default <C extends Plan> PatternDescriptor<PhysicalUnary<C>>
physicalUnary(PatternDescriptor<C> child) {
return new PatternDescriptor(new TypePattern(PhysicalUnary.class, child.pattern), defaultPromise());
}
/**
* create a physicalBinary pattern.
*/
default PatternDescriptor<PhysicalBinary<Plan, Plan>> physicalBinary() {
return new PatternDescriptor(
new TypePattern(PhysicalBinary.class, Pattern.ANY, Pattern.ANY),
defaultPromise()
);
}
/**
* create a physicalBinary pattern.
*/
default <LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Plan>
PatternDescriptor<PhysicalBinary<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE>>
physicalBinary(
PatternDescriptor<LEFT_CHILD_TYPE> leftChild,
PatternDescriptor<RIGHT_CHILD_TYPE> rightChild) {
return new PatternDescriptor(
new TypePattern(PhysicalBinary.class, leftChild.pattern, rightChild.pattern),
defaultPromise()
);
}
/**
* create a physicalRelation pattern.
*/
default PatternDescriptor<PhysicalRelation> physicalRelation() {
return new PatternDescriptor(new TypePattern(PhysicalRelation.class), defaultPromise());
}
/**
* create a aggregate pattern.
*/
default PatternDescriptor<Aggregate<Plan>> aggregate() {
return new PatternDescriptor(new TypePattern(Aggregate.class, Pattern.ANY), defaultPromise());
}
/**
* create a aggregate pattern.
*/
default <C extends Plan> PatternDescriptor<Aggregate<C>> aggregate(PatternDescriptor<C> child) {
return new PatternDescriptor(new TypePattern(Aggregate.class, child.pattern), defaultPromise());
}
}

View File

@ -0,0 +1,45 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.pattern;
import org.apache.doris.nereids.trees.plans.Plan;
/** ProxyPattern */
public class ProxyPattern<TYPE extends Plan> extends Pattern<TYPE> {
protected final Pattern pattern;
public ProxyPattern(Pattern pattern) {
super(pattern.getPlanType(), pattern.children());
this.pattern = pattern;
}
@Override
public boolean matchPlanTree(Plan plan) {
return pattern.matchPlanTree(plan);
}
@Override
public boolean matchRoot(Plan plan) {
return pattern.matchRoot(plan);
}
@Override
public boolean matchPredicates(TYPE root) {
return pattern.matchPredicates(root);
}
}

View File

@ -26,13 +26,13 @@ import java.util.TreeSet;
public class LogicalBinaryPatternGenerator extends PatternGenerator {
public LogicalBinaryPatternGenerator(PatternGeneratorAnalyzer analyzer,
ClassDeclaration opType, Set<String> parentClass) {
super(analyzer, opType, parentClass);
ClassDeclaration opType, Set<String> parentClass, boolean isMemoPattern) {
super(analyzer, opType, parentClass, isMemoPattern);
}
@Override
public String genericType() {
return "<" + opType.name + "<GroupPlan, GroupPlan>>";
return "<" + opType.name + "<" + childType() + ", " + childType() + ">>";
}
@Override
@ -44,7 +44,9 @@ public class LogicalBinaryPatternGenerator extends PatternGenerator {
public Set<String> getImports() {
Set<String> imports = new TreeSet<>();
imports.add(opType.getFullQualifiedName());
imports.add("org.apache.doris.nereids.trees.plans.GroupPlan");
if (isMemoPattern) {
imports.add("org.apache.doris.nereids.trees.plans.GroupPlan");
}
imports.add("org.apache.doris.nereids.trees.plans.Plan");
enumFieldPatternInfos.stream()
.map(info -> info.enumFullName)

View File

@ -26,8 +26,8 @@ import java.util.TreeSet;
public class LogicalLeafPatternGenerator extends PatternGenerator {
public LogicalLeafPatternGenerator(PatternGeneratorAnalyzer analyzer,
ClassDeclaration opType, Set<String> parentClass) {
super(analyzer, opType, parentClass);
ClassDeclaration opType, Set<String> parentClass, boolean isMemoPattern) {
super(analyzer, opType, parentClass, isMemoPattern);
}
@Override

View File

@ -26,13 +26,13 @@ import java.util.TreeSet;
public class LogicalUnaryPatternGenerator extends PatternGenerator {
public LogicalUnaryPatternGenerator(PatternGeneratorAnalyzer analyzer,
ClassDeclaration opType, Set<String> parentClass) {
super(analyzer, opType, parentClass);
ClassDeclaration opType, Set<String> parentClass, boolean isMemoPattern) {
super(analyzer, opType, parentClass, isMemoPattern);
}
@Override
public String genericType() {
return "<" + opType.name + "<GroupPlan>>";
return "<" + opType.name + "<" + childType() + ">>";
}
@Override
@ -44,7 +44,9 @@ public class LogicalUnaryPatternGenerator extends PatternGenerator {
public Set<String> getImports() {
Set<String> imports = new TreeSet<>();
imports.add(opType.getFullQualifiedName());
imports.add("org.apache.doris.nereids.trees.plans.GroupPlan");
if (isMemoPattern) {
imports.add("org.apache.doris.nereids.trees.plans.GroupPlan");
}
imports.add("org.apache.doris.nereids.trees.plans.Plan");
enumFieldPatternInfos.stream()
.map(info -> info.enumFullName)

View File

@ -86,21 +86,9 @@ public class PatternDescribableProcessor extends AbstractProcessor {
List<TypeDeclaration> asts = parseJavaFile(file);
patternGeneratorAnalyzer.addAsts(asts);
}
String generatePatternCode = patternGeneratorAnalyzer.generatePatterns();
File generatePatternFile = new File(processingEnv.getFiler()
.getResource(StandardLocation.SOURCE_OUTPUT, "org.apache.doris.nereids.pattern",
"GeneratedPatterns.java").toUri());
if (generatePatternFile.exists()) {
generatePatternFile.delete();
}
if (!generatePatternFile.getParentFile().exists()) {
generatePatternFile.getParentFile().mkdirs();
}
// bypass create file for processingEnv.getFiler(), compile GeneratePatterns in next compile term
try (BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(generatePatternFile))) {
bufferedWriter.write(generatePatternCode);
}
doGenerate("GeneratedMemoPatterns", "MemoPatterns", true, patternGeneratorAnalyzer);
doGenerate("GeneratedPlanPatterns", "PlanPatterns", false, patternGeneratorAnalyzer);
} catch (Throwable t) {
String exceptionMsg = Throwables.getStackTraceAsString(t);
processingEnv.getMessager().printMessage(Kind.ERROR,
@ -109,6 +97,26 @@ public class PatternDescribableProcessor extends AbstractProcessor {
return false;
}
private void doGenerate(String className, String parentClassName, boolean isMemoPattern,
PatternGeneratorAnalyzer patternGeneratorAnalyzer) throws IOException {
String generatePatternCode = patternGeneratorAnalyzer.generatePatterns(
className, parentClassName, isMemoPattern);
File generatePatternFile = new File(processingEnv.getFiler()
.getResource(StandardLocation.SOURCE_OUTPUT, "org.apache.doris.nereids.pattern",
className + ".java").toUri());
if (generatePatternFile.exists()) {
generatePatternFile.delete();
}
if (!generatePatternFile.getParentFile().exists()) {
generatePatternFile.getParentFile().mkdirs();
}
// bypass create file for processingEnv.getFiler(), compile GeneratePatterns in next compile term
try (BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(generatePatternFile))) {
bufferedWriter.write(generatePatternCode);
}
}
private List<File> findJavaFiles(List<File> dirs) {
List<File> files = new ArrayList<>();
for (File dir : dirs) {

View File

@ -50,13 +50,16 @@ public abstract class PatternGenerator {
protected final Set<String> parentClass;
protected final List<EnumFieldPatternInfo> enumFieldPatternInfos;
protected final List<String> generatePatterns = new ArrayList<>();
protected final boolean isMemoPattern;
/** constructor. */
public PatternGenerator(PatternGeneratorAnalyzer analyzer, ClassDeclaration opType, Set<String> parentClass) {
public PatternGenerator(PatternGeneratorAnalyzer analyzer, ClassDeclaration opType,
Set<String> parentClass, boolean isMemoPattern) {
this.analyzer = analyzer;
this.opType = opType;
this.parentClass = parentClass;
this.enumFieldPatternInfos = getEnumFieldPatternInfos();
this.isMemoPattern = isMemoPattern;
}
public abstract String genericType();
@ -74,7 +77,8 @@ public abstract class PatternGenerator {
}
/** generate code by generators and analyzer. */
public static String generateCode(List<PatternGenerator> generators, PatternGeneratorAnalyzer analyzer) {
public static String generateCode(String className, String parentClassName, List<PatternGenerator> generators,
PatternGeneratorAnalyzer analyzer, boolean isMemoPattern) {
String generateCode
= "// Licensed to the Apache Software Foundation (ASF) under one\n"
+ "// or more contributor license agreements. See the NOTICE file\n"
@ -97,11 +101,10 @@ public abstract class PatternGenerator {
+ "\n"
+ generateImports(generators)
+ "\n";
generateCode += "public interface GeneratedPatterns extends Patterns {\n";
generateCode += "public interface " + className + " extends " + parentClassName + " {\n";
generateCode += generators.stream()
.map(generator -> {
String patternMethods = generator.generate();
String patternMethods = generator.generate(isMemoPattern);
// add indent
return Arrays.stream(patternMethods.split("\n"))
.map(line -> " " + line + "\n")
@ -199,21 +202,25 @@ public abstract class PatternGenerator {
return parts;
}
protected String childType() {
return isMemoPattern ? "GroupPlan" : "Plan";
}
/** create generator by plan's type. */
public static Optional<PatternGenerator> create(PatternGeneratorAnalyzer analyzer,
ClassDeclaration opType, Set<String> parentClass) {
ClassDeclaration opType, Set<String> parentClass, boolean isMemoPattern) {
if (parentClass.contains("org.apache.doris.nereids.trees.plans.logical.LogicalLeaf")) {
return Optional.of(new LogicalLeafPatternGenerator(analyzer, opType, parentClass));
return Optional.of(new LogicalLeafPatternGenerator(analyzer, opType, parentClass, isMemoPattern));
} else if (parentClass.contains("org.apache.doris.nereids.trees.plans.logical.LogicalUnary")) {
return Optional.of(new LogicalUnaryPatternGenerator(analyzer, opType, parentClass));
return Optional.of(new LogicalUnaryPatternGenerator(analyzer, opType, parentClass, isMemoPattern));
} else if (parentClass.contains("org.apache.doris.nereids.trees.plans.logical.LogicalBinary")) {
return Optional.of(new LogicalBinaryPatternGenerator(analyzer, opType, parentClass));
return Optional.of(new LogicalBinaryPatternGenerator(analyzer, opType, parentClass, isMemoPattern));
} else if (parentClass.contains("org.apache.doris.nereids.trees.plans.physical.PhysicalLeaf")) {
return Optional.of(new PhysicalLeafPatternGenerator(analyzer, opType, parentClass));
return Optional.of(new PhysicalLeafPatternGenerator(analyzer, opType, parentClass, isMemoPattern));
} else if (parentClass.contains("org.apache.doris.nereids.trees.plans.physical.PhysicalUnary")) {
return Optional.of(new PhysicalUnaryPatternGenerator(analyzer, opType, parentClass));
return Optional.of(new PhysicalUnaryPatternGenerator(analyzer, opType, parentClass, isMemoPattern));
} else if (parentClass.contains("org.apache.doris.nereids.trees.plans.physical.PhysicalBinary")) {
return Optional.of(new PhysicalBinaryPatternGenerator(analyzer, opType, parentClass));
return Optional.of(new PhysicalBinaryPatternGenerator(analyzer, opType, parentClass, isMemoPattern));
} else {
return Optional.empty();
}
@ -233,21 +240,24 @@ public abstract class PatternGenerator {
}
/** generate some pattern method code. */
public String generate() {
public String generate(boolean isMemoPattern) {
String opClassName = opType.name;
String methodName = getPatternMethodName();
generateTypePattern(methodName, opClassName, genericType(), "", false);
generateTypePattern(methodName, opClassName, genericType(), "", false, isMemoPattern);
if (childrenNum() > 0) {
generateTypePattern(methodName, opClassName, genericTypeWithChildren(), "", true);
generateTypePattern(methodName, opClassName, genericTypeWithChildren(),
"", true, isMemoPattern);
}
for (EnumFieldPatternInfo info : enumFieldPatternInfos) {
String predicate = ".when(p -> p." + info.enumInstanceGetter + "() == "
+ info.enumType + "." + info.enumInstance + ")";
generateTypePattern(info.patternName, opClassName, genericType(), predicate, false);
generateTypePattern(info.patternName, opClassName, genericType(),
predicate, false, isMemoPattern);
if (childrenNum() > 0) {
generateTypePattern(info.patternName, opClassName, genericTypeWithChildren(), predicate, true);
generateTypePattern(info.patternName, opClassName, genericTypeWithChildren(),
predicate, true, isMemoPattern);
}
}
return generatePatterns();
@ -255,7 +265,7 @@ public abstract class PatternGenerator {
/** generate a pattern method code. */
public String generateTypePattern(String patterName, String className,
String genericParam, String predicate, boolean specifyChildren) {
String genericParam, String predicate, boolean specifyChildren, boolean isMemoPattern) {
int childrenNum = childrenNum();
@ -286,7 +296,8 @@ public abstract class PatternGenerator {
generatePatterns.add(pattern);
return pattern;
} else {
String childrenPattern = StringUtils.repeat("Pattern.GROUP", ", ", childrenNum);
String childrenPattern = StringUtils.repeat(
isMemoPattern ? "Pattern.GROUP" : "Pattern.ANY", ", ", childrenNum);
if (childrenNum > 0) {
childrenPattern = ", " + childrenPattern;
}

View File

@ -57,10 +57,10 @@ public class PatternGeneratorAnalyzer {
}
/** generate pattern methods. */
public String generatePatterns() {
public String generatePatterns(String className, String parentClassName, boolean isMemoPattern) {
analyzeImport();
analyzeParentClass();
return doGenerate();
return doGenerate(className, parentClassName, isMemoPattern);
}
Optional<TypeDeclaration> getType(TypeDeclaration typeDeclaration, TypeType type) {
@ -73,7 +73,7 @@ public class PatternGeneratorAnalyzer {
return Optional.empty();
}
private String doGenerate() {
private String doGenerate(String className, String parentClassName, boolean isMemoPattern) {
Map<ClassDeclaration, Set<String>> planClassMap = parentClassMap.entrySet().stream()
.filter(kv -> kv.getValue().contains("org.apache.doris.nereids.trees.plans.Plan"))
.filter(kv -> !kv.getKey().name.equals("GroupPlan"))
@ -83,7 +83,7 @@ public class PatternGeneratorAnalyzer {
List<PatternGenerator> generators = planClassMap.entrySet()
.stream()
.map(kv -> PatternGenerator.create(this, kv.getKey(), kv.getValue()))
.map(kv -> PatternGenerator.create(this, kv.getKey(), kv.getValue(), isMemoPattern))
.filter(Optional::isPresent)
.map(Optional::get)
.sorted((g1, g2) -> {
@ -100,7 +100,7 @@ public class PatternGeneratorAnalyzer {
})
.collect(Collectors.toList());
return PatternGenerator.generateCode(generators, this);
return PatternGenerator.generateCode(className, parentClassName, generators, this, isMemoPattern);
}
private void analyzeImport() {

View File

@ -26,13 +26,13 @@ import java.util.TreeSet;
public class PhysicalBinaryPatternGenerator extends PatternGenerator {
public PhysicalBinaryPatternGenerator(PatternGeneratorAnalyzer analyzer,
ClassDeclaration opType, Set<String> parentClass) {
super(analyzer, opType, parentClass);
ClassDeclaration opType, Set<String> parentClass, boolean isMemoPattern) {
super(analyzer, opType, parentClass, isMemoPattern);
}
@Override
public String genericType() {
return "<" + opType.name + "<GroupPlan, GroupPlan>>";
return "<" + opType.name + "<" + childType() + ", " + childType() + ">>";
}
@Override
@ -44,7 +44,9 @@ public class PhysicalBinaryPatternGenerator extends PatternGenerator {
public Set<String> getImports() {
Set<String> imports = new TreeSet<>();
imports.add(opType.getFullQualifiedName());
imports.add("org.apache.doris.nereids.trees.plans.GroupPlan");
if (isMemoPattern) {
imports.add("org.apache.doris.nereids.trees.plans.GroupPlan");
}
imports.add("org.apache.doris.nereids.trees.plans.Plan");
enumFieldPatternInfos.stream()
.map(info -> info.enumFullName)

View File

@ -26,8 +26,8 @@ import java.util.TreeSet;
public class PhysicalLeafPatternGenerator extends PatternGenerator {
public PhysicalLeafPatternGenerator(PatternGeneratorAnalyzer analyzer,
ClassDeclaration opType, Set<String> parentClass) {
super(analyzer, opType, parentClass);
ClassDeclaration opType, Set<String> parentClass, boolean isMemoPattern) {
super(analyzer, opType, parentClass, isMemoPattern);
}
@Override

View File

@ -26,13 +26,13 @@ import java.util.TreeSet;
public class PhysicalUnaryPatternGenerator extends PatternGenerator {
public PhysicalUnaryPatternGenerator(PatternGeneratorAnalyzer analyzer,
ClassDeclaration opType, Set<String> parentClass) {
super(analyzer, opType, parentClass);
ClassDeclaration opType, Set<String> parentClass, boolean isMemoPattern) {
super(analyzer, opType, parentClass, isMemoPattern);
}
@Override
public String genericType() {
return "<" + opType.name + "<GroupPlan>>";
return "<" + opType.name + "<" + childType() + ">>";
}
@Override
@ -44,7 +44,9 @@ public class PhysicalUnaryPatternGenerator extends PatternGenerator {
public Set<String> getImports() {
Set<String> imports = new TreeSet<>();
imports.add(opType.getFullQualifiedName());
imports.add("org.apache.doris.nereids.trees.plans.GroupPlan");
if (isMemoPattern) {
imports.add("org.apache.doris.nereids.trees.plans.GroupPlan");
}
imports.add("org.apache.doris.nereids.trees.plans.Plan");
enumFieldPatternInfos.stream()
.map(info -> info.enumFullName)

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.processor.post;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
@ -56,7 +57,10 @@ public class Validator extends PlanPostProcessor {
Plan child = filter.child();
// Forbidden filter-project, we must make filter-project -> project-filter.
Preconditions.checkState(!(child instanceof PhysicalProject));
if (child instanceof PhysicalProject) {
throw new AnalysisException(
"Nereids generate a filter-project plan, but backend not support:\n" + filter.treeString());
}
// Check filter is from child output.
Set<Slot> childOutputSet = child.getOutputSet();

View File

@ -0,0 +1,111 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.TransformException;
import org.apache.doris.nereids.pattern.Pattern;
import org.apache.doris.nereids.pattern.ProxyPattern;
import org.apache.doris.nereids.trees.plans.Plan;
import java.util.BitSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.BiPredicate;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
/** AppliedAwareRule */
public class AppliedAwareRule extends Rule {
private static final String APPLIED_RULES_KEY = "applied_rules";
private static final Supplier<BitSet> CREATE_APPLIED_RULES = () -> new BitSet(RuleType.SENTINEL.ordinal());
protected final Rule rule;
protected final RuleType ruleType;
protected int ruleTypeIndex;
private AppliedAwareRule(Rule rule, BiPredicate<Rule, Plan> matchRootPredicate) {
super(rule.getRuleType(),
new ExtendPattern(rule.getPattern(), (Predicate<Plan>) (plan -> matchRootPredicate.test(rule, plan))),
rule.getRulePromise());
this.rule = rule;
this.ruleType = rule.getRuleType();
this.ruleTypeIndex = rule.getRuleType().ordinal();
}
@Override
public List<Plan> transform(Plan plan, CascadesContext context) throws TransformException {
return rule.transform(plan, context);
}
@Override
public void acceptPlan(Plan plan) {
BitSet appliedRules = plan.getOrInitMutableState(APPLIED_RULES_KEY, CREATE_APPLIED_RULES);
appliedRules.set(ruleTypeIndex);
}
/**
* AppliedAwareRuleCondition: convert one rule to AppliedAwareRule, so that the rule can add
* some condition depends on whether this rule is applied to some plan
*/
public static class AppliedAwareRuleCondition implements Function<Rule, Rule> {
public Rule apply(Rule rule) {
return new AppliedAwareRule(rule, this::condition);
}
/** provide this method for the child class get the applied state */
public final boolean isAppliedRule(Rule rule, Plan plan) {
Optional<BitSet> appliedRules = plan.getMutableState("applied_rules");
if (!appliedRules.isPresent()) {
return false;
}
return appliedRules.get().get(rule.getRuleType().ordinal());
}
/**
* the default condition is whether this rule already applied to a plan,
* this means one plan only apply for a rule only once. child class can
* override this method.
*/
protected boolean condition(Rule rule, Plan plan) {
return isAppliedRule(rule, plan);
}
}
private static class ExtendPattern<TYPE extends Plan> extends ProxyPattern<TYPE> {
private final Predicate<Plan> matchRootPredicate;
public ExtendPattern(Pattern pattern, Predicate<Plan> matchRootPredicate) {
super(pattern);
this.matchRootPredicate = Objects.requireNonNull(matchRootPredicate, "matchRootPredicate cannot be null");
}
@Override
public boolean matchPlanTree(Plan plan) {
return matchRootPredicate.test(plan) && super.matchPlanTree(plan);
}
@Override
public boolean matchRoot(Plan plan) {
return matchRootPredicate.test(plan) && super.matchRoot(plan);
}
}
}

View File

@ -0,0 +1,45 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.TransformException;
import org.apache.doris.nereids.pattern.Pattern;
import org.apache.doris.nereids.trees.plans.Plan;
import java.util.List;
import java.util.Objects;
/** ProxyRule */
public class ProxyRule extends Rule {
protected final Rule rule;
public ProxyRule(Rule rule) {
this(rule, rule.getRuleType(), rule.getPattern(), rule.getRulePromise());
}
public ProxyRule(Rule rule, RuleType ruleType, Pattern pattern, RulePromise rulePromise) {
super(ruleType, pattern, rulePromise);
this.rule = Objects.requireNonNull(rule, "rule cannot be null");
}
@Override
public List<Plan> transform(Plan node, CascadesContext context) throws TransformException {
return rule.transform(node, context);
}
}

View File

@ -68,4 +68,9 @@ public abstract class Rule {
}
public abstract List<Plan> transform(Plan node, CascadesContext context) throws TransformException;
/** callback this function when the traverse framework accept a new plan which produce by this rule */
public void acceptPlan(Plan plan) {
}
}

View File

@ -17,14 +17,14 @@
package org.apache.doris.nereids.rules;
import org.apache.doris.nereids.pattern.GeneratedPatterns;
import org.apache.doris.nereids.pattern.Patterns;
import java.util.List;
/**
* interface for all rule factories for build some rules.
*/
public interface RuleFactory extends GeneratedPatterns {
public interface RuleFactory extends Patterns {
// need implement
List<Rule> buildRules();

View File

@ -93,7 +93,6 @@ public enum RuleType {
CHECK_AND_STANDARDIZE_WINDOW_FUNCTION_AND_FRAME(RuleTypeClass.REWRITE),
AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE),
DISTINCT_AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE),
MARK_NECESSARY_PROJECT(RuleTypeClass.REWRITE),
LOGICAL_SUB_QUERY_ALIAS_TO_LOGICAL_PROJECT(RuleTypeClass.REWRITE),
ELIMINATE_GROUP_BY_CONSTANT(RuleTypeClass.REWRITE),
ELIMINATE_ORDER_BY_CONSTANT(RuleTypeClass.REWRITE),
@ -101,16 +100,17 @@ public enum RuleType {
INFER_FILTER_NOT_NULL(RuleTypeClass.REWRITE),
INFER_JOIN_NOT_NULL(RuleTypeClass.REWRITE),
// subquery analyze
ANALYZE_FILTER_SUBQUERY(RuleTypeClass.REWRITE),
ANALYZE_PROJECT_SUBQUERY(RuleTypeClass.REWRITE),
FILTER_SUBQUERY_TO_APPLY(RuleTypeClass.REWRITE),
PROJECT_SUBQUERY_TO_APPLY(RuleTypeClass.REWRITE),
// subquery rewrite rule
ELIMINATE_LIMIT_UNDER_APPLY(RuleTypeClass.REWRITE),
ELIMINATE_SORT_UNDER_APPLY(RuleTypeClass.REWRITE),
PUSH_APPLY_UNDER_PROJECT(RuleTypeClass.REWRITE),
PUSH_APPLY_UNDER_FILTER(RuleTypeClass.REWRITE),
ELIMINATE_FILTER_UNDER_APPLY_PROJECT(RuleTypeClass.REWRITE),
APPLY_PULL_FILTER_ON_AGG(RuleTypeClass.REWRITE),
APPLY_PULL_FILTER_ON_PROJECT_UNDER_AGG(RuleTypeClass.REWRITE),
ELIMINATE_SORT_UNDER_APPLY_PROJECT(RuleTypeClass.REWRITE),
PULL_UP_PROJECT_UNDER_APPLY(RuleTypeClass.REWRITE),
UN_CORRELATED_APPLY_FILTER(RuleTypeClass.REWRITE),
UN_CORRELATED_APPLY_PROJECT_FILTER(RuleTypeClass.REWRITE),
UN_CORRELATED_APPLY_AGGREGATE_FILTER(RuleTypeClass.REWRITE),
PULL_UP_CORRELATED_FILTER_UNDER_APPLY_AGGREGATE_PROJECT(RuleTypeClass.REWRITE),
SCALAR_APPLY_TO_JOIN(RuleTypeClass.REWRITE),
IN_APPLY_TO_JOIN(RuleTypeClass.REWRITE),
EXISTS_APPLY_TO_JOIN(RuleTypeClass.REWRITE),
@ -124,6 +124,7 @@ public enum RuleType {
PUSHDOWN_FILTER_THROUGH_LEFT_SEMI_JOIN(RuleTypeClass.REWRITE),
PUSH_FILTER_INSIDE_JOIN(RuleTypeClass.REWRITE),
PUSHDOWN_FILTER_THROUGH_PROJECT(RuleTypeClass.REWRITE),
PUSHDOWN_FILTER_THROUGH_PROJECT_UNDER_LIMIT(RuleTypeClass.REWRITE),
PUSHDOWN_PROJECT_THROUGH_LIMIT(RuleTypeClass.REWRITE),
PUSHDOWN_FILTER_THROUGH_SET_OPERATION(RuleTypeClass.REWRITE),
// column prune rules,

View File

@ -17,13 +17,14 @@
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.pattern.GeneratedPlanPatterns;
import org.apache.doris.nereids.rules.PlanRuleFactory;
import org.apache.doris.nereids.rules.RulePromise;
/**
* interface for all rule factories used in analysis stage.
*/
public interface AnalysisRuleFactory extends PlanRuleFactory {
public interface AnalysisRuleFactory extends PlanRuleFactory, GeneratedPlanPatterns {
@Override
default RulePromise defaultPromise() {
return RulePromise.ANALYSIS;

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
@ -40,7 +41,7 @@ import java.util.Map;
*
* change avg( distinct a ) into sum( distinct a ) / count( distinct a ) if there are more than 1 distinct arguments
*/
public class AvgDistinctToSumDivCount extends OneAnalysisRuleFactory {
public class AvgDistinctToSumDivCount extends OneRewriteRuleFactory {
@Override
public Rule build() {
return RuleType.AVG_DISTINCT_TO_SUM_DIV_COUNT.build(

View File

@ -28,6 +28,7 @@ import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.analyzer.UnboundTVFRelation;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.AppliedAwareRule.AppliedAwareRuleCondition;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
@ -47,9 +48,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunctio
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
import org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.LeafPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
@ -61,7 +60,6 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalIntersect;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
@ -80,7 +78,6 @@ import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
@ -93,13 +90,8 @@ import java.util.stream.Stream;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class BindExpression implements AnalysisRuleFactory {
private final Optional<Scope> outerScope;
public BindExpression(Optional<Scope> outerScope) {
this.outerScope = Objects.requireNonNull(outerScope, "outerScope cannot be null");
}
private Scope toScope(List<Slot> slots) {
private Scope toScope(CascadesContext cascadesContext, List<Slot> slots) {
Optional<Scope> outerScope = cascadesContext.getOuterScope();
if (outerScope.isPresent()) {
return new Scope(outerScope, slots, outerScope.get().getSubquery());
} else {
@ -109,10 +101,29 @@ public class BindExpression implements AnalysisRuleFactory {
@Override
public List<Rule> buildRules() {
/*
* some rules not only depends on the condition Plan::canBind, for example,
* BINDING_FILTER_SLOT need transform 'filter(unix_timestamp() > 100)' to
* 'filter(unix_timestamp() > cast(100 as int))'. there is no any unbound expression
* in the filter, so the Plan::canBind return false.
*
* we need `isAppliedRule` to judge whether a plan is applied to a rule, so need convert
* the normal rule to `AppliedAwareRule` to read and write the mutable state.
*/
AppliedAwareRuleCondition ruleCondition = new AppliedAwareRuleCondition() {
@Override
protected boolean condition(Rule rule, Plan plan) {
if (!rule.getPattern().matchRoot(plan)) {
return false;
}
return plan.canBind() || (plan.bound() && !isAppliedRule(rule, plan));
}
};
return ImmutableList.of(
RuleType.BINDING_PROJECT_SLOT.build(
logicalProject().when(Plan::canBind).thenApply(ctx -> {
LogicalProject<GroupPlan> project = ctx.root;
logicalProject().thenApply(ctx -> {
LogicalProject<Plan> project = ctx.root;
List<NamedExpression> boundProjections =
bindSlot(project.getProjects(), project.children(), ctx.cascadesContext);
List<NamedExpression> boundExceptions = bindSlot(project.getExcepts(), project.children(),
@ -125,8 +136,8 @@ public class BindExpression implements AnalysisRuleFactory {
})
),
RuleType.BINDING_FILTER_SLOT.build(
logicalFilter().when(Plan::canBind).thenApply(ctx -> {
LogicalFilter<GroupPlan> filter = ctx.root;
logicalFilter().thenApply(ctx -> {
LogicalFilter<Plan> filter = ctx.root;
Set<Expression> boundConjuncts = filter.getConjuncts().stream()
.map(expr -> bindSlot(expr, filter.children(), ctx.cascadesContext))
.map(expr -> bindFunction(expr, ctx.cascadesContext))
@ -137,7 +148,7 @@ public class BindExpression implements AnalysisRuleFactory {
RuleType.BINDING_USING_JOIN_SLOT.build(
usingJoin().thenApply(ctx -> {
UsingJoin<GroupPlan, GroupPlan> using = ctx.root;
UsingJoin<Plan, Plan> using = ctx.root;
LogicalJoin<Plan, Plan> lj = new LogicalJoin<>(using.getJoinType() == JoinType.CROSS_JOIN
? JoinType.INNER_JOIN : using.getJoinType(),
using.getHashJoinConjuncts(),
@ -151,7 +162,7 @@ public class BindExpression implements AnalysisRuleFactory {
// the most right slot is matched with priority.
Collections.reverse(leftOutput);
List<Expression> leftSlots = new ArrayList<>();
Scope scope = toScope(leftOutput.stream()
Scope scope = toScope(ctx.cascadesContext, leftOutput.stream()
.filter(s -> !slotNames.contains(s.getName()))
.peek(s -> slotNames.add(s.getName()))
.collect(Collectors.toList()));
@ -160,7 +171,7 @@ public class BindExpression implements AnalysisRuleFactory {
leftSlots.add(expression);
}
slotNames.clear();
scope = toScope(lj.right().getOutput().stream()
scope = toScope(ctx.cascadesContext, lj.right().getOutput().stream()
.filter(s -> !slotNames.contains(s.getName()))
.peek(s -> slotNames.add(s.getName()))
.collect(Collectors.toList()));
@ -178,8 +189,8 @@ public class BindExpression implements AnalysisRuleFactory {
})
),
RuleType.BINDING_JOIN_SLOT.build(
logicalJoin().when(Plan::canBind).thenApply(ctx -> {
LogicalJoin<GroupPlan, GroupPlan> join = ctx.root;
logicalJoin().thenApply(ctx -> {
LogicalJoin<Plan, Plan> join = ctx.root;
List<Expression> cond = join.getOtherJoinConjuncts().stream()
.map(expr -> bindSlot(expr, join.children(), ctx.cascadesContext))
.map(expr -> bindFunction(expr, ctx.cascadesContext))
@ -193,8 +204,8 @@ public class BindExpression implements AnalysisRuleFactory {
})
),
RuleType.BINDING_AGGREGATE_SLOT.build(
logicalAggregate().when(Plan::canBind).thenApply(ctx -> {
LogicalAggregate<GroupPlan> agg = ctx.root;
logicalAggregate().thenApply(ctx -> {
LogicalAggregate<Plan> agg = ctx.root;
List<NamedExpression> output = agg.getOutputExpressions().stream()
.map(expr -> bindSlot(expr, agg.children(), ctx.cascadesContext))
.map(expr -> bindFunction(expr, ctx.cascadesContext))
@ -291,9 +302,10 @@ public class BindExpression implements AnalysisRuleFactory {
boundSlots.addAll(outputSlots);
SlotBinder binder = new SlotBinder(
toScope(Lists.newArrayList(boundSlots)), ctx.cascadesContext);
toScope(ctx.cascadesContext, ImmutableList.copyOf(boundSlots)), ctx.cascadesContext);
SlotBinder childBinder = new SlotBinder(
toScope(new ArrayList<>(agg.child().getOutputSet())), ctx.cascadesContext);
toScope(ctx.cascadesContext, ImmutableList.copyOf(agg.child().getOutputSet())),
ctx.cascadesContext);
List<Expression> groupBy = replacedGroupBy.stream()
.map(expression -> {
@ -323,8 +335,8 @@ public class BindExpression implements AnalysisRuleFactory {
})
),
RuleType.BINDING_REPEAT_SLOT.build(
logicalRepeat().when(Plan::canBind).thenApply(ctx -> {
LogicalRepeat<GroupPlan> repeat = ctx.root;
logicalRepeat().thenApply(ctx -> {
LogicalRepeat<Plan> repeat = ctx.root;
List<NamedExpression> output = repeat.getOutputExpressions().stream()
.map(expr -> bindSlot(expr, repeat.children(), ctx.cascadesContext))
.map(expr -> bindFunction(expr, ctx.cascadesContext))
@ -368,35 +380,35 @@ public class BindExpression implements AnalysisRuleFactory {
})
),
RuleType.BINDING_SORT_SLOT.build(
logicalSort(aggregate()).when(Plan::canBind).thenApply(ctx -> {
LogicalSort<Aggregate<GroupPlan>> sort = ctx.root;
Aggregate<GroupPlan> aggregate = sort.child();
logicalSort(aggregate()).thenApply(ctx -> {
LogicalSort<Aggregate<Plan>> sort = ctx.root;
Aggregate<Plan> aggregate = sort.child();
return bindSort(sort, aggregate, ctx.cascadesContext);
})
),
RuleType.BINDING_SORT_SLOT.build(
logicalSort(logicalHaving(aggregate())).when(Plan::canBind).thenApply(ctx -> {
LogicalSort<LogicalHaving<Aggregate<GroupPlan>>> sort = ctx.root;
Aggregate<GroupPlan> aggregate = sort.child().child();
logicalSort(logicalHaving(aggregate())).thenApply(ctx -> {
LogicalSort<LogicalHaving<Aggregate<Plan>>> sort = ctx.root;
Aggregate<Plan> aggregate = sort.child().child();
return bindSort(sort, aggregate, ctx.cascadesContext);
})
),
RuleType.BINDING_SORT_SLOT.build(
logicalSort(logicalHaving(logicalProject())).when(Plan::canBind).thenApply(ctx -> {
LogicalSort<LogicalHaving<LogicalProject<GroupPlan>>> sort = ctx.root;
LogicalProject<GroupPlan> project = sort.child().child();
logicalSort(logicalHaving(logicalProject())).thenApply(ctx -> {
LogicalSort<LogicalHaving<LogicalProject<Plan>>> sort = ctx.root;
LogicalProject<Plan> project = sort.child().child();
return bindSort(sort, project, ctx.cascadesContext);
})
),
RuleType.BINDING_SORT_SLOT.build(
logicalSort(logicalProject()).when(Plan::canBind).thenApply(ctx -> {
LogicalSort<LogicalProject<GroupPlan>> sort = ctx.root;
LogicalProject<GroupPlan> project = sort.child();
logicalSort(logicalProject()).thenApply(ctx -> {
LogicalSort<LogicalProject<Plan>> sort = ctx.root;
LogicalProject<Plan> project = sort.child();
return bindSort(sort, project, ctx.cascadesContext);
})
),
RuleType.BINDING_SORT_SET_OPERATION_SLOT.build(
logicalSort(logicalSetOperation()).when(Plan::canBind).thenApply(ctx -> {
logicalSort(logicalSetOperation()).thenApply(ctx -> {
LogicalSort<LogicalSetOperation> sort = ctx.root;
List<OrderKey> sortItemList = sort.getOrderKeys()
.stream()
@ -409,8 +421,8 @@ public class BindExpression implements AnalysisRuleFactory {
})
),
RuleType.BINDING_HAVING_SLOT.build(
logicalHaving(aggregate()).when(Plan::canBind).thenApply(ctx -> {
LogicalHaving<Aggregate<GroupPlan>> having = ctx.root;
logicalHaving(aggregate()).thenApply(ctx -> {
LogicalHaving<Aggregate<Plan>> having = ctx.root;
Plan childPlan = having.child();
Set<Expression> boundConjuncts = having.getConjuncts().stream()
.map(expr -> {
@ -423,7 +435,7 @@ public class BindExpression implements AnalysisRuleFactory {
})
),
RuleType.BINDING_HAVING_SLOT.build(
logicalHaving(any()).when(Plan::canBind).thenApply(ctx -> {
logicalHaving(any()).thenApply(ctx -> {
LogicalHaving<Plan> having = ctx.root;
Plan childPlan = having.child();
Set<Expression> boundConjuncts = having.getConjuncts().stream()
@ -449,7 +461,7 @@ public class BindExpression implements AnalysisRuleFactory {
})
),
RuleType.BINDING_SET_OPERATION_SLOT.build(
logicalSetOperation().when(Plan::canBind).then(setOperation -> {
logicalSetOperation().then(setOperation -> {
// check whether the left and right child output columns are the same
if (setOperation.child(0).getOutput().size() != setOperation.child(1).getOutput().size()) {
throw new AnalysisException("Operands have unequal number of columns:\n"
@ -472,8 +484,8 @@ public class BindExpression implements AnalysisRuleFactory {
})
),
RuleType.BINDING_GENERATE_SLOT.build(
logicalGenerate().when(Plan::canBind).thenApply(ctx -> {
LogicalGenerate<GroupPlan> generate = ctx.root;
logicalGenerate().thenApply(ctx -> {
LogicalGenerate<Plan> generate = ctx.root;
List<Function> boundSlotGenerators
= bindSlot(generate.getGenerators(), generate.children(), ctx.cascadesContext);
List<Function> boundFunctionGenerators = boundSlotGenerators.stream()
@ -497,16 +509,8 @@ public class BindExpression implements AnalysisRuleFactory {
UnboundTVFRelation relation = ctx.root;
return bindTableValuedFunction(relation, ctx.statementContext);
})
),
// when child update, we need update current plan's logical properties,
// since we use cache to avoid compute more than once.
RuleType.BINDING_NON_LEAF_LOGICAL_PLAN.build(
logicalPlan()
.when(plan -> plan.canBind() && !(plan instanceof LeafPlan))
.then(LogicalPlan::recomputeLogicalProperties)
)
);
).stream().map(ruleCondition).collect(ImmutableList.toImmutableList());
}
private Plan bindSort(LogicalSort<? extends Plan> sort, Plan plan, CascadesContext ctx) {
@ -528,8 +532,8 @@ public class BindExpression implements AnalysisRuleFactory {
List<OrderKey> sortItemList = sort.getOrderKeys()
.stream()
.map(orderKey -> {
Expression item = bindSlot(orderKey.getExpr(), plan, ctx);
item = bindSlot(item, plan.children(), ctx);
Expression item = bindSlot(orderKey.getExpr(), plan, ctx, true, false);
item = bindSlot(item, plan.children(), ctx, true, false);
item = bindFunction(item, ctx);
return new OrderKey(item, orderKey.isAsc(), orderKey.isNullFirst());
}).collect(Collectors.toList());
@ -561,21 +565,33 @@ public class BindExpression implements AnalysisRuleFactory {
@SuppressWarnings("unchecked")
private <E extends Expression> E bindSlot(E expr, Plan input, CascadesContext cascadesContext) {
return bindSlot(expr, input, cascadesContext, true);
return bindSlot(expr, input, cascadesContext, true, true);
}
private <E extends Expression> E bindSlot(E expr, Plan input, CascadesContext cascadesContext,
boolean enableExactMatch) {
return (E) new SlotBinder(toScope(input.getOutput()), cascadesContext, enableExactMatch).bind(expr);
return bindSlot(expr, input, cascadesContext, enableExactMatch, true);
}
private <E extends Expression> E bindSlot(E expr, Plan input, CascadesContext cascadesContext,
boolean enableExactMatch, boolean bindSlotInOuterScope) {
return (E) new SlotBinder(toScope(cascadesContext, input.getOutput()), cascadesContext,
enableExactMatch, bindSlotInOuterScope).bind(expr);
}
@SuppressWarnings("unchecked")
private <E extends Expression> E bindSlot(E expr, List<Plan> inputs, CascadesContext cascadesContext,
boolean enableExactMatch) {
return bindSlot(expr, inputs, cascadesContext, enableExactMatch, true);
}
private <E extends Expression> E bindSlot(E expr, List<Plan> inputs, CascadesContext cascadesContext,
boolean enableExactMatch, boolean bindSlotInOuterScope) {
List<Slot> boundedSlots = inputs.stream()
.flatMap(input -> input.getOutput().stream())
.collect(Collectors.toList());
return (E) new SlotBinder(toScope(boundedSlots), cascadesContext, enableExactMatch).bind(expr);
return (E) new SlotBinder(toScope(cascadesContext, boundedSlots), cascadesContext,
enableExactMatch, bindSlotInOuterScope).bind(expr);
}
@SuppressWarnings("unchecked")
@ -648,4 +664,8 @@ public class BindExpression implements AnalysisRuleFactory {
function = (BoundFunction) TypeCoercion.INSTANCE.rewrite(function, null);
return function;
}
public boolean canBind(Plan plan) {
return !plan.hasUnboundExpression() || plan.canBind();
}
}

View File

@ -30,10 +30,13 @@ import org.apache.doris.common.util.Util;
import org.apache.doris.datasource.CatalogIf;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.CTEContext;
import org.apache.doris.nereids.analyzer.Unbound;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.pattern.MatchingContext;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.EqualTo;
@ -70,26 +73,37 @@ public class BindRelation extends OneAnalysisRuleFactory {
@Override
public Rule build() {
return unboundRelation().thenApply(ctx -> {
List<String> nameParts = ctx.root.getNameParts();
switch (nameParts.size()) {
case 1: { // table
// Use current database name from catalog.
return bindWithCurrentDb(ctx.cascadesContext, ctx.root);
}
case 2: { // db.table
// Use database name from table name parts.
return bindWithDbNameFromNamePart(ctx.cascadesContext, ctx.root);
}
case 3: { // catalog.db.table
// Use catalog and database name from name parts.
return bindWithCatalogNameFromNamePart(ctx.cascadesContext, ctx.root);
}
default:
throw new IllegalStateException("Table name [" + ctx.root.getTableName() + "] is invalid.");
Plan plan = doBindRelation(ctx);
if (!(plan instanceof Unbound)) {
// init output and allocate slot id immediately, so that the slot id increase
// in the order in which the table appears.
LogicalProperties logicalProperties = plan.getLogicalProperties();
logicalProperties.getOutput();
}
return plan;
}).toRule(RuleType.BINDING_RELATION);
}
private Plan doBindRelation(MatchingContext<UnboundRelation> ctx) {
List<String> nameParts = ctx.root.getNameParts();
switch (nameParts.size()) {
case 1: { // table
// Use current database name from catalog.
return bindWithCurrentDb(ctx.cascadesContext, ctx.root);
}
case 2: { // db.table
// Use database name from table name parts.
return bindWithDbNameFromNamePart(ctx.cascadesContext, ctx.root);
}
case 3: { // catalog.db.table
// Use catalog and database name from name parts.
return bindWithCatalogNameFromNamePart(ctx.cascadesContext, ctx.root);
}
default:
throw new IllegalStateException("Table name [" + ctx.root.getTableName() + "] is invalid.");
}
}
private TableIf getTable(String catalogName, String dbName, String tableName, Env env) {
CatalogIf catalog = env.getCatalogMgr().getCatalog(catalogName);
if (catalog == null) {
@ -209,12 +223,12 @@ public class BindRelation extends OneAnalysisRuleFactory {
private Plan parseAndAnalyzeView(String viewSql, CascadesContext parentContext) {
LogicalPlan parsedViewPlan = new NereidsParser().parseSingle(viewSql);
CascadesContext viewContext = new Memo(parsedViewPlan)
.newCascadesContext(parentContext.getStatementContext());
CascadesContext viewContext = CascadesContext.newRewriteContext(
parentContext.getStatementContext(), parsedViewPlan, PhysicalProperties.ANY);
viewContext.newAnalyzer().analyze();
// we should remove all group expression of the plan which in other memo, so the groupId would not conflict
return viewContext.getMemo().copyOut(false);
return viewContext.getRewritePlan();
}
private List<Long> getPartitionIds(TableIf t, UnboundRelation unboundRelation) {

View File

@ -17,9 +17,11 @@
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
@ -39,15 +41,13 @@ public class CheckPolicy implements AnalysisRuleFactory {
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.CHECK_ROW_POLICY.build(
logicalCheckPolicy(logicalSubQueryAlias()).then(checkPolicy -> checkPolicy.child())
),
RuleType.CHECK_ROW_POLICY.build(
logicalCheckPolicy(logicalFilter()).then(checkPolicy -> checkPolicy.child())
),
RuleType.CHECK_ROW_POLICY.build(
logicalCheckPolicy(logicalRelation()).thenApply(ctx -> {
LogicalCheckPolicy<LogicalRelation> checkPolicy = ctx.root;
LogicalRelation relation = checkPolicy.child();
logicalCheckPolicy(any().when(child -> !(child instanceof UnboundRelation))).thenApply(ctx -> {
LogicalCheckPolicy<Plan> checkPolicy = ctx.root;
Plan child = checkPolicy.child();
if (!(child instanceof LogicalRelation)) {
return child;
}
LogicalRelation relation = (LogicalRelation) child;
Optional<Expression> filter = checkPolicy.getFilter(relation, ctx.connectContext);
if (!filter.isPresent()) {
return relation;

View File

@ -20,11 +20,9 @@ package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.CTEContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTE;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
@ -48,7 +46,7 @@ public class RegisterCTE extends OneAnalysisRuleFactory {
@Override
public Rule build() {
return logicalCTE().thenApply(ctx -> {
LogicalCTE<GroupPlan> logicalCTE = ctx.root;
LogicalCTE<Plan> logicalCTE = ctx.root;
register(logicalCTE.getAliasQueries(), ctx.cascadesContext);
return logicalCTE.child();
}).toRule(RuleType.REGISTER_CTE);
@ -69,10 +67,10 @@ public class RegisterCTE extends OneAnalysisRuleFactory {
CTEContext localCteContext = cteCtx;
Function<Plan, LogicalPlan> analyzeCte = parsePlan -> {
CascadesContext localCascadesContext = new Memo(parsePlan)
.newCascadesContext(cascadesContext.getStatementContext(), localCteContext);
CascadesContext localCascadesContext = CascadesContext.newRewriteContext(
cascadesContext.getStatementContext(), parsePlan, localCteContext);
localCascadesContext.newAnalyzer().analyze();
return (LogicalPlan) localCascadesContext.getMemo().copyOut(false);
return (LogicalPlan) localCascadesContext.getRewritePlan();
};
LogicalPlan analyzedCteBody = analyzeCte.apply(aliasQuery.child());

View File

@ -23,7 +23,7 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
@ -46,7 +46,7 @@ public class ReplaceExpressionByChildOutput implements AnalysisRuleFactory {
return ImmutableList.<Rule>builder()
.add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build(
logicalSort(logicalProject()).then(sort -> {
LogicalProject<GroupPlan> project = sort.child();
LogicalProject<Plan> project = sort.child();
Map<Expression, Slot> sMap = Maps.newHashMap();
project.getProjects().stream()
.filter(Alias.class::isInstance)
@ -57,7 +57,7 @@ public class ReplaceExpressionByChildOutput implements AnalysisRuleFactory {
))
.add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build(
logicalSort(logicalAggregate()).then(sort -> {
LogicalAggregate<GroupPlan> aggregate = sort.child();
LogicalAggregate<Plan> aggregate = sort.child();
Map<Expression, Slot> sMap = Maps.newHashMap();
aggregate.getOutputExpressions().stream()
.filter(Alias.class::isInstance)
@ -67,7 +67,7 @@ public class ReplaceExpressionByChildOutput implements AnalysisRuleFactory {
})
)).add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build(
logicalSort(logicalHaving(logicalAggregate())).then(sort -> {
LogicalAggregate<GroupPlan> aggregate = sort.child().child();
LogicalAggregate<Plan> aggregate = sort.child().child();
Map<Expression, Slot> sMap = Maps.newHashMap();
aggregate.getOutputExpressions().stream()
.filter(Alias.class::isInstance)

View File

@ -20,13 +20,16 @@ package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rewrite.rules.FoldConstantRule;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import com.google.common.collect.ImmutableList;
@ -43,13 +46,15 @@ public class ResolveOrdinalInOrderByAndGroupBy implements AnalysisRuleFactory {
public List<Rule> buildRules() {
return ImmutableList.<Rule>builder()
.add(RuleType.RESOLVE_ORDINAL_IN_ORDER_BY.build(
logicalSort().then(sort -> {
logicalSort().thenApply(ctx -> {
LogicalSort<Plan> sort = ctx.root;
List<Slot> childOutput = sort.child().getOutput();
List<OrderKey> orderKeys = sort.getOrderKeys();
List<OrderKey> orderKeysWithoutOrd = new ArrayList<>();
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
for (OrderKey k : orderKeys) {
Expression expression = k.getExpr();
expression = FoldConstantRule.INSTANCE.rewrite(expression);
expression = FoldConstantRule.INSTANCE.rewrite(expression, context);
if (expression instanceof IntegerLikeLiteral) {
IntegerLikeLiteral i = (IntegerLikeLiteral) expression;
int ord = i.getIntValue();
@ -64,12 +69,14 @@ public class ResolveOrdinalInOrderByAndGroupBy implements AnalysisRuleFactory {
})
))
.add(RuleType.RESOLVE_ORDINAL_IN_GROUP_BY.build(
logicalAggregate().whenNot(agg -> agg.isOrdinalIsResolved()).then(agg -> {
logicalAggregate().whenNot(agg -> agg.isOrdinalIsResolved()).thenApply(ctx -> {
LogicalAggregate<Plan> agg = ctx.root;
List<NamedExpression> aggOutput = agg.getOutputExpressions();
List<Expression> groupByWithoutOrd = new ArrayList<>();
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
boolean ordExists = false;
for (Expression groupByExpr : agg.getGroupByExpressions()) {
groupByExpr = FoldConstantRule.INSTANCE.rewrite(groupByExpr);
groupByExpr = FoldConstantRule.INSTANCE.rewrite(groupByExpr, context);
if (groupByExpr instanceof IntegerLikeLiteral) {
IntegerLikeLiteral i = (IntegerLikeLiteral) groupByExpr;
int ord = i.getIntValue();

View File

@ -47,15 +47,18 @@ class SlotBinder extends SubExprAnalyzer {
but enabled for order by clause
TODO after remove original planner, always enable exact match mode.
*/
private boolean enableExactMatch = true;
private boolean enableExactMatch;
private final boolean bindSlotInOuterScope;
public SlotBinder(Scope scope, CascadesContext cascadesContext) {
this(scope, cascadesContext, true);
this(scope, cascadesContext, true, true);
}
public SlotBinder(Scope scope, CascadesContext cascadesContext, boolean enableExactMatch) {
public SlotBinder(Scope scope, CascadesContext cascadesContext,
boolean enableExactMatch, boolean bindSlotInOuterScope) {
super(scope, cascadesContext);
this.enableExactMatch = enableExactMatch;
this.bindSlotInOuterScope = bindSlotInOuterScope;
}
public Expression bind(Expression expression) {
@ -81,7 +84,7 @@ class SlotBinder extends SubExprAnalyzer {
Optional<List<Slot>> boundedOpt = Optional.of(bindSlot(unboundSlot, getScope().getSlots()));
boolean foundInThisScope = !boundedOpt.get().isEmpty();
// Currently only looking for symbols on the previous level.
if (!foundInThisScope && getScope().getOuterScope().isPresent()) {
if (bindSlotInOuterScope && !foundInThisScope && getScope().getOuterScope().isPresent()) {
boundedOpt = Optional.of(bindSlot(unboundSlot,
getScope()
.getOuterScope()

View File

@ -20,7 +20,6 @@ package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.trees.expressions.Exists;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InSubquery;
@ -38,6 +37,7 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
@ -140,13 +140,12 @@ class SubExprAnalyzer extends DefaultExpressionRewriter<CascadesContext> {
}
private AnalyzedResult analyzeSubquery(SubqueryExpr expr) {
CascadesContext subqueryContext = new Memo(expr.getQueryPlan())
.newCascadesContext((cascadesContext.getStatementContext()), cascadesContext.getCteContext());
CascadesContext subqueryContext = CascadesContext.newRewriteContext(
cascadesContext.getStatementContext(), expr.getQueryPlan(), cascadesContext.getCteContext());
Scope subqueryScope = genScopeWithSubquery(expr);
subqueryContext
.newAnalyzer(Optional.of(subqueryScope))
.analyze();
return new AnalyzedResult((LogicalPlan) subqueryContext.getMemo().copyOut(false),
subqueryContext.setOuterScope(subqueryScope);
subqueryContext.newAnalyzer().analyze();
return new AnalyzedResult((LogicalPlan) subqueryContext.getRewritePlan(),
subqueryScope.getCorrelatedSlots());
}
@ -168,7 +167,7 @@ class SubExprAnalyzer extends DefaultExpressionRewriter<CascadesContext> {
private final LogicalPlan logicalPlan;
private final List<Slot> correlatedSlots;
public AnalyzedResult(LogicalPlan logicalPlan, List<Slot> correlatedSlots) {
public AnalyzedResult(LogicalPlan logicalPlan, Collection<Slot> correlatedSlots) {
this.logicalPlan = Objects.requireNonNull(logicalPlan, "logicalPlan can not be null");
this.correlatedSlots = correlatedSlots == null ? new ArrayList<>() : ImmutableList.copyOf(correlatedSlots);
}

View File

@ -30,7 +30,7 @@ import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
@ -46,18 +46,18 @@ import java.util.Optional;
import java.util.Set;
/**
* AnalyzeSubquery. translate from subquery to LogicalApply.
* SubqueryToApply. translate from subquery to LogicalApply.
* In two steps
* The first step is to replace the predicate corresponding to the filter where the subquery is located.
* The second step converts the subquery into an apply node.
*/
public class AnalyzeSubquery implements AnalysisRuleFactory {
public class SubqueryToApply implements AnalysisRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.ANALYZE_FILTER_SUBQUERY.build(
RuleType.FILTER_SUBQUERY_TO_APPLY.build(
logicalFilter().thenApply(ctx -> {
LogicalFilter<GroupPlan> filter = ctx.root;
LogicalFilter<Plan> filter = ctx.root;
Set<SubqueryExpr> subqueryExprs = filter.getPredicate().collect(SubqueryExpr.class::isInstance);
if (subqueryExprs.isEmpty()) {
return filter;
@ -66,14 +66,14 @@ public class AnalyzeSubquery implements AnalysisRuleFactory {
// first step: Replace the subquery of predicate in LogicalFilter
// second step: Replace subquery with LogicalApply
return new LogicalFilter<>(new ReplaceSubquery().replace(filter.getConjuncts()),
analyzedSubquery(
subqueryToApply(
subqueryExprs, filter.child(), ctx.cascadesContext
));
})
),
RuleType.ANALYZE_PROJECT_SUBQUERY.build(
RuleType.PROJECT_SUBQUERY_TO_APPLY.build(
logicalProject().thenApply(ctx -> {
LogicalProject<GroupPlan> project = ctx.root;
LogicalProject<Plan> project = ctx.root;
Set<SubqueryExpr> subqueryExprs = new HashSet<>();
project.getProjects().stream()
.filter(Alias.class::isInstance)
@ -89,17 +89,17 @@ public class AnalyzeSubquery implements AnalysisRuleFactory {
return new LogicalProject(project.getProjects().stream()
.map(p -> p.withChildren(new ReplaceSubquery().replace(p)))
.collect(ImmutableList.toImmutableList()),
analyzedSubquery(
subqueryToApply(
subqueryExprs, project.child(), ctx.cascadesContext
));
})
)
)
);
}
private LogicalPlan analyzedSubquery(Set<SubqueryExpr> subqueryExprs,
LogicalPlan childPlan, CascadesContext ctx) {
LogicalPlan tmpPlan = childPlan;
private Plan subqueryToApply(Set<SubqueryExpr> subqueryExprs,
Plan childPlan, CascadesContext ctx) {
Plan tmpPlan = childPlan;
for (SubqueryExpr subqueryExpr : subqueryExprs) {
if (!ctx.subqueryIsAnalyzed(subqueryExpr)) {
tmpPlan = addApply(subqueryExpr, tmpPlan, ctx);
@ -108,8 +108,7 @@ public class AnalyzeSubquery implements AnalysisRuleFactory {
return tmpPlan;
}
private LogicalPlan addApply(SubqueryExpr subquery,
LogicalPlan childPlan, CascadesContext ctx) {
private LogicalPlan addApply(SubqueryExpr subquery, Plan childPlan, CascadesContext ctx) {
ctx.setSubqueryExprIsAnalyzed(subquery, true);
LogicalApply newApply = new LogicalApply(
subquery.getCorrelateSlots(),

View File

@ -33,7 +33,8 @@ public class UserAuthentication extends OneAnalysisRuleFactory {
@Override
public Rule build() {
return logicalRelation().thenApply(ctx -> checkPermission(ctx.root, ctx.connectContext))
return logicalRelation()
.thenApply(ctx -> checkPermission(ctx.root, ctx.connectContext))
.toRule(RuleType.RELATION_AUTHENTICATION);
}
@ -46,7 +47,6 @@ public class UserAuthentication extends OneAnalysisRuleFactory {
ConnectContext.get().getQualifiedUser(), ConnectContext.get().getRemoteIP(),
dbName + ": " + tableName);
throw new AnalysisException(message);
}
return relation;
}

View File

@ -17,13 +17,14 @@
package org.apache.doris.nereids.rules.exploration;
import org.apache.doris.nereids.pattern.GeneratedMemoPatterns;
import org.apache.doris.nereids.rules.PlanRuleFactory;
import org.apache.doris.nereids.rules.RulePromise;
/**
* interface for all exploration rule factories.
*/
public interface ExplorationRuleFactory extends PlanRuleFactory {
public interface ExplorationRuleFactory extends PlanRuleFactory, GeneratedMemoPatterns {
@Override
default RulePromise defaultPromise() {
return RulePromise.EXPLORE;

View File

@ -29,7 +29,7 @@ import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyCastRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyNotExprRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.SupportJavaDateFormatter;
import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import com.google.common.collect.ImmutableList;
@ -56,8 +56,13 @@ public class ExpressionNormalization extends ExpressionRewrite {
SupportJavaDateFormatter.INSTANCE
);
public ExpressionNormalization(ConnectContext context) {
super(new ExpressionRuleExecutor(NORMALIZE_REWRITE_RULES, context));
public ExpressionNormalization() {
super(new ExpressionRuleExecutor(NORMALIZE_REWRITE_RULES));
}
@Override
public Expression rewrite(Expression expression, ExpressionRewriteContext context) {
return super.rewrite(expression, context);
}
}

View File

@ -25,11 +25,16 @@ import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
@ -56,8 +61,8 @@ public class ExpressionRewrite implements RewriteRuleFactory {
this.rewriter = Objects.requireNonNull(rewriter, "rewriter is null");
}
public Expression rewrite(Expression expression) {
return rewriter.rewrite(expression);
public Expression rewrite(Expression expression, ExpressionRewriteContext expressionRewriteContext) {
return rewriter.rewrite(expression, expressionRewriteContext);
}
@Override
@ -77,10 +82,12 @@ public class ExpressionRewrite implements RewriteRuleFactory {
private class GenerateExpressionRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalGenerate().then(generate -> {
return logicalGenerate().thenApply(ctx -> {
LogicalGenerate<Plan> generate = ctx.root;
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
List<Function> generators = generate.getGenerators();
List<Function> newGenerators = generators.stream()
.map(func -> (Function) rewriter.rewrite(func))
.map(func -> (Function) rewriter.rewrite(func, context))
.collect(ImmutableList.toImmutableList());
if (generators.equals(newGenerators)) {
return generate;
@ -93,11 +100,14 @@ public class ExpressionRewrite implements RewriteRuleFactory {
private class OneRowRelationExpressionRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalOneRowRelation().then(oneRowRelation -> {
return logicalOneRowRelation().thenApply(ctx -> {
LogicalOneRowRelation oneRowRelation = ctx.root;
List<NamedExpression> projects = oneRowRelation.getProjects();
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
List<NamedExpression> newProjects = projects
.stream()
.map(expr -> (NamedExpression) rewriter.rewrite(expr))
.map(expr -> (NamedExpression) rewriter.rewrite(expr, context))
.collect(ImmutableList.toImmutableList());
if (projects.equals(newProjects)) {
return oneRowRelation;
@ -110,10 +120,13 @@ public class ExpressionRewrite implements RewriteRuleFactory {
private class ProjectExpressionRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalProject().then(project -> {
return logicalProject().thenApply(ctx -> {
LogicalProject<Plan> project = ctx.root;
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
List<NamedExpression> projects = project.getProjects();
List<NamedExpression> newProjects = projects.stream()
.map(expr -> (NamedExpression) rewriter.rewrite(expr)).collect(ImmutableList.toImmutableList());
.map(expr -> (NamedExpression) rewriter.rewrite(expr, context))
.collect(ImmutableList.toImmutableList());
if (projects.equals(newProjects)) {
return project;
}
@ -125,9 +138,11 @@ public class ExpressionRewrite implements RewriteRuleFactory {
private class FilterExpressionRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalFilter().then(filter -> {
return logicalFilter().thenApply(ctx -> {
LogicalFilter<Plan> filter = ctx.root;
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
Set<Expression> newConjuncts = ImmutableSet.copyOf(ExpressionUtils.extractConjunction(
rewriter.rewrite(filter.getPredicate())));
rewriter.rewrite(filter.getPredicate(), context)));
if (newConjuncts.equals(filter.getConjuncts())) {
return filter;
}
@ -139,13 +154,16 @@ public class ExpressionRewrite implements RewriteRuleFactory {
private class AggExpressionRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalAggregate().then(agg -> {
return logicalAggregate().thenApply(ctx -> {
LogicalAggregate<Plan> agg = ctx.root;
List<Expression> groupByExprs = agg.getGroupByExpressions();
List<Expression> newGroupByExprs = rewriter.rewrite(groupByExprs);
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
List<Expression> newGroupByExprs = rewriter.rewrite(groupByExprs, context);
List<NamedExpression> outputExpressions = agg.getOutputExpressions();
List<NamedExpression> newOutputExpressions = outputExpressions.stream()
.map(expr -> (NamedExpression) rewriter.rewrite(expr)).collect(ImmutableList.toImmutableList());
.map(expr -> (NamedExpression) rewriter.rewrite(expr, context))
.collect(ImmutableList.toImmutableList());
if (outputExpressions.equals(newOutputExpressions)) {
return agg;
}
@ -158,16 +176,18 @@ public class ExpressionRewrite implements RewriteRuleFactory {
private class JoinExpressionRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalJoin().then(join -> {
return logicalJoin().thenApply(ctx -> {
LogicalJoin<Plan, Plan> join = ctx.root;
List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts();
List<Expression> otherJoinConjuncts = join.getOtherJoinConjuncts();
if (otherJoinConjuncts.isEmpty() && hashJoinConjuncts.isEmpty()) {
return join;
}
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
List<Expression> rewriteHashJoinConjuncts = Lists.newArrayList();
boolean hashJoinConjunctsChanged = false;
for (Expression expr : hashJoinConjuncts) {
Expression newExpr = rewriter.rewrite(expr);
Expression newExpr = rewriter.rewrite(expr, context);
hashJoinConjunctsChanged = hashJoinConjunctsChanged || !newExpr.equals(expr);
rewriteHashJoinConjuncts.add(newExpr);
}
@ -175,7 +195,7 @@ public class ExpressionRewrite implements RewriteRuleFactory {
List<Expression> rewriteOtherJoinConjuncts = Lists.newArrayList();
boolean otherJoinConjunctsChanged = false;
for (Expression expr : otherJoinConjuncts) {
Expression newExpr = rewriter.rewrite(expr);
Expression newExpr = rewriter.rewrite(expr, context);
otherJoinConjunctsChanged = otherJoinConjunctsChanged || !newExpr.equals(expr);
rewriteOtherJoinConjuncts.add(newExpr);
}
@ -193,11 +213,13 @@ public class ExpressionRewrite implements RewriteRuleFactory {
@Override
public Rule build() {
return logicalSort().then(sort -> {
return logicalSort().thenApply(ctx -> {
LogicalSort<Plan> sort = ctx.root;
List<OrderKey> orderKeys = sort.getOrderKeys();
List<OrderKey> rewrittenOrderKeys = new ArrayList<>();
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
for (OrderKey k : orderKeys) {
Expression expression = rewriter.rewrite(k.getExpr());
Expression expression = rewriter.rewrite(k.getExpr(), context);
rewrittenOrderKeys.add(new OrderKey(expression, k.isAsc(), k.isNullFirst()));
}
return sort.withOrderKeys(rewrittenOrderKeys);
@ -208,10 +230,12 @@ public class ExpressionRewrite implements RewriteRuleFactory {
private class HavingExpressionRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalHaving().then(having -> {
return logicalHaving().thenApply(ctx -> {
LogicalHaving<Plan> having = ctx.root;
Set<Expression> rewrittenExpr = new HashSet<>();
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
for (Expression e : having.getExpressions()) {
rewrittenExpr.add(rewriter.rewrite(e));
rewrittenExpr.add(rewriter.rewrite(e, context));
}
return having.withExpressions(rewrittenExpr);
}).toRule(RuleType.REWRITE_HAVING_EXPRESSION);
@ -221,15 +245,19 @@ public class ExpressionRewrite implements RewriteRuleFactory {
private class LogicalRepeatRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalRepeat().then(r -> {
return logicalRepeat().thenApply(ctx -> {
LogicalRepeat<Plan> repeat = ctx.root;
ImmutableList.Builder<List<Expression>> groupingExprs = ImmutableList.builder();
for (List<Expression> expressions : r.getGroupingSets()) {
groupingExprs.add(expressions.stream().map(rewriter::rewrite)
.collect(ImmutableList.toImmutableList()));
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
for (List<Expression> expressions : repeat.getGroupingSets()) {
groupingExprs.add(expressions.stream()
.map(expr -> rewriter.rewrite(expr, context))
.collect(ImmutableList.toImmutableList())
);
}
return r.withGroupSetsAndOutput(groupingExprs.build(),
r.getOutputExpressions().stream()
.map(rewriter::rewrite)
return repeat.withGroupSetsAndOutput(groupingExprs.build(),
repeat.getOutputExpressions().stream()
.map(output -> rewriter.rewrite(output, context))
.map(e -> (NamedExpression) e)
.collect(ImmutableList.toImmutableList()));
}).toRule(RuleType.REWRITE_REPEAT_EXPRESSION);

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.rules.expression.rewrite;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.qe.ConnectContext;
/**
@ -25,7 +26,7 @@ import org.apache.doris.qe.ConnectContext;
public class ExpressionRewriteContext {
public final ConnectContext connectContext;
public ExpressionRewriteContext(ConnectContext connectContext) {
this.connectContext = connectContext;
public ExpressionRewriteContext(CascadesContext cascadesContext) {
this.connectContext = cascadesContext.getConnectContext();
}
}

View File

@ -19,7 +19,6 @@ package org.apache.doris.nereids.rules.expression.rewrite;
import org.apache.doris.nereids.rules.expression.rewrite.rules.NormalizeBinaryPredicatesRule;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
@ -30,30 +29,23 @@ import java.util.Optional;
* Expression rewrite entry, which contains all rewrite rules.
*/
public class ExpressionRuleExecutor {
private final ExpressionRewriteContext ctx;
private final List<ExpressionRewriteRule> rules;
public ExpressionRuleExecutor(List<ExpressionRewriteRule> rules, ConnectContext context) {
this.rules = rules;
this.ctx = new ExpressionRewriteContext(context);
}
public ExpressionRuleExecutor(List<ExpressionRewriteRule> rules) {
this(rules, null);
this.rules = rules;
}
public List<Expression> rewrite(List<Expression> exprs) {
return exprs.stream().map(this::rewrite).collect(ImmutableList.toImmutableList());
public List<Expression> rewrite(List<Expression> exprs, ExpressionRewriteContext ctx) {
return exprs.stream().map(expr -> rewrite(expr, ctx)).collect(ImmutableList.toImmutableList());
}
/**
* Given an expression, returns a rewritten expression.
*/
public Expression rewrite(Expression root) {
public Expression rewrite(Expression root, ExpressionRewriteContext ctx) {
Expression result = root;
for (ExpressionRewriteRule rule : rules) {
result = applyRule(result, rule);
result = applyRule(result, rule, ctx);
}
return result;
}
@ -61,11 +53,11 @@ public class ExpressionRuleExecutor {
/**
* Given an expression, returns a rewritten expression.
*/
public Optional<Expression> rewrite(Optional<Expression> root) {
return root.map(this::rewrite);
public Optional<Expression> rewrite(Optional<Expression> root, ExpressionRewriteContext ctx) {
return root.map(r -> this.rewrite(r, ctx));
}
private Expression applyRule(Expression expr, ExpressionRewriteRule rule) {
private Expression applyRule(Expression expr, ExpressionRewriteRule rule, ExpressionRewriteContext ctx) {
return rule.rewrite(expr, ctx);
}

View File

@ -20,7 +20,6 @@ package org.apache.doris.nereids.rules.expression.rewrite.rules;
import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.qe.ConnectContext;
/**
* Constant evaluation of an expression.
@ -36,11 +35,5 @@ public class FoldConstantRule extends AbstractExpressionRewriteRule {
}
return FoldConstantRuleOnFE.INSTANCE.rewrite(expr, ctx);
}
public Expression rewrite(Expression expr) {
ExpressionRewriteContext ctx = new ExpressionRewriteContext(ConnectContext.get());
return rewrite(expr, ctx);
}
}

View File

@ -87,8 +87,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
@Override
public List<Rule> buildRules() {
PatternDescriptor<LogicalAggregate<GroupPlan>> basePattern = logicalAggregate()
.when(LogicalAggregate::isNormalized);
PatternDescriptor<LogicalAggregate<GroupPlan>> basePattern = logicalAggregate();
return ImmutableList.of(
RuleType.STORAGE_LAYER_AGGREGATE_WITHOUT_PROJECT.build(
@ -125,12 +124,12 @@ public class AggregateStrategies implements ImplementationRuleFactory {
RuleType.TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
basePattern
.when(this::containsCountDistinctMultiExpr)
.thenApplyMulti(ctx -> twoPhaseAggregateWithCountDistinctMulti(ctx.root, ctx.connectContext))
.thenApplyMulti(ctx -> twoPhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext))
),
RuleType.THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
basePattern
.when(this::containsCountDistinctMultiExpr)
.thenApplyMulti(ctx -> threePhaseAggregateWithCountDistinctMulti(ctx.root, ctx.connectContext))
.thenApplyMulti(ctx -> threePhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext))
),
RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT.build(
basePattern
@ -393,7 +392,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
*
*/
private List<PhysicalHashAggregate<Plan>> twoPhaseAggregateWithCountDistinctMulti(
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) {
LogicalAggregate<? extends Plan> logicalAgg, CascadesContext cascadesContext) {
AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER);
Set<Expression> countDistinctArguments = logicalAgg.getDistinctArguments();
@ -423,13 +422,14 @@ public class AggregateStrategies implements ImplementationRuleFactory {
PhysicalHashAggregate<Plan> gatherLocalAgg = new PhysicalHashAggregate<>(
localAggGroupBy, localOutput, Optional.of(partitionExpressions),
new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER),
maybeUsingStreamAgg(connectContext, logicalAgg),
maybeUsingStreamAgg(cascadesContext.getConnectContext(), logicalAgg),
logicalAgg.getLogicalProperties(), requireGather, logicalAgg.child()
);
List<Expression> distinctGroupBy = logicalAgg.getGroupByExpressions();
LogicalAggregate<? extends Plan> countIfAgg = countDistinctMultiExprToCountIf(logicalAgg, connectContext).first;
LogicalAggregate<? extends Plan> countIfAgg = countDistinctMultiExprToCountIf(
logicalAgg, cascadesContext).first;
AggregateParam distinctInputToResultParam
= new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT);
@ -509,7 +509,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
*
*/
private List<PhysicalHashAggregate<? extends Plan>> threePhaseAggregateWithCountDistinctMulti(
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) {
LogicalAggregate<? extends Plan> logicalAgg, CascadesContext cascadesContext) {
AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER);
Set<Expression> countDistinctArguments = logicalAgg.getDistinctArguments();
@ -539,7 +539,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
PhysicalHashAggregate<Plan> anyLocalAgg = new PhysicalHashAggregate<>(
localAggGroupBy, localOutput, Optional.of(partitionExpressions),
new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER),
maybeUsingStreamAgg(connectContext, logicalAgg),
maybeUsingStreamAgg(cascadesContext.getConnectContext(), logicalAgg),
logicalAgg.getLogicalProperties(), requireAny, logicalAgg.child()
);
@ -571,7 +571,8 @@ public class AggregateStrategies implements ImplementationRuleFactory {
bufferToBufferParam, false, logicalAgg.getLogicalProperties(),
requireGather, anyLocalAgg);
LogicalAggregate<? extends Plan> countIfAgg = countDistinctMultiExprToCountIf(logicalAgg, connectContext).first;
LogicalAggregate<? extends Plan> countIfAgg = countDistinctMultiExprToCountIf(
logicalAgg, cascadesContext).first;
AggregateParam distinctInputToResultParam
= new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT);
@ -1194,7 +1195,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
* phase of aggregate, please normalize to slot and create a bottom project like NormalizeAggregate.
*/
private Pair<LogicalAggregate<? extends Plan>, List<Count>> countDistinctMultiExprToCountIf(
LogicalAggregate<? extends Plan> aggregate, ConnectContext connectContext) {
LogicalAggregate<? extends Plan> aggregate, CascadesContext cascadesContext) {
ImmutableList.Builder<Count> countIfList = ImmutableList.builder();
List<NamedExpression> newOutput = ExpressionUtils.rewriteDownShortCircuit(
aggregate.getOutputExpressions(), outputChild -> {
@ -1206,7 +1207,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
for (int i = arguments.size() - 2; i >= 0; --i) {
Expression argument = count.getArgument(i);
If ifNull = new If(new IsNull(argument), NullLiteral.INSTANCE, countExpr);
countExpr = assignNullType(ifNull, connectContext);
countExpr = assignNullType(ifNull, cascadesContext);
}
Count countIf = new Count(countExpr);
countIfList.add(countIf);
@ -1224,8 +1225,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
}
// don't invoke the ExpressionNormalization, because the expression maybe simplified and get rid of some slots
private If assignNullType(If ifExpr, ConnectContext context) {
If ifWithCoercion = (If) TypeCoercion.INSTANCE.rewrite(ifExpr, new ExpressionRewriteContext(context));
private If assignNullType(If ifExpr, CascadesContext cascadesContext) {
ExpressionRewriteContext context = new ExpressionRewriteContext(cascadesContext);
If ifWithCoercion = (If) TypeCoercion.INSTANCE.rewrite(ifExpr, context);
Expression trueValue = ifWithCoercion.getArgument(1);
if (trueValue instanceof Cast && trueValue.child(0) instanceof NullLiteral) {
List<Expression> newArgs = Lists.newArrayList(ifWithCoercion.getArguments());

View File

@ -17,13 +17,14 @@
package org.apache.doris.nereids.rules.implementation;
import org.apache.doris.nereids.pattern.GeneratedMemoPatterns;
import org.apache.doris.nereids.rules.PlanRuleFactory;
import org.apache.doris.nereids.rules.RulePromise;
/**
* interface for all implementation rule factories.
*/
public interface ImplementationRuleFactory extends PlanRuleFactory {
public interface ImplementationRuleFactory extends PlanRuleFactory, GeneratedMemoPatterns {
@Override
default RulePromise defaultPromise() {
return RulePromise.IMPLEMENT;

View File

@ -24,6 +24,7 @@ import org.apache.doris.catalog.MaterializedIndex;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
@ -56,7 +57,6 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
@ -484,9 +484,14 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
Set<Slot> nonVirtualRequiredScanOutput = requiredScanOutput.stream()
.filter(slot -> !(slot instanceof VirtualSlotReference))
.collect(ImmutableSet.toImmutableSet());
Preconditions.checkArgument(scan.getOutputSet().containsAll(nonVirtualRequiredScanOutput),
String.format("Scan's output (%s) should contains all the input required scan output (%s).",
scan.getOutput(), nonVirtualRequiredScanOutput));
// use if condition to skip String.format() and speed up
if (!scan.getOutputSet().containsAll(nonVirtualRequiredScanOutput)) {
throw new AnalysisException(
String.format("Scan's output (%s) should contains all the input required scan output (%s).",
scan.getOutput(), nonVirtualRequiredScanOutput));
}
OlapTable table = scan.getTable();
switch (scan.getTable().getKeysType()) {
case AGG_KEYS:

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.mv;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.MaterializedIndex;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.nereids.rules.Rule;
@ -114,13 +115,20 @@ public class SelectMaterializedIndexWithoutAggregate extends AbstractSelectMater
LogicalOlapScan scan,
Supplier<Set<Slot>> requiredScanOutputSupplier,
Supplier<Set<Expression>> predicatesSupplier) {
switch (scan.getTable().getKeysType()) {
OlapTable table = scan.getTable();
long baseIndexId = table.getBaseIndexId();
KeysType keysType = scan.getTable().getKeysType();
switch (keysType) {
case AGG_KEYS:
case UNIQUE_KEYS:
break;
case DUP_KEYS:
if (table.getIndexIdToMeta().size() == 1) {
return scan.withMaterializedIndexSelected(PreAggStatus.on(), baseIndexId);
}
break;
default:
throw new RuntimeException("Not supported keys type: " + scan.getTable().getKeysType());
throw new RuntimeException("Not supported keys type: " + keysType);
}
if (scan.getTable().isDupKeysOrMergeOnWrite()) {
// Set pre-aggregation to `on` to keep consistency with legacy logic.
@ -132,8 +140,16 @@ public class SelectMaterializedIndexWithoutAggregate extends AbstractSelectMater
return scan.withMaterializedIndexSelected(PreAggStatus.on(),
selectBestIndex(candidate, scan, predicatesSupplier.get()));
} else {
OlapTable table = scan.getTable();
long baseIndexId = table.getBaseIndexId();
final PreAggStatus preAggStatus;
if (preAggEnabledByHint(scan)) {
// PreAggStatus could be enabled by pre-aggregation hint for agg-keys and unique-keys.
preAggStatus = PreAggStatus.on();
} else {
preAggStatus = PreAggStatus.off("No aggregate on scan.");
}
if (table.getIndexIdToMeta().size() == 1) {
return scan.withMaterializedIndexSelected(preAggStatus, baseIndexId);
}
int baseIndexKeySize = table.getKeyColumnsByIndexId(table.getBaseIndexId()).size();
// No aggregate on scan.
// So only base index and indexes that have all the keys could be used.
@ -143,13 +159,6 @@ public class SelectMaterializedIndexWithoutAggregate extends AbstractSelectMater
.filter(index -> containAllRequiredColumns(index, scan, requiredScanOutputSupplier.get()))
.collect(Collectors.toList());
final PreAggStatus preAggStatus;
if (preAggEnabledByHint(scan)) {
// PreAggStatus could be enabled by pre-aggregation hint for agg-keys and unique-keys.
preAggStatus = PreAggStatus.on();
} else {
preAggStatus = PreAggStatus.off("No aggregate on scan.");
}
if (candidates.size() == 1) {
// `candidates` only have base index.
return scan.withMaterializedIndexSelected(preAggStatus, baseIndexId);

View File

@ -0,0 +1,47 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.PlanRuleFactory;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleFactory;
import org.apache.doris.nereids.rules.RulePromise;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import java.util.List;
/** BatchRewriteRuleFactory */
public interface BatchRewriteRuleFactory extends PlanRuleFactory {
@Override
default RulePromise defaultPromise() {
return RulePromise.REWRITE;
}
@Override
default List<Rule> buildRules() {
Builder<Rule> rules = ImmutableList.builder();
for (RuleFactory ruleFactory : getRuleFactories()) {
rules.addAll(ruleFactory.buildRules());
}
return rules.build();
}
List<RuleFactory> getRuleFactories();
}

View File

@ -17,13 +17,14 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.pattern.GeneratedPlanPatterns;
import org.apache.doris.nereids.rules.PlanRuleFactory;
import org.apache.doris.nereids.rules.RulePromise;
/**
* interface for all rewrite rule factories.
*/
public interface RewriteRuleFactory extends PlanRuleFactory {
public interface RewriteRuleFactory extends PlanRuleFactory, GeneratedPlanPatterns {
@Override
default RulePromise defaultPromise() {
return RulePromise.REWRITE;

View File

@ -24,7 +24,7 @@ import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import com.google.common.base.Preconditions;
@ -45,7 +45,7 @@ public class CheckAndStandardizeWindowFunctionAndFrame extends OneRewriteRuleFac
);
}
private LogicalWindow checkAndStandardize(LogicalWindow<GroupPlan> logicalWindow) {
private LogicalWindow checkAndStandardize(LogicalWindow<Plan> logicalWindow) {
List<NamedExpression> newOutputExpressions = logicalWindow.getWindowExpressions().stream()
.map(expr -> {

View File

@ -23,7 +23,7 @@ import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Project;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
@ -34,7 +34,7 @@ public class EliminateAggregate extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalAggregate(logicalAggregate()).then(outerAgg -> {
LogicalAggregate<GroupPlan> innerAgg = outerAgg.child();
LogicalAggregate<Plan> innerAgg = outerAgg.child();
if (!isSame(outerAgg.getGroupByExpressions(), innerAgg.getGroupByExpressions())) {
return outerAgg;

View File

@ -19,11 +19,14 @@ package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rewrite.rules.FoldConstantRule;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
@ -46,13 +49,15 @@ import java.util.Set;
public class EliminateGroupByConstant extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalAggregate().then(aggregate -> {
return logicalAggregate().thenApply(ctx -> {
LogicalAggregate<Plan> aggregate = ctx.root;
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
List<Expression> groupByExprs = aggregate.getGroupByExpressions();
List<NamedExpression> outputExprs = aggregate.getOutputExpressions();
Set<Expression> slotGroupByExprs = Sets.newLinkedHashSet();
Expression lit = null;
for (Expression expression : groupByExprs) {
expression = FoldConstantRule.INSTANCE.rewrite(expression);
expression = FoldConstantRule.INSTANCE.rewrite(expression, context);
if (!(expression instanceof Literal)) {
slotGroupByExprs.add(expression);
} else {

View File

@ -32,7 +32,7 @@ import java.util.List;
public class EliminateLimitUnderApply extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalApply(group(), logicalLimit()).then(apply -> {
return logicalApply(any(), logicalLimit()).then(apply -> {
List<Plan> children = new ImmutableList.Builder<Plan>()
.add(apply.left())
.add(apply.right().child())

View File

@ -25,6 +25,8 @@ import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.nereids.util.TypeUtils;
@ -48,7 +50,8 @@ public class EliminateNotNull extends OneRewriteRuleFactory {
public Rule build() {
return logicalFilter()
.when(filter -> filter.getConjuncts().stream().anyMatch(expr -> expr.isGeneratedIsNotNull))
.then(filter -> {
.thenApply(ctx -> {
LogicalFilter<Plan> filter = ctx.root;
// Progress Example: `id > 0 and id is not null and name is not null(generated)`
// predicatesNotContainIsNotNull: `id > 0`
// predicatesNotContainIsNotNull infer nonNullable slots: `id`
@ -66,7 +69,8 @@ public class EliminateNotNull extends OneRewriteRuleFactory {
predicatesNotContainIsNotNull.add(expr);
}
});
Set<Slot> inferNonNotSlots = ExpressionUtils.inferNotNullSlots(predicatesNotContainIsNotNull);
Set<Slot> inferNonNotSlots = ExpressionUtils.inferNotNullSlots(
predicatesNotContainIsNotNull, ctx.cascadesContext);
Set<Expression> keepIsNotNull = slotsFromIsNotNull.stream()
.filter(ExpressionTrait::nullable)

View File

@ -22,8 +22,8 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.TypeUtils;
import org.apache.doris.nereids.util.Utils;
@ -45,7 +45,7 @@ public class EliminateOuterJoin extends OneRewriteRuleFactory {
return logicalFilter(
logicalJoin().when(join -> join.getJoinType().isOuterJoin())
).then(filter -> {
LogicalJoin<GroupPlan, GroupPlan> join = filter.child();
LogicalJoin<Plan, Plan> join = filter.child();
Builder<Expression> conjunctsBuilder = ImmutableSet.builder();
Set<Slot> notNullSlots = new HashSet<>();

View File

@ -34,7 +34,7 @@ public class EliminateSortUnderApply implements RewriteRuleFactory {
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.ELIMINATE_SORT_UNDER_APPLY.build(
logicalApply(group(), logicalSort()).then(apply -> {
logicalApply(any(), logicalSort()).then(apply -> {
List<Plan> children = new ImmutableList.Builder<Plan>()
.add(apply.left())
.add(apply.right().child())
@ -42,8 +42,8 @@ public class EliminateSortUnderApply implements RewriteRuleFactory {
return apply.withChildren(children);
})
),
RuleType.ELIMINATE_SORT_UNDER_APPLY.build(
logicalApply(group(), logicalProject(logicalSort())).then(apply -> {
RuleType.ELIMINATE_SORT_UNDER_APPLY_PROJECT.build(
logicalApply(any(), logicalProject(logicalSort())).then(apply -> {
List<Plan> children = new ImmutableList.Builder<Plan>()
.add(apply.left())
.add(apply.right().withChildren(apply.right().child().child()))

View File

@ -17,63 +17,92 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.annotation.DependsRules;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import org.apache.doris.nereids.trees.plans.logical.OutputSavePoint;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
/**
* remove the project that output same with its child to avoid we get two consecutive projects in best plan.
* for more information, please see <a href="https://github.com/apache/doris/pull/13886">this PR</a>
*/
public class EliminateUnnecessaryProject implements RewriteRuleFactory {
@DependsRules(ColumnPruning.class)
public class EliminateUnnecessaryProject implements CustomRewriter {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.MARK_NECESSARY_PROJECT.build(
logicalSetOperation(logicalProject(), group())
.thenApply(ctx -> {
LogicalProject project = (LogicalProject) ctx.root.child(0);
return ctx.root.withChildren(project.withEliminate(false), ctx.root.child(1));
})
),
RuleType.MARK_NECESSARY_PROJECT.build(
logicalSetOperation(group(), logicalProject())
.thenApply(ctx -> {
LogicalProject project = (LogicalProject) ctx.root.child(1);
return ctx.root.withChildren(ctx.root.child(0), project.withEliminate(false));
})
),
RuleType.ELIMINATE_UNNECESSARY_PROJECT.build(
logicalProject(any())
.when(LogicalProject::canEliminate)
.when(project -> project.getOutputSet().equals(project.child().getOutputSet()))
.thenApply(ctx -> {
int rootGroupId = ctx.cascadesContext.getMemo().getRoot().getGroupId().asInt();
LogicalProject<Plan> project = ctx.root;
// if project is root, we need to ensure the output order is same.
if (project.getGroupExpression().get().getOwnerGroup().getGroupId().asInt()
== rootGroupId) {
if (project.getOutput().equals(project.child().getOutput())) {
return project.child();
} else {
return null;
}
} else {
return project.child();
}
})
),
RuleType.ELIMINATE_UNNECESSARY_PROJECT.build(
logicalProject(logicalEmptyRelation())
.then(project -> new LogicalEmptyRelation(project.getProjects()))
)
);
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
return rewrite(plan, false);
}
private Plan rewrite(Plan plan, boolean outputSavePoint) {
if (plan instanceof LogicalSetOperation) {
return rewriteLogicalSetOperation((LogicalSetOperation) plan, outputSavePoint);
} else if (plan instanceof LogicalProject) {
return rewriteProject((LogicalProject) plan, outputSavePoint);
} else if (plan instanceof OutputSavePoint) {
return rewriteChildren(plan, true);
} else {
return rewriteChildren(plan, outputSavePoint);
}
}
private Plan rewriteProject(LogicalProject<Plan> project, boolean outputSavePoint) {
if (project.child() instanceof LogicalEmptyRelation) {
// eliminate unnecessary project
return new LogicalEmptyRelation(project.getProjects());
} else if (project.canEliminate() && outputSavePoint
&& project.getOutputSet().equals(project.child().getOutputSet())) {
// eliminate unnecessary project
return rewrite(project.child(), outputSavePoint);
} else if (project.canEliminate() && project.getOutput().equals(project.child().getOutput())) {
// eliminate unnecessary project
return rewrite(project.child(), outputSavePoint);
} else {
return rewriteChildren(project, true);
}
}
private Plan rewriteLogicalSetOperation(LogicalSetOperation set, boolean outputSavePoint) {
if (set.arity() == 2) {
Plan left = set.child(0);
Plan right = set.child(1);
boolean changed = false;
if (isCanEliminateProject(left)) {
changed = true;
left = ((LogicalProject) left).withEliminate(false);
}
if (isCanEliminateProject(right)) {
changed = true;
right = ((LogicalProject) right).withEliminate(false);
}
if (changed) {
set = (LogicalSetOperation) set.withChildren(left, right);
}
}
return rewriteChildren(set, outputSavePoint);
}
private Plan rewriteChildren(Plan plan, boolean outputSavePoint) {
List<Plan> newChildren = new ArrayList<>();
boolean hasNewChildren = false;
for (Plan child : plan.children()) {
Plan newChild = rewrite(child, outputSavePoint);
if (newChild != child) {
hasNewChildren = true;
}
newChildren.add(newChild);
}
return hasNewChildren ? plan.withChildren(newChildren) : plan;
}
private static boolean isCanEliminateProject(Plan plan) {
return plan instanceof LogicalProject && ((LogicalProject<?>) plan).canEliminate();
}
}

View File

@ -21,8 +21,8 @@ import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
@ -40,7 +40,7 @@ public class ExtractFilterFromCrossJoin extends OneRewriteRuleFactory {
public Rule build() {
return crossLogicalJoin()
.then(join -> {
LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(JoinType.CROSS_JOIN,
LogicalJoin<Plan, Plan> newJoin = new LogicalJoin<>(JoinType.CROSS_JOIN,
ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, join.getHint(),
join.left(), join.right());
Set<Expression> predicates = Stream.concat(join.getHashJoinConjuncts().stream(),

View File

@ -115,7 +115,6 @@ public class ExtractSingleTableExpressionFromDisjunction extends OneRewriteRuleF
redundants.add(ExpressionUtils.or(extractForAll));
}
}
}
if (redundants.isEmpty()) {
return new LogicalFilter<>(filter.getConjuncts(), true, filter.child());

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.annotation.DependsRules;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
@ -43,6 +44,9 @@ import java.util.List;
* CAUTION:
* This rule must be applied after BindSlotReference
*/
@DependsRules({
PushFilterInsideJoin.class
})
public class FindHashConditionForJoin extends OneRewriteRuleFactory {
@Override
public Rule build() {

View File

@ -50,7 +50,7 @@ public class HideOneRowRelationUnderUnion implements AnalysisRuleFactory {
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.HIDE_ONE_ROW_RELATION_UNDER_UNION.build(
logicalUnion(logicalOneRowRelation().when(LogicalOneRowRelation::buildUnionNode), group())
logicalUnion(logicalOneRowRelation().when(LogicalOneRowRelation::buildUnionNode), any())
.then(union -> {
List<Plan> newChildren = new ImmutableList.Builder<Plan>()
.add(((LogicalOneRowRelation) union.child(0)).withBuildUnionNode(false))
@ -60,7 +60,7 @@ public class HideOneRowRelationUnderUnion implements AnalysisRuleFactory {
})
),
RuleType.HIDE_ONE_ROW_RELATION_UNDER_UNION.build(
logicalUnion(group(), logicalOneRowRelation().when(LogicalOneRowRelation::buildUnionNode))
logicalUnion(any(), logicalOneRowRelation().when(LogicalOneRowRelation::buildUnionNode))
.then(union -> {
List<Plan> children = new ImmutableList.Builder<Plan>()
.add(union.child(0))

View File

@ -21,6 +21,8 @@ import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
@ -43,9 +45,10 @@ public class InferFilterNotNull extends OneRewriteRuleFactory {
public Rule build() {
return logicalFilter()
.when(filter -> filter.getConjuncts().stream().noneMatch(expr -> expr.isGeneratedIsNotNull))
.then(filter -> {
.thenApply(ctx -> {
LogicalFilter<Plan> filter = ctx.root;
Set<Expression> predicates = filter.getConjuncts();
Set<Expression> isNotNull = ExpressionUtils.inferNotNull(predicates);
Set<Expression> isNotNull = ExpressionUtils.inferNotNull(predicates, ctx.cascadesContext);
if (isNotNull.isEmpty() || predicates.containsAll(isNotNull)) {
return null;
}

View File

@ -44,7 +44,8 @@ public class InferJoinNotNull extends OneRewriteRuleFactory {
// TODO: maybe consider ANTI?
return logicalJoin().when(join -> join.getJoinType().isInnerJoin() || join.getJoinType().isSemiJoin())
.whenNot(LogicalJoin::isGenerateIsNotNull)
.then(join -> {
.thenApply(ctx -> {
LogicalJoin<Plan, Plan> join = ctx.root;
Set<Expression> conjuncts = new HashSet<>();
conjuncts.addAll(join.getHashJoinConjuncts());
conjuncts.addAll(join.getOtherJoinConjuncts());
@ -52,15 +53,19 @@ public class InferJoinNotNull extends OneRewriteRuleFactory {
Plan left = join.left();
Plan right = join.right();
if (join.getJoinType().isInnerJoin()) {
Set<Expression> leftNotNull = ExpressionUtils.inferNotNull(conjuncts, join.left().getOutputSet());
Set<Expression> rightNotNull = ExpressionUtils.inferNotNull(conjuncts, join.right().getOutputSet());
Set<Expression> leftNotNull = ExpressionUtils.inferNotNull(
conjuncts, join.left().getOutputSet(), ctx.cascadesContext);
Set<Expression> rightNotNull = ExpressionUtils.inferNotNull(
conjuncts, join.right().getOutputSet(), ctx.cascadesContext);
left = PlanUtils.filterOrSelf(leftNotNull, join.left());
right = PlanUtils.filterOrSelf(rightNotNull, join.right());
} else if (join.getJoinType() == JoinType.LEFT_SEMI_JOIN) {
Set<Expression> leftNotNull = ExpressionUtils.inferNotNull(conjuncts, join.left().getOutputSet());
Set<Expression> leftNotNull = ExpressionUtils.inferNotNull(
conjuncts, join.left().getOutputSet(), ctx.cascadesContext);
left = PlanUtils.filterOrSelf(leftNotNull, join.left());
} else {
Set<Expression> rightNotNull = ExpressionUtils.inferNotNull(conjuncts, join.right().getOutputSet());
Set<Expression> rightNotNull = ExpressionUtils.inferNotNull(
conjuncts, join.right().getOutputSet(), ctx.cascadesContext);
right = PlanUtils.filterOrSelf(rightNotNull, join.right());
}

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
@ -49,10 +50,15 @@ import java.util.stream.Collectors;
* 2. put these predicates into `otherJoinConjuncts` , these predicates are processed in the next
* round of predicate push-down
*/
public class InferPredicates extends DefaultPlanRewriter<JobContext> {
public class InferPredicates extends DefaultPlanRewriter<JobContext> implements CustomRewriter {
private final PredicatePropagation propagation = new PredicatePropagation();
private final PullUpPredicates pollUpPredicates = new PullUpPredicates();
@Override
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
return plan.accept(this, jobContext);
}
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, JobContext context) {
join = (LogicalJoin<? extends Plan, ? extends Plan>) super.visit(join, context);

View File

@ -30,16 +30,12 @@ import com.google.common.collect.ImmutableSet;
* this rule aims to merge consecutive filters.
* For example:
* logical plan tree:
* project
* |
* filter(a>0)
* |
* filter(b>0)
* |
* scan
* transformed to:
* project
* |
* filter(a>0 and b>0)
* |
* scan

View File

@ -22,7 +22,7 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
import com.google.common.collect.Lists;
@ -37,7 +37,7 @@ public class MergeGenerates extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalGenerate(logicalGenerate()).then(top -> {
LogicalGenerate<GroupPlan> bottom = top.child();
LogicalGenerate<Plan> bottom = top.child();
Set<Slot> topGeneratorSlots = top.getInputSlots();
if (bottom.getGeneratorOutput().stream().anyMatch(topGeneratorSlots::contains)) {
// top generators use bottom's generator's output, cannot merge.

View File

@ -21,7 +21,7 @@ import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import java.util.List;
@ -47,7 +47,7 @@ public class MergeProjects extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalProject(logicalProject()).then(project -> {
LogicalProject<GroupPlan> childProject = project.child();
LogicalProject<Plan> childProject = project.child();
List<NamedExpression> projectExpressions = project.mergeProjections(childProject);
return new LogicalProject<>(projectExpressions, childProject.child(0));
}).toRule(RuleType.MERGE_PROJECTS);

View File

@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.stream.Stream;
/**
* optimization.
@ -51,48 +52,44 @@ public class MergeSetOperations implements RewriteRuleFactory {
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.MERGE_SET_OPERATION.build(
logicalSetOperation(logicalSetOperation(), group()).thenApply(ctx -> {
LogicalSetOperation parentSetOperation = ctx.root;
LogicalSetOperation childSetOperation = (LogicalSetOperation) parentSetOperation.child(0);
logicalSetOperation(any(), any()).when(MergeSetOperations::canMerge).then(parentSetOperation -> {
List<Plan> newChildren = parentSetOperation.children()
.stream()
.flatMap(child -> {
if (canMerge(parentSetOperation, child)) {
return child.children().stream();
} else {
return Stream.of(child);
}
}).collect(ImmutableList.toImmutableList());
if (isSameClass(parentSetOperation, childSetOperation)
&& isSameQualifierOrChildQualifierIsAll(parentSetOperation, childSetOperation)) {
List<Plan> newChildren = new ImmutableList.Builder<Plan>()
.addAll(childSetOperation.children())
.add(parentSetOperation.child(1))
.build();
return parentSetOperation.withChildren(newChildren);
}
return parentSetOperation;
})
),
RuleType.MERGE_SET_OPERATION.build(
logicalSetOperation(group(), logicalSetOperation()).thenApply(ctx -> {
LogicalSetOperation parentSetOperation = ctx.root;
LogicalSetOperation childSetOperation = (LogicalSetOperation) parentSetOperation.child(1);
if (isSameClass(parentSetOperation, childSetOperation)
&& isSameQualifierOrChildQualifierIsAll(parentSetOperation, childSetOperation)) {
List<Plan> newChildren = new ImmutableList.Builder<Plan>()
.add(parentSetOperation.child(0))
.addAll(childSetOperation.children())
.build();
return parentSetOperation.withNewChildren(newChildren);
}
return parentSetOperation;
return parentSetOperation.withChildren(newChildren);
})
)
);
}
private boolean isSameQualifierOrChildQualifierIsAll(LogicalSetOperation parentSetOperation,
/** canMerge */
public static boolean canMerge(LogicalSetOperation parent) {
Plan left = parent.child(0);
if (canMerge(parent, left)) {
return true;
}
Plan right = parent.child(1);
if (canMerge(parent, right)) {
return true;
}
return false;
}
public static final boolean canMerge(LogicalSetOperation parent, Plan child) {
return child.getClass().equals(parent.getClass())
&& isSameQualifierOrChildQualifierIsAll(parent, (LogicalSetOperation) child);
}
public static final boolean isSameQualifierOrChildQualifierIsAll(LogicalSetOperation parentSetOperation,
LogicalSetOperation childSetOperation) {
return parentSetOperation.getQualifier() == childSetOperation.getQualifier()
|| childSetOperation.getQualifier() == Qualifier.ALL;
}
private boolean isSameClass(LogicalSetOperation parentSetOperation,
LogicalSetOperation childSetOperation) {
return parentSetOperation.getClass().isAssignableFrom(childSetOperation.getClass());
}
}

View File

@ -23,7 +23,7 @@ import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
@ -78,7 +78,7 @@ public class PruneAggChildColumns extends OneRewriteRuleFactory {
* @return null, if there exists an aggregation function that its parameters contains non-constant expr.
* else return a slot with min data type.
*/
private boolean isAggregateWithConstant(LogicalAggregate<GroupPlan> agg) {
private boolean isAggregateWithConstant(LogicalAggregate<Plan> agg) {
for (NamedExpression output : agg.getOutputExpressions()) {
if (output.anyMatch(SlotReference.class::isInstance)) {
return false;

View File

@ -19,7 +19,6 @@ package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
@ -51,7 +50,7 @@ import java.util.stream.Stream;
* |
* scan(k1,k2,k3,v1)
*/
public class PruneFilterChildColumns extends AbstractPushDownProjectRule<LogicalFilter<GroupPlan>> {
public class PruneFilterChildColumns extends AbstractPushDownProjectRule<LogicalFilter<Plan>> {
public PruneFilterChildColumns() {
setRuleType(RuleType.COLUMN_PRUNE_FILTER_CHILD);
@ -59,7 +58,7 @@ public class PruneFilterChildColumns extends AbstractPushDownProjectRule<Logical
}
@Override
protected Plan pushDownProject(LogicalFilter<GroupPlan> filter, Set<Slot> references) {
protected Plan pushDownProject(LogicalFilter<Plan> filter, Set<Slot> references) {
Set<Slot> filterInputSlots = filter.getInputSlots();
Set<Slot> required = Stream.concat(references.stream(), filterInputSlots.stream()).collect(Collectors.toSet());
if (required.containsAll(filter.child().getOutput())) {

View File

@ -21,7 +21,6 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
@ -56,7 +55,7 @@ import java.util.stream.Stream;
* scan scan
*/
public class PruneJoinChildrenColumns
extends AbstractPushDownProjectRule<LogicalJoin<GroupPlan, GroupPlan>> {
extends AbstractPushDownProjectRule<LogicalJoin<Plan, Plan>> {
public PruneJoinChildrenColumns() {
setRuleType(RuleType.COLUMN_PRUNE_JOIN_CHILD);
@ -64,7 +63,7 @@ public class PruneJoinChildrenColumns
}
@Override
protected Plan pushDownProject(LogicalJoin<GroupPlan, GroupPlan> joinPlan,
protected Plan pushDownProject(LogicalJoin<Plan, Plan> joinPlan,
Set<Slot> references) {
Set<ExprId> exprIds = Stream.of(references, joinPlan.getInputSlots())

View File

@ -19,7 +19,6 @@ package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
@ -34,7 +33,7 @@ import java.util.stream.Stream;
* prune join children output.
* pattern: project(sort())
*/
public class PruneSortChildColumns extends AbstractPushDownProjectRule<LogicalSort<GroupPlan>> {
public class PruneSortChildColumns extends AbstractPushDownProjectRule<LogicalSort<Plan>> {
public PruneSortChildColumns() {
setRuleType(RuleType.COLUMN_PRUNE_SORT_CHILD);
@ -42,7 +41,7 @@ public class PruneSortChildColumns extends AbstractPushDownProjectRule<LogicalSo
}
@Override
protected Plan pushDownProject(LogicalSort<GroupPlan> sortPlan, Set<Slot> references) {
protected Plan pushDownProject(LogicalSort<Plan> sortPlan, Set<Slot> references) {
Set<Slot> sortSlots = sortPlan.getOutputSet();
Set<Slot> required = Stream.concat(references.stream(), sortSlots.stream()).collect(Collectors.toSet());
if (required.containsAll(sortPlan.child().getOutput())) {

View File

@ -21,12 +21,13 @@ import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.List;
@ -57,15 +58,15 @@ import java.util.List;
* child
* </pre>
*/
public class ApplyPullFilterOnProjectUnderAgg extends OneRewriteRuleFactory {
public class PullUpCorrelatedFilterUnderApplyAggregateProject extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalApply(group(), logicalAggregate(logicalProject(logicalFilter())))
return logicalApply(any(), logicalAggregate(logicalProject(logicalFilter())))
.when(LogicalApply::isCorrelated).then(apply -> {
LogicalAggregate<LogicalProject<LogicalFilter<GroupPlan>>> agg = apply.right();
LogicalAggregate<LogicalProject<LogicalFilter<Plan>>> agg = apply.right();
LogicalProject<LogicalFilter<GroupPlan>> project = agg.child();
LogicalFilter<GroupPlan> filter = project.child();
LogicalProject<LogicalFilter<Plan>> project = agg.child();
LogicalFilter<Plan> filter = project.child();
List<NamedExpression> newProjects = Lists.newArrayList();
newProjects.addAll(project.getProjects());
filter.child().getOutput().forEach(slot -> {
@ -76,10 +77,9 @@ public class ApplyPullFilterOnProjectUnderAgg extends OneRewriteRuleFactory {
LogicalProject newProject = new LogicalProject<>(newProjects, filter.child());
LogicalFilter newFilter = new LogicalFilter<>(filter.getConjuncts(), newProject);
LogicalAggregate newAgg = new LogicalAggregate<>(agg.getGroupByExpressions(),
agg.getOutputExpressions(), agg.isOrdinalIsResolved(), newFilter);
LogicalAggregate newAgg = agg.withChildren(ImmutableList.of(newFilter));
return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(),
apply.getCorrelationFilter(), apply.left(), newAgg);
}).toRule(RuleType.APPLY_PULL_FILTER_ON_PROJECT_UNDER_AGG);
}).toRule(RuleType.PULL_UP_CORRELATED_FILTER_UNDER_APPLY_AGGREGATE_PROJECT);
}
}

Some files were not shown because too many files have changed in this diff Show More