[enhancement](Nereids) support StatementContext, SET_VAR, and Plan pre/post processor (#11654)

1. add StatementContext, and PlannerContext is renamed to CascadsContext. CascadsContext belong to a StatementContext, and StatementContext belong to a ConnectionContext, and the lifecycle increases in turn. StatementContext can wrap some statement's lifecycle-related state, such as ExpressionId, TableLock. MemoTestUtil can simplify create a CascadesContext and Memo for test.
2. add PlanPreprocessor to process parsed logical plan before copy into memo. and add a PlanPostprocessor to process physical plan after copy out from memo.
3. utilize PlanPreprocessor to process SET_VAR hint, the class is EliminateLogicalSelectHint
4. pass the limit clause in regression test case, in set_var.groovy
This commit is contained in:
924060929
2022-08-12 14:49:11 +08:00
committed by GitHub
parent e353be7dcb
commit ec4347ad39
73 changed files with 1345 additions and 472 deletions

View File

@ -397,8 +397,8 @@ CONCAT_PIPE: '||';
HAT: '^';
COLON: ':';
ARROW: '->';
HENT_START: '/*+';
HENT_END: '*/';
HINT_START: '/*+';
HINT_END: '*/';
STRING
: '\'' ( ~('\''|'\\') | ('\\' .) )* '\''

View File

@ -78,7 +78,7 @@ querySpecification
;
selectClause
: SELECT namedExpressionSeq
: SELECT selectHint? namedExpressionSeq
;
whereClause
@ -109,6 +109,16 @@ havingClause
: HAVING booleanExpression
;
selectHint: HINT_START hintStatements+=hintStatement (COMMA? hintStatements+=hintStatement)* HINT_END;
hintStatement
: hintName=identifier LEFT_PAREN parameters+=hintAssignment (COMMA parameters+=hintAssignment)* RIGHT_PAREN
;
hintAssignment
: key=identifier (EQ (constantValue=constant | identifierValue=identifier))?
;
queryOrganization
: sortClause? limitClause?
;

View File

@ -46,6 +46,14 @@ public class Pair<F, S> {
return new Pair<F, S>(first, second);
}
public F getFirst() {
return first;
}
public S getSecond() {
return second;
}
@Override
/**
* A pair is equal if both parts are equal().

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids;
import org.apache.doris.nereids.analyzer.NereidsAnalyzer;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.rewrite.RewriteBottomUpJob;
@ -30,18 +31,21 @@ import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleFactory;
import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.rules.analysis.Scope;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Optional;
/**
* Context used in memo.
*/
public class PlannerContext {
public class CascadesContext {
private final Memo memo;
private final ConnectContext connectContext;
private final StatementContext statementContext;
private RuleSet ruleSet;
private JobPool jobPool;
private final JobScheduler jobScheduler;
@ -51,13 +55,27 @@ public class PlannerContext {
* Constructor of OptimizerContext.
*
* @param memo {@link Memo} reference
* @param statementContext {@link StatementContext} reference
*/
public PlannerContext(Memo memo, ConnectContext connectContext) {
public CascadesContext(Memo memo, StatementContext statementContext) {
this.memo = memo;
this.connectContext = connectContext;
this.statementContext = statementContext;
this.ruleSet = new RuleSet();
this.jobPool = new JobStack();
this.jobScheduler = new SimpleJobScheduler();
this.currentJobContext = new JobContext(this, PhysicalProperties.ANY, Double.MAX_VALUE);
}
public static CascadesContext newContext(StatementContext statementContext, Plan initPlan) {
return new CascadesContext(new Memo(initPlan), statementContext);
}
public NereidsAnalyzer newAnalyzer() {
return new NereidsAnalyzer(this);
}
public NereidsAnalyzer newAnalyzer(Optional<Scope> outerScope) {
return new NereidsAnalyzer(this, outerScope);
}
public void pushJob(Job job) {
@ -69,7 +87,11 @@ public class PlannerContext {
}
public ConnectContext getConnectContext() {
return connectContext;
return statementContext.getConnectContext();
}
public StatementContext getStatementContext() {
return statementContext;
}
public RuleSet getRuleSet() {
@ -100,41 +122,36 @@ public class PlannerContext {
this.currentJobContext = currentJobContext;
}
public PlannerContext setDefaultJobContext() {
this.currentJobContext = new JobContext(this, PhysicalProperties.ANY, Double.MAX_VALUE);
return this;
}
public PlannerContext setJobContext(PhysicalProperties physicalProperties) {
public CascadesContext setJobContext(PhysicalProperties physicalProperties) {
this.currentJobContext = new JobContext(this, physicalProperties, Double.MAX_VALUE);
return this;
}
public PlannerContext bottomUpRewrite(RuleFactory... rules) {
public CascadesContext bottomUpRewrite(RuleFactory... rules) {
return execute(new RewriteBottomUpJob(memo.getRoot(), currentJobContext, ImmutableList.copyOf(rules)));
}
public PlannerContext bottomUpRewrite(Rule... rules) {
public CascadesContext bottomUpRewrite(Rule... rules) {
return bottomUpRewrite(ImmutableList.copyOf(rules));
}
public PlannerContext bottomUpRewrite(List<Rule> rules) {
public CascadesContext bottomUpRewrite(List<Rule> rules) {
return execute(new RewriteBottomUpJob(memo.getRoot(), rules, currentJobContext));
}
public PlannerContext topDownRewrite(RuleFactory... rules) {
public CascadesContext topDownRewrite(RuleFactory... rules) {
return execute(new RewriteTopDownJob(memo.getRoot(), currentJobContext, ImmutableList.copyOf(rules)));
}
public PlannerContext topDownRewrite(Rule... rules) {
public CascadesContext topDownRewrite(Rule... rules) {
return topDownRewrite(ImmutableList.copyOf(rules));
}
public PlannerContext topDownRewrite(List<Rule> rules) {
public CascadesContext topDownRewrite(List<Rule> rules) {
return execute(new RewriteTopDownJob(memo.getRoot(), rules, currentJobContext));
}
private PlannerContext execute(Job job) {
private CascadesContext execute(Job job) {
pushJob(job);
jobScheduler.executeJobPool(this);
return this;

View File

@ -21,12 +21,10 @@ import org.apache.doris.analysis.DescriptorTable;
import org.apache.doris.analysis.StatementBase;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.UserException;
import org.apache.doris.nereids.analyzer.NereidsAnalyzer;
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator;
import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
import org.apache.doris.nereids.jobs.batch.DisassembleRulesJob;
import org.apache.doris.nereids.jobs.batch.FinalizeAnalyzeJob;
import org.apache.doris.nereids.jobs.batch.JoinReorderRulesJob;
import org.apache.doris.nereids.jobs.batch.NormalizeExpressionRulesJob;
import org.apache.doris.nereids.jobs.batch.OptimizeRulesJob;
@ -34,6 +32,8 @@ import org.apache.doris.nereids.jobs.batch.PredicatePushDownRulesJob;
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.processor.post.PlanPostprocessors;
import org.apache.doris.nereids.processor.pre.PlanPreprocessors;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
@ -55,13 +55,13 @@ import java.util.stream.Collectors;
*/
public class NereidsPlanner extends Planner {
private PlannerContext plannerContext;
private final ConnectContext ctx;
private CascadesContext cascadesContext;
private final StatementContext statementContext;
private List<ScanNode> scanNodeList = null;
private DescriptorTable descTable;
public NereidsPlanner(ConnectContext ctx) {
this.ctx = ctx;
public NereidsPlanner(StatementContext statementContext) {
this.statementContext = statementContext;
}
@Override
@ -72,7 +72,7 @@ public class NereidsPlanner extends Planner {
}
LogicalPlanAdapter logicalPlanAdapter = (LogicalPlanAdapter) queryStmt;
PhysicalPlan physicalPlan = plan(logicalPlanAdapter.getLogicalPlan(), PhysicalProperties.ANY, ctx);
PhysicalPlan physicalPlan = plan(logicalPlanAdapter.getLogicalPlan(), PhysicalProperties.ANY);
PhysicalPlanTranslator physicalPlanTranslator = new PhysicalPlanTranslator();
PlanTranslatorContext planTranslatorContext = new PlanTranslatorContext();
@ -94,46 +94,65 @@ public class NereidsPlanner extends Planner {
*
* @param plan wait for plan
* @param outputProperties physical properties constraints
* @param connectContext connect context for this query
* @return physical plan generated by this planner
* @throws AnalysisException throw exception if failed in ant stage
*/
// TODO: refactor, just demo code here
public PhysicalPlan plan(LogicalPlan plan, PhysicalProperties outputProperties, ConnectContext connectContext)
throws AnalysisException {
plannerContext = new NereidsAnalyzer(connectContext)
.analyzeWithPlannerContext(plan)
// TODO: revisit this. What is the appropriate time to set physical properties? Maybe before enter
// cascades style optimize phase.
.setJobContext(outputProperties);
public PhysicalPlan plan(LogicalPlan plan, PhysicalProperties outputProperties) throws AnalysisException {
finalizeAnalyze();
// pre-process logical plan out of memo, e.g. process SET_VAR hint
plan = preprocess(plan);
initCascadesContext(plan);
// resolve column, table and function
analyze();
// rule-based optimize
rewrite();
// TODO: remove this condition, when stats collector is fully developed.
if (ConnectContext.get().getSessionVariable().isEnableNereidsCBO()) {
deriveStats();
}
// TODO: What is the appropriate time to set physical properties? Maybe before enter.
// cascades style optimize phase.
// cost-based optimize and explode plan space
optimize();
// Get plan directly. Just for SSB.
return getRoot().extractPlan();
PhysicalPlan physicalPlan = getRoot().extractPlan();
// post-process physical plan out of memo, just for future use.
return postprocess(physicalPlan);
}
private void finalizeAnalyze() {
new FinalizeAnalyzeJob(plannerContext).execute();
private LogicalPlan preprocess(LogicalPlan logicalPlan) {
return new PlanPreprocessors(statementContext).process(logicalPlan);
}
private void initCascadesContext(LogicalPlan plan) {
cascadesContext = CascadesContext.newContext(statementContext, plan);
}
private void analyze() {
cascadesContext.newAnalyzer().analyze();
}
/**
* Logical plan rewrite based on a series of heuristic rules.
*/
private void rewrite() {
new NormalizeExpressionRulesJob(plannerContext).execute();
new JoinReorderRulesJob(plannerContext).execute();
new PredicatePushDownRulesJob(plannerContext).execute();
new DisassembleRulesJob(plannerContext).execute();
new NormalizeExpressionRulesJob(cascadesContext).execute();
new JoinReorderRulesJob(cascadesContext).execute();
new PredicatePushDownRulesJob(cascadesContext).execute();
new DisassembleRulesJob(cascadesContext).execute();
}
private void deriveStats() {
new DeriveStatsJob(getRoot().getLogicalExpression(), plannerContext.getCurrentJobContext()).execute();
new DeriveStatsJob(getRoot().getLogicalExpression(), cascadesContext.getCurrentJobContext()).execute();
}
/**
@ -141,7 +160,11 @@ public class NereidsPlanner extends Planner {
* try to find best plan under the guidance of statistic information and cost model.
*/
private void optimize() {
new OptimizeRulesJob(plannerContext).execute();
new OptimizeRulesJob(cascadesContext).execute();
}
private PhysicalPlan postprocess(PhysicalPlan physicalPlan) {
return new PlanPostprocessors(cascadesContext).process(physicalPlan);
}
@Override
@ -150,7 +173,7 @@ public class NereidsPlanner extends Planner {
}
public Group getRoot() {
return plannerContext.getMemo().getRoot();
return cascadesContext.getMemo().getRoot();
}
private PhysicalPlan chooseBestPlan(Group rootGroup, PhysicalProperties physicalProperties)

View File

@ -0,0 +1,53 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids;
import org.apache.doris.analysis.StatementBase;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.OriginStatement;
/**
* Statement context for nereids
*/
public class StatementContext {
private final ConnectContext connectContext;
private final OriginStatement originStatement;
private StatementBase parsedStatement;
public StatementContext(ConnectContext connectContext, OriginStatement originStatement) {
this.connectContext = connectContext;
this.originStatement = originStatement;
}
public ConnectContext getConnectContext() {
return connectContext;
}
public OriginStatement getOriginStatement() {
return originStatement;
}
public StatementBase getParsedStatement() {
return parsedStatement;
}
public void setParsedStatement(StatementBase parsedStatement) {
this.parsedStatement = parsedStatement;
}
}

View File

@ -17,15 +17,12 @@
package org.apache.doris.nereids.analyzer;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.batch.AnalyzeRulesJob;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.jobs.batch.FinalizeAnalyzeJob;
import org.apache.doris.nereids.rules.analysis.Scope;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.qe.ConnectContext;
import java.util.Objects;
import java.util.Optional;
/**
@ -33,64 +30,28 @@ import java.util.Optional;
* TODO: revisit the interface after subquery analysis is supported.
*/
public class NereidsAnalyzer {
private final ConnectContext connectContext;
private final CascadesContext cascadesContext;
private final Optional<Scope> outerScope;
public NereidsAnalyzer(ConnectContext connectContext) {
this.connectContext = connectContext;
public NereidsAnalyzer(CascadesContext cascadesContext) {
this(cascadesContext, Optional.empty());
}
/**
* Analyze plan.
*/
public LogicalPlan analyze(Plan plan) {
return analyze(plan, Optional.empty());
public NereidsAnalyzer(CascadesContext cascadesContext, Optional<Scope> outerScope) {
this.cascadesContext = Objects.requireNonNull(cascadesContext, "cascadesContext can not be null");
this.outerScope = Objects.requireNonNull(outerScope, "outerScope can not be null");
}
/**
* Analyze plan with scope.
*/
public LogicalPlan analyze(Plan plan, Optional<Scope> scope) {
return (LogicalPlan) analyzeWithPlannerContext(plan, scope).getMemo().copyOut();
public void analyze() {
new AnalyzeRulesJob(cascadesContext, outerScope).execute();
new FinalizeAnalyzeJob(cascadesContext).execute();
}
/**
* Convert SQL String to analyzed plan.
*/
public LogicalPlan analyze(String sql) {
return analyze(parse(sql), Optional.empty());
public CascadesContext getCascadesContext() {
return cascadesContext;
}
/**
* Analyze plan and return {@link PlannerContext}.
* Thus returned {@link PlannerContext} could be reused to do
* further plan optimization without creating new {@link Memo} and {@link PlannerContext}.
*/
public PlannerContext analyzeWithPlannerContext(Plan plan) {
return analyzeWithPlannerContext(plan, Optional.empty());
}
/**
* Analyze plan with scope.
*/
public PlannerContext analyzeWithPlannerContext(Plan plan, Optional<Scope> scope) {
PlannerContext plannerContext = new Memo(plan)
.newPlannerContext(connectContext)
.setDefaultJobContext();
new AnalyzeRulesJob(plannerContext, scope).execute();
return plannerContext;
}
/**
* Convert SQL String to analyzed plan without copying out of {@link Memo}.
* Thus returned {@link PlannerContext} could be reused to do
* further plan optimization without creating new {@link Memo} and {@link PlannerContext}.
*/
public PlannerContext analyzeWithPlannerContext(String sql) {
return analyzeWithPlannerContext(parse(sql));
}
private Plan parse(String sql) {
return new NereidsParser().parseSingle(sql);
public Optional<Scope> getOuterScope() {
return outerScope;
}
}

View File

@ -34,11 +34,13 @@ import org.apache.doris.analysis.LikePredicate;
import org.apache.doris.analysis.NullLiteral;
import org.apache.doris.analysis.StringLiteral;
import org.apache.doris.analysis.TimestampArithmeticExpr;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Arithmetic;
import org.apache.doris.nereids.trees.expressions.Between;
import org.apache.doris.nereids.trees.expressions.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
@ -189,6 +191,15 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
return new IntLiteral(integerLiteral.getValue());
}
@Override
public Expr visitBigIntLiteral(BigIntLiteral bigIntLiteral, PlanTranslatorContext context) {
try {
return new IntLiteral(bigIntLiteral.getValue(), Type.BIGINT);
} catch (Throwable t) {
throw new IllegalStateException("Can not translate BigIntLiteral: " + bigIntLiteral.getValue(), t);
}
}
@Override
public Expr visitNullLiteral(org.apache.doris.nereids.trees.expressions.NullLiteral nullLiteral,
PlanTranslatorContext context) {

View File

@ -17,25 +17,25 @@
package org.apache.doris.nereids.jobs;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.properties.PhysicalProperties;
/**
* Context for one job in Nereids' cascades framework.
*/
public class JobContext {
private final PlannerContext plannerContext;
private final CascadesContext cascadesContext;
private final PhysicalProperties requiredProperties;
private double costUpperBound;
public JobContext(PlannerContext plannerContext, PhysicalProperties requiredProperties, double costUpperBound) {
this.plannerContext = plannerContext;
public JobContext(CascadesContext cascadesContext, PhysicalProperties requiredProperties, double costUpperBound) {
this.cascadesContext = cascadesContext;
this.requiredProperties = requiredProperties;
this.costUpperBound = costUpperBound;
}
public PlannerContext getPlannerContext() {
return plannerContext;
public CascadesContext getPlannerContext() {
return cascadesContext;
}
public PhysicalProperties getRequiredProperties() {

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.analysis.BindFunction;
import org.apache.doris.nereids.rules.analysis.BindRelation;
import org.apache.doris.nereids.rules.analysis.BindSlotReference;
@ -35,11 +35,11 @@ public class AnalyzeRulesJob extends BatchRulesJob {
/**
* Execute the analysis job with scope.
* @param plannerContext planner context for execute job
* @param cascadesContext planner context for execute job
* @param scope Parse the symbolic scope of the field
*/
public AnalyzeRulesJob(PlannerContext plannerContext, Optional<Scope> scope) {
super(plannerContext);
public AnalyzeRulesJob(CascadesContext cascadesContext, Optional<Scope> scope) {
super(cascadesContext);
rulesJob.addAll(ImmutableList.of(
bottomUpBatch(ImmutableList.of(
new BindRelation(),

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.cascades.OptimizeGroupJob;
import org.apache.doris.nereids.jobs.rewrite.RewriteBottomUpJob;
@ -35,11 +35,11 @@ import java.util.Objects;
* Each batch of rules will be uniformly executed.
*/
public abstract class BatchRulesJob {
protected PlannerContext plannerContext;
protected CascadesContext cascadesContext;
protected List<Job> rulesJob = new ArrayList<>();
BatchRulesJob(PlannerContext plannerContext) {
this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext can not null");
BatchRulesJob(CascadesContext cascadesContext) {
this.cascadesContext = Objects.requireNonNull(cascadesContext, "cascadesContext can not null");
}
protected Job bottomUpBatch(List<RuleFactory> ruleFactories) {
@ -48,9 +48,9 @@ public abstract class BatchRulesJob {
rules.addAll(ruleFactory.buildRules());
}
return new RewriteBottomUpJob(
plannerContext.getMemo().getRoot(),
cascadesContext.getMemo().getRoot(),
rules,
plannerContext.getCurrentJobContext());
cascadesContext.getCurrentJobContext());
}
protected Job topDownBatch(List<RuleFactory> ruleFactories) {
@ -59,21 +59,21 @@ public abstract class BatchRulesJob {
rules.addAll(ruleFactory.buildRules());
}
return new RewriteTopDownJob(
plannerContext.getMemo().getRoot(),
cascadesContext.getMemo().getRoot(),
rules,
plannerContext.getCurrentJobContext());
cascadesContext.getCurrentJobContext());
}
protected Job optimize() {
return new OptimizeGroupJob(
plannerContext.getMemo().getRoot(),
plannerContext.getCurrentJobContext());
cascadesContext.getMemo().getRoot(),
cascadesContext.getCurrentJobContext());
}
public void execute() {
for (Job job : rulesJob) {
plannerContext.pushJob(job);
plannerContext.getJobScheduler().executeJobPool(plannerContext);
cascadesContext.pushJob(job);
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
}
}
}

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
import com.google.common.collect.ImmutableList;
@ -26,8 +26,8 @@ import com.google.common.collect.ImmutableList;
* Execute the disassemble rules.
*/
public class DisassembleRulesJob extends BatchRulesJob {
public DisassembleRulesJob(PlannerContext plannerContext) {
super(plannerContext);
public DisassembleRulesJob(CascadesContext cascadesContext) {
super(cascadesContext);
rulesJob.addAll(ImmutableList.of(
topDownBatch(ImmutableList.of(
new AggregateDisassemble())

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.analysis.EliminateAliasNode;
import com.google.common.collect.ImmutableList;
@ -29,10 +29,10 @@ public class FinalizeAnalyzeJob extends BatchRulesJob {
/**
* constructor
* @param plannerContext ctx
* @param cascadesContext ctx
*/
public FinalizeAnalyzeJob(PlannerContext plannerContext) {
super(plannerContext);
public FinalizeAnalyzeJob(CascadesContext cascadesContext) {
super(cascadesContext);
rulesJob.addAll(ImmutableList.of(
bottomUpBatch(ImmutableList.of(new EliminateAliasNode()))
));

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;
import com.google.common.collect.ImmutableList;
@ -27,8 +27,8 @@ import com.google.common.collect.ImmutableList;
*/
public class JoinReorderRulesJob extends BatchRulesJob {
public JoinReorderRulesJob(PlannerContext plannerContext) {
super(plannerContext);
public JoinReorderRulesJob(CascadesContext cascadesContext) {
super(cascadesContext);
rulesJob.addAll(ImmutableList.of(
topDownBatch(ImmutableList.of(new ReorderJoin()))
));

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization;
import com.google.common.collect.ImmutableList;
@ -29,10 +29,10 @@ public class NormalizeExpressionRulesJob extends BatchRulesJob {
/**
* Constructor.
* @param plannerContext context for applying rules.
* @param cascadesContext context for applying rules.
*/
public NormalizeExpressionRulesJob(PlannerContext plannerContext) {
super(plannerContext);
public NormalizeExpressionRulesJob(CascadesContext cascadesContext) {
super(cascadesContext);
rulesJob.addAll(ImmutableList.of(
topDownBatch(ImmutableList.of(
new ExpressionNormalization()

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import com.google.common.collect.ImmutableList;
@ -25,8 +25,8 @@ import com.google.common.collect.ImmutableList;
* cascade optimizer added.
*/
public class OptimizeRulesJob extends BatchRulesJob {
public OptimizeRulesJob(PlannerContext plannerContext) {
super(plannerContext);
public OptimizeRulesJob(CascadesContext cascadesContext) {
super(cascadesContext);
rulesJob.addAll(ImmutableList.of(
optimize()
));

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.rewrite.logical.PushPredicateThroughJoin;
import com.google.common.collect.ImmutableList;
@ -26,8 +26,8 @@ import com.google.common.collect.ImmutableList;
* execute predicate push down job.
*/
public class PredicatePushDownRulesJob extends BatchRulesJob {
public PredicatePushDownRulesJob(PlannerContext plannerContext) {
super(plannerContext);
public PredicatePushDownRulesJob(CascadesContext cascadesContext) {
super(cascadesContext);
rulesJob.addAll(ImmutableList.of(
topDownBatch(ImmutableList.of(
new PushPredicateThroughJoin())

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.jobs.scheduler;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.jobs.Job;
@ -25,7 +25,7 @@ import org.apache.doris.nereids.jobs.Job;
* Scheduler to schedule jobs in Nereids.
*/
public interface JobScheduler {
void executeJob(Job job, PlannerContext context);
void executeJob(Job job, CascadesContext context);
void executeJobPool(PlannerContext plannerContext) throws AnalysisException;
void executeJobPool(CascadesContext cascadesContext) throws AnalysisException;
}

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.jobs.scheduler;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.jobs.Job;
@ -26,13 +26,13 @@ import org.apache.doris.nereids.jobs.Job;
*/
public class SimpleJobScheduler implements JobScheduler {
@Override
public void executeJob(Job job, PlannerContext context) {
public void executeJob(Job job, CascadesContext context) {
}
@Override
public void executeJobPool(PlannerContext plannerContext) throws AnalysisException {
JobPool pool = plannerContext.getJobPool();
public void executeJobPool(CascadesContext cascadesContext) throws AnalysisException {
JobPool pool = cascadesContext.getJobPool();
while (!pool.isEmpty()) {
Job job = pool.pop();
job.execute();

View File

@ -19,12 +19,12 @@ package org.apache.doris.nereids.memo;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
@ -100,21 +100,19 @@ public class Memo {
}
public Plan copyOut() {
return groupToTreeNode(root);
return copyOut(root);
}
/**
* Utility function to create a new {@link PlannerContext} with this Memo.
* copyOut the group.
* @param group the group what want to copyOut
* @return plan
*/
public PlannerContext newPlannerContext(ConnectContext connectContext) {
return new PlannerContext(this, connectContext);
}
private Plan groupToTreeNode(Group group) {
public Plan copyOut(Group group) {
GroupExpression logicalExpression = group.getLogicalExpression();
List<Plan> childrenNode = Lists.newArrayList();
for (Group child : logicalExpression.children()) {
childrenNode.add(groupToTreeNode(child));
childrenNode.add(copyOut(child));
}
Plan result = logicalExpression.getPlan();
if (result.children().size() == 0) {
@ -123,6 +121,13 @@ public class Memo {
return result.withChildren(childrenNode);
}
/**
* Utility function to create a new {@link CascadesContext} with this Memo.
*/
public CascadesContext newCascadesContext(StatementContext statementContext) {
return new CascadesContext(this, statementContext);
}
/**
* Insert groupExpression to target group.
* If group expression is already in memo and target group is not null, we merge two groups.

View File

@ -32,6 +32,8 @@ import org.apache.doris.nereids.DorisParser.DereferenceContext;
import org.apache.doris.nereids.DorisParser.ExistContext;
import org.apache.doris.nereids.DorisParser.ExplainContext;
import org.apache.doris.nereids.DorisParser.FromClauseContext;
import org.apache.doris.nereids.DorisParser.HintAssignmentContext;
import org.apache.doris.nereids.DorisParser.HintStatementContext;
import org.apache.doris.nereids.DorisParser.IdentifierListContext;
import org.apache.doris.nereids.DorisParser.IdentifierSeqContext;
import org.apache.doris.nereids.DorisParser.IntegerLiteralContext;
@ -55,6 +57,7 @@ import org.apache.doris.nereids.DorisParser.QueryOrganizationContext;
import org.apache.doris.nereids.DorisParser.RegularQuerySpecificationContext;
import org.apache.doris.nereids.DorisParser.RelationContext;
import org.apache.doris.nereids.DorisParser.SelectClauseContext;
import org.apache.doris.nereids.DorisParser.SelectHintContext;
import org.apache.doris.nereids.DorisParser.SingleStatementContext;
import org.apache.doris.nereids.DorisParser.SortClauseContext;
import org.apache.doris.nereids.DorisParser.SortItemContext;
@ -75,10 +78,12 @@ import org.apache.doris.nereids.analyzer.UnboundStar;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.exceptions.ParseException;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.properties.SelectHint;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Between;
import org.apache.doris.nereids.trees.expressions.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
@ -122,11 +127,13 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSelectHint;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.antlr.v4.runtime.ParserRuleContext;
import org.antlr.v4.runtime.RuleContext;
import org.antlr.v4.runtime.Token;
@ -134,11 +141,13 @@ import org.antlr.v4.runtime.tree.ParseTree;
import org.antlr.v4.runtime.tree.RuleNode;
import org.antlr.v4.runtime.tree.TerminalNode;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
@ -570,9 +579,13 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
@Override
public Literal visitIntegerLiteral(IntegerLiteralContext ctx) {
// TODO: throw NumberFormatException
Integer l = Integer.valueOf(ctx.getText());
return new IntegerLiteral(l);
BigInteger bigInt = new BigInteger(ctx.getText());
if (BigInteger.valueOf(bigInt.intValue()).equals(bigInt)) {
return new IntegerLiteral(bigInt.intValue());
} else {
// throw exception if out of long range
return new BigIntLiteral(bigInt.longValueExact());
}
}
@Override
@ -740,7 +753,6 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
Optional<WhereClauseContext> whereClause,
Optional<AggClauseContext> aggClause) {
return ParserUtils.withOrigin(ctx, () -> {
// TODO: process hint
// TODO: add lateral views
// from -> where -> group by -> having -> select
@ -750,7 +762,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
// TODO: replace and process having at this position
LogicalPlan having = aggregate; // LogicalPlan having = withFilter(aggregate, havingClause);
LogicalPlan projection = withProjection(having, selectClause, aggClause);
return projection;
return withSelectHint(projection, selectClause.selectHint());
});
}
@ -807,6 +819,31 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
return last;
}
private LogicalPlan withSelectHint(LogicalPlan logicalPlan, SelectHintContext hintContext) {
if (hintContext == null) {
return logicalPlan;
}
Map<String, SelectHint> hints = Maps.newLinkedHashMap();
for (HintStatementContext hintStatement : hintContext.hintStatements) {
String hintName = hintStatement.hintName.getText().toLowerCase(Locale.ROOT);
Map<String, Optional<String>> parameters = Maps.newLinkedHashMap();
for (HintAssignmentContext kv : hintStatement.parameters) {
String parameterName = kv.key.getText();
Optional<String> value = Optional.empty();
if (kv.constantValue != null) {
Literal literal = (Literal) visit(kv.constantValue);
value = Optional.ofNullable(literal.toLegacyLiteral().getStringValue());
} else if (kv.identifierValue != null) {
// maybe we should throw exception when the identifierValue is quoted identifier
value = Optional.ofNullable(kv.identifierValue.getText());
}
parameters.put(parameterName, value);
}
hints.put(hintName, new SelectHint(hintName, parameters));
}
return new LogicalSelectHint<>(hints, logicalPlan);
}
private LogicalPlan withProjection(LogicalPlan input, SelectClauseContext selectCtx,
Optional<AggClauseContext> aggCtx) {
return ParserUtils.withOrigin(selectCtx, () -> {

View File

@ -17,8 +17,10 @@
package org.apache.doris.nereids.pattern;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.qe.ConnectContext;
/**
* Define a context when match a pattern pass through a MatchedAction.
@ -26,18 +28,22 @@ import org.apache.doris.nereids.trees.plans.Plan;
public class MatchingContext<TYPE extends Plan> {
public final TYPE root;
public final Pattern<TYPE> pattern;
public final PlannerContext plannerContext;
public final CascadesContext cascadesContext;
public final StatementContext statementContext;
public final ConnectContext connectContext;
/**
* the MatchingContext is the param pass through the MatchedAction.
*
* @param root the matched tree node root
* @param pattern the defined pattern
* @param plannerContext the planner context
* @param cascadesContext the planner context
*/
public MatchingContext(TYPE root, Pattern<TYPE> pattern, PlannerContext plannerContext) {
public MatchingContext(TYPE root, Pattern<TYPE> pattern, CascadesContext cascadesContext) {
this.root = root;
this.pattern = pattern;
this.plannerContext = plannerContext;
this.cascadesContext = cascadesContext;
this.statementContext = cascadesContext.getStatementContext();
this.connectContext = cascadesContext.getConnectContext();
}
}

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.pattern;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RulePromise;
import org.apache.doris.nereids.rules.RuleType;
@ -67,7 +67,7 @@ public class PatternMatcher<INPUT_TYPE extends Plan, OUTPUT_TYPE extends Plan> {
public Rule toRule(RuleType ruleType, RulePromise rulePromise) {
return new Rule(ruleType, pattern, rulePromise) {
@Override
public List<Plan> transform(Plan originPlan, PlannerContext context) {
public List<Plan> transform(Plan originPlan, CascadesContext context) {
MatchingContext<INPUT_TYPE> matchingContext =
new MatchingContext<>((INPUT_TYPE) originPlan, pattern, context);
OUTPUT_TYPE replacePlan = matchedAction.apply(matchingContext);

View File

@ -0,0 +1,27 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.processor.post;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
/**
* PlanPostprocessor: a PlanVisitor to rewrite PhysicalPlan to new PhysicalPlan.
*/
public class PlanPostprocessor extends DefaultPlanRewriter<CascadesContext> {
}

View File

@ -0,0 +1,50 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.processor.post;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
/**
* PlanPostprocessors: after copy out the plan from the memo, we use this rewriter to rewrite plan by visitor.
*/
public class PlanPostprocessors {
private final CascadesContext cascadesContext;
public PlanPostprocessors(CascadesContext cascadesContext) {
this.cascadesContext = Objects.requireNonNull(cascadesContext, "cascadesContext can not be null");
}
public PhysicalPlan process(PhysicalPlan physicalPlan) {
PhysicalPlan resultPlan = physicalPlan;
for (PlanPostprocessor processor : getProcessors()) {
resultPlan = (PhysicalPlan) physicalPlan.accept(processor, cascadesContext);
}
return resultPlan;
}
public List<PlanPostprocessor> getProcessors() {
// add processor if we need
return ImmutableList.of();
}
}

View File

@ -0,0 +1,74 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.processor.pre;
import org.apache.doris.analysis.SetVar;
import org.apache.doris.analysis.StringLiteral;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.properties.SelectHint;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalSelectHint;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.qe.VariableMgr;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Map.Entry;
import java.util.Optional;
/**
* eliminate set var hint, and set var to session variables.
*/
public class EliminateLogicalSelectHint extends PlanPreprocessor {
private Logger logger = LoggerFactory.getLogger(getClass());
@Override
public LogicalPlan visitLogicalSelectHint(LogicalSelectHint<Plan> selectHintPlan, StatementContext context) {
for (Entry<String, SelectHint> hint : selectHintPlan.getHints().entrySet()) {
String hintName = hint.getKey();
if (hintName.equalsIgnoreCase("SET_VAR")) {
setVar(hint.getValue(), context);
} else {
logger.warn("Can not process select hint '{}' and skip it", hint.getKey());
}
}
return (LogicalPlan) selectHintPlan.child();
}
private void setVar(SelectHint selectHint, StatementContext context) {
SessionVariable sessionVariable = context.getConnectContext().getSessionVariable();
// set temporary session value, and then revert value in the 'finally block' of StmtExecutor#execute
sessionVariable.setIsSingleSetVar(true);
for (Entry<String, Optional<String>> kv : selectHint.getParameters().entrySet()) {
String key = kv.getKey();
Optional<String> value = kv.getValue();
if (value.isPresent()) {
try {
VariableMgr.setVar(sessionVariable, new SetVar(key, new StringLiteral(value.get())));
} catch (Throwable t) {
throw new AnalysisException("Can not set session variable '" + key + "' = '"
+ value.get() + "'", t);
}
}
}
}
}

View File

@ -0,0 +1,29 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.processor.pre;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
/**
* PlanPreprocessor: a PlanVisitor to rewrite LogicalPlan to new LogicalPlan.
*/
public abstract class PlanPreprocessor extends DefaultPlanRewriter<StatementContext> {
}

View File

@ -0,0 +1,52 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.processor.pre;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
/**
* PlanPreprocessors: before copy the plan into the memo, we use this rewriter to rewrite plan by visitor.
*/
public class PlanPreprocessors {
private final StatementContext statementContext;
public PlanPreprocessors(StatementContext statementContext) {
this.statementContext = Objects.requireNonNull(statementContext, "statementContext can not be null");
}
public LogicalPlan process(LogicalPlan logicalPlan) {
LogicalPlan resultPlan = logicalPlan;
for (PlanPreprocessor processor : getProcessors()) {
resultPlan = (LogicalPlan) logicalPlan.accept(processor, statementContext);
}
return resultPlan;
}
public List<PlanPreprocessor> getProcessors() {
// add processor if we need
return ImmutableList.of(
new EliminateLogicalSelectHint()
);
}
}

View File

@ -0,0 +1,61 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.properties;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
/**
* select hint.
* e.g. set_var(query_timeout='1800', exec_mem_limit='2147483648')
*/
public class SelectHint {
// e.g. set_var
private final String hintName;
// e.g. query_timeout='1800', exec_mem_limit='2147483648'
private final Map<String, Optional<String>> parameters;
public SelectHint(String hintName, Map<String, Optional<String>> parameters) {
this.hintName = Objects.requireNonNull(hintName, "hintName can not be null");
this.parameters = Objects.requireNonNull(parameters, "parameters can not be null");
}
public String getHintName() {
return hintName;
}
public Map<String, Optional<String>> getParameters() {
return parameters;
}
@Override
public String toString() {
String kvString = parameters
.entrySet()
.stream()
.map(kv ->
kv.getValue().isPresent()
? kv.getKey() + "='" + kv.getValue().get() + "'"
: kv.getKey()
)
.collect(Collectors.joining(", "));
return hintName + "(" + kvString + ")";
}
}

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.rules;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.TransformException;
import org.apache.doris.nereids.pattern.Pattern;
import org.apache.doris.nereids.rules.RuleType.RuleTypeClass;
@ -67,5 +67,5 @@ public abstract class Rule {
return getRuleType().toString();
}
public abstract List<Plan> transform(Plan node, PlannerContext context) throws TransformException;
public abstract List<Plan> transform(Plan node, CascadesContext context) throws TransformException;
}

View File

@ -31,6 +31,7 @@ public enum RuleType {
BINDING_JOIN_SLOT(RuleTypeClass.REWRITE),
BINDING_AGGREGATE_SLOT(RuleTypeClass.REWRITE),
BINDING_SORT_SLOT(RuleTypeClass.REWRITE),
BINDING_LIMIT_SLOT(RuleTypeClass.REWRITE),
BINDING_PROJECT_FUNCTION(RuleTypeClass.REWRITE),
BINDING_AGGREGATE_FUNCTION(RuleTypeClass.REWRITE),
BINDING_SUBQUERY_ALIAS_SLOT(RuleTypeClass.REWRITE),

View File

@ -21,9 +21,13 @@ import org.apache.doris.catalog.Database;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.Table;
import org.apache.doris.catalog.TableIf.TableType;
import org.apache.doris.nereids.analyzer.NereidsAnalyzer;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
@ -40,16 +44,15 @@ public class BindRelation extends OneAnalysisRuleFactory {
@Override
public Rule build() {
return unboundRelation().thenApply(ctx -> {
ConnectContext connectContext = ctx.plannerContext.getConnectContext();
List<String> nameParts = ctx.root.getNameParts();
switch (nameParts.size()) {
case 1: {
case 1: { // table
// Use current database name from catalog.
return bindWithCurrentDb(connectContext, nameParts);
return bindWithCurrentDb(ctx.cascadesContext, nameParts);
}
case 2: {
case 2: { // db.table
// Use database name from table name parts.
return bindWithDbNameFromNamePart(connectContext, nameParts);
return bindWithDbNameFromNamePart(ctx.cascadesContext, nameParts);
}
default:
throw new IllegalStateException("Table name [" + ctx.root.getTableName() + "] is invalid.");
@ -69,32 +72,41 @@ public class BindRelation extends OneAnalysisRuleFactory {
}
}
private LogicalPlan bindWithCurrentDb(ConnectContext ctx, List<String> nameParts) {
String dbName = ctx.getDatabase();
Table table = getTable(dbName, nameParts.get(0), ctx.getEnv());
private LogicalPlan bindWithCurrentDb(CascadesContext cascadesContext, List<String> nameParts) {
String dbName = cascadesContext.getConnectContext().getDatabase();
Table table = getTable(dbName, nameParts.get(0), cascadesContext.getConnectContext().getEnv());
// TODO: should generate different Scan sub class according to table's type
if (table.getType() == TableType.OLAP) {
return new LogicalOlapScan(table, ImmutableList.of(dbName));
} else if (table.getType() == TableType.VIEW) {
LogicalPlan viewPlan = new NereidsAnalyzer(ctx).analyze(table.getDdlSql());
Plan viewPlan = parseAndAnalyzeView(table.getDdlSql(), cascadesContext);
return new LogicalSubQueryAlias<>(table.getName(), viewPlan);
}
throw new RuntimeException("Unsupported tableType:" + table.getType());
throw new AnalysisException("Unsupported tableType:" + table.getType());
}
private LogicalPlan bindWithDbNameFromNamePart(ConnectContext ctx, List<String> nameParts) {
private LogicalPlan bindWithDbNameFromNamePart(CascadesContext cascadesContext, List<String> nameParts) {
ConnectContext connectContext = cascadesContext.getConnectContext();
// if the relation is view, nameParts.get(0) is dbName.
String dbName = nameParts.get(0);
if (!dbName.equals(ctx.getDatabase())) {
dbName = ctx.getClusterName() + ":" + nameParts.get(0);
if (!dbName.equals(connectContext.getDatabase())) {
dbName = connectContext.getClusterName() + ":" + dbName;
}
Table table = getTable(dbName, nameParts.get(1), ctx.getEnv());
Table table = getTable(dbName, nameParts.get(1), connectContext.getEnv());
if (table.getType() == TableType.OLAP) {
return new LogicalOlapScan(table, ImmutableList.of(dbName));
} else if (table.getType() == TableType.VIEW) {
LogicalPlan viewPlan = new NereidsAnalyzer(ctx).analyze(table.getDdlSql());
Plan viewPlan = parseAndAnalyzeView(table.getDdlSql(), cascadesContext);
return new LogicalSubQueryAlias<>(table.getName(), viewPlan);
}
throw new RuntimeException("Unsupported tableType:" + table.getType());
throw new AnalysisException("Unsupported tableType:" + table.getType());
}
private Plan parseAndAnalyzeView(String viewSql, CascadesContext parentContext) {
LogicalPlan parsedViewPlan = new NereidsParser().parseSingle(viewSql);
CascadesContext viewContext = new Memo(parsedViewPlan)
.newCascadesContext(parentContext.getStatementContext());
viewContext.newAnalyzer().analyze();
return viewContext.getMemo().copyOut();
}
}

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.UnboundAlias;
import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.analyzer.UnboundStar;
@ -29,18 +30,20 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultSubExprRewriter;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import org.apache.commons.lang.StringUtils;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -51,8 +54,12 @@ import java.util.stream.Stream;
public class BindSlotReference implements AnalysisRuleFactory {
private final Optional<Scope> outerScope;
public BindSlotReference() {
this(Optional.empty());
}
public BindSlotReference(Optional<Scope> outputScope) {
this.outerScope = outputScope;
this.outerScope = Objects.requireNonNull(outputScope, "outerScope can not be null");
}
private Scope toScope(List<Slot> slots) {
@ -67,46 +74,62 @@ public class BindSlotReference implements AnalysisRuleFactory {
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.BINDING_PROJECT_SLOT.build(
logicalProject().then(project -> {
logicalProject().thenApply(ctx -> {
LogicalProject<GroupPlan> project = ctx.root;
List<NamedExpression> boundSlots =
bind(project.getProjects(), project.children(), project);
bind(project.getProjects(), project.children(), project, ctx.cascadesContext);
return new LogicalProject<>(flatBoundStar(boundSlots), project.child());
})
),
RuleType.BINDING_FILTER_SLOT.build(
logicalFilter().then(filter -> {
Expression boundPredicates = bind(filter.getPredicates(), filter.children(), filter);
logicalFilter().thenApply(ctx -> {
LogicalFilter<GroupPlan> filter = ctx.root;
Expression boundPredicates = bind(filter.getPredicates(), filter.children(),
filter, ctx.cascadesContext);
return new LogicalFilter<>(boundPredicates, filter.child());
})
),
RuleType.BINDING_JOIN_SLOT.build(
logicalJoin().then(join -> {
logicalJoin().thenApply(ctx -> {
LogicalJoin<GroupPlan, GroupPlan> join = ctx.root;
Optional<Expression> cond = join.getCondition()
.map(expr -> bind(expr, join.children(), join));
.map(expr -> bind(expr, join.children(), join, ctx.cascadesContext));
return new LogicalJoin<>(join.getJoinType(), cond, join.left(), join.right());
})
),
RuleType.BINDING_AGGREGATE_SLOT.build(
logicalAggregate().then(agg -> {
List<Expression> groupBy = bind(agg.getGroupByExpressions(), agg.children(), agg);
List<NamedExpression> output = bind(agg.getOutputExpressions(), agg.children(), agg);
logicalAggregate().thenApply(ctx -> {
LogicalAggregate<GroupPlan> agg = ctx.root;
List<Expression> groupBy =
bind(agg.getGroupByExpressions(), agg.children(), agg, ctx.cascadesContext);
List<NamedExpression> output =
bind(agg.getOutputExpressions(), agg.children(), agg, ctx.cascadesContext);
return agg.withGroupByAndOutput(groupBy, output);
})
),
RuleType.BINDING_SORT_SLOT.build(
logicalSort().then(sort -> {
logicalSort().thenApply(ctx -> {
LogicalSort<GroupPlan> sort = ctx.root;
List<OrderKey> sortItemList = sort.getOrderKeys()
.stream()
.map(orderKey -> {
Expression item = bind(orderKey.getExpr(), sort.children(), sort);
Expression item = bind(orderKey.getExpr(), sort.children(), sort, ctx.cascadesContext);
return new OrderKey(item, orderKey.isAsc(), orderKey.isNullFirst());
}).collect(Collectors.toList());
return new LogicalSort<>(sortItemList, sort.child());
})
),
// this rewrite is necessary because we should replace the logicalProperties which refer the child
// unboundLogicalProperties to a new LogicalProperties. This restriction is because we move the
// analysis stage after build the memo, and cause parent's plan can not update logical properties
// when the children are changed. we should discuss later and refactor it.
RuleType.BINDING_SUBQUERY_ALIAS_SLOT.build(
logicalSubQueryAlias().then(alias -> new LogicalSubQueryAlias<>(alias.getAlias(), alias.child()))
logicalSubQueryAlias().then(alias -> alias.withChildren(ImmutableList.of(alias.child())))
),
RuleType.BINDING_LIMIT_SLOT.build(
logicalLimit().then(limit -> limit.withChildren(ImmutableList.of(limit.child())))
)
);
}
@ -123,24 +146,25 @@ public class BindSlotReference implements AnalysisRuleFactory {
}).collect(Collectors.toList());
}
private <E extends Expression> List<E> bind(List<E> exprList, List<Plan> inputs, Plan plan) {
private <E extends Expression> List<E> bind(List<E> exprList, List<Plan> inputs, Plan plan,
CascadesContext cascadesContext) {
return exprList.stream()
.map(expr -> bind(expr, inputs, plan))
.map(expr -> bind(expr, inputs, plan, cascadesContext))
.collect(Collectors.toList());
}
private <E extends Expression> E bind(E expr, List<Plan> inputs, Plan plan) {
private <E extends Expression> E bind(E expr, List<Plan> inputs, Plan plan, CascadesContext cascadesContext) {
List<Slot> boundedSlots = inputs.stream()
.flatMap(input -> input.getOutput().stream())
.collect(Collectors.toList());
return (E) new SlotBinder(toScope(boundedSlots), plan).bind(expr);
return (E) new SlotBinder(toScope(boundedSlots), plan, cascadesContext).bind(expr);
}
private class SlotBinder extends DefaultSubExprRewriter<Void> {
private final Plan plan;
public SlotBinder(Scope scope, Plan plan) {
super(scope);
public SlotBinder(Scope scope, Plan plan, CascadesContext cascadesContext) {
super(scope, cascadesContext);
this.plan = plan;
}

View File

@ -45,7 +45,7 @@ public class ReorderJoin extends OneRewriteRuleFactory {
public Rule build() {
return logicalFilter(subTree(LogicalJoin.class, LogicalFilter.class)).thenApply(ctx -> {
LogicalFilter<Plan> filter = ctx.root;
if (!ctx.plannerContext.getConnectContext().getSessionVariable()
if (!ctx.cascadesContext.getConnectContext().getSessionVariable()
.isEnableNereidsReorderToEliminateCrossJoin()) {
return filter;
}

View File

@ -0,0 +1,58 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.analysis.IntLiteral;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
/**
* Represents Bigint literal
*/
public class BigIntLiteral extends Literal {
private final long value;
public BigIntLiteral(long value) {
super(BigIntType.INSTANCE);
this.value = value;
}
@Override
public Long getValue() {
return value;
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitBigIntLiteral(this, context);
}
@Override
public LiteralExpr toLegacyLiteral() {
try {
return new IntLiteral(value, Type.BIGINT);
} catch (AnalysisException e) {
throw new org.apache.doris.nereids.exceptions.AnalysisException(
"Can not convert to legacy literal: " + value, e);
}
}
}

View File

@ -17,6 +17,8 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.analysis.BoolLiteral;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BooleanType;
@ -50,4 +52,9 @@ public class BooleanLiteral extends Literal {
public String toString() {
return Boolean.valueOf(value).toString().toUpperCase();
}
@Override
public LiteralExpr toLegacyLiteral() {
return new BoolLiteral(value);
}
}

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
@ -145,6 +146,11 @@ public class DateLiteral extends Literal {
return String.format("%04d-%02d-%02d", year, month, day);
}
@Override
public LiteralExpr toLegacyLiteral() {
return new org.apache.doris.analysis.DateLiteral(year, month, day);
}
public long getYear() {
return year;
}

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
@ -137,6 +138,11 @@ public class DateTimeLiteral extends DateLiteral {
return String.format("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second);
}
@Override
public LiteralExpr toLegacyLiteral() {
return new org.apache.doris.analysis.DateLiteral(year, month, day, hour, minute, second);
}
public long getHour() {
return hour;
}

View File

@ -17,6 +17,9 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.analysis.FloatLiteral;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DoubleType;
@ -41,4 +44,9 @@ public class DoubleLiteral extends Literal {
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitDoubleLiteral(this, context);
}
@Override
public LiteralExpr toLegacyLiteral() {
return new FloatLiteral(value, Type.DOUBLE);
}
}

View File

@ -17,6 +17,10 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.analysis.IntLiteral;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.IntegerType;
@ -41,4 +45,14 @@ public class IntegerLiteral extends Literal {
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitIntegerLiteral(this, context);
}
@Override
public LiteralExpr toLegacyLiteral() {
try {
return new IntLiteral(value, Type.INT);
} catch (AnalysisException e) {
throw new org.apache.doris.nereids.exceptions.AnalysisException(
"Can not convert to legacy literal: " + value, e);
}
}
}

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
@ -100,4 +101,6 @@ public abstract class Literal extends Expression implements LeafExpression {
public String toString() {
return String.valueOf(getValue());
}
public abstract LiteralExpr toLegacyLiteral();
}

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.NullType;
@ -43,4 +44,9 @@ public class NullLiteral extends Literal {
public String toString() {
return "NULL";
}
@Override
public LiteralExpr toLegacyLiteral() {
return new org.apache.doris.analysis.NullLiteral();
}
}

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
@ -49,6 +50,11 @@ public class StringLiteral extends Literal {
return visitor.visitStringLiteral(this, context);
}
@Override
public LiteralExpr toLegacyLiteral() {
return new org.apache.doris.analysis.StringLiteral(value);
}
@Override
protected Expression uncheckedCastTo(DataType targetType) throws AnalysisException {
if (getDataType().equals(targetType)) {

View File

@ -17,8 +17,9 @@
package org.apache.doris.nereids.trees.expressions.visitor;
import org.apache.doris.nereids.analyzer.NereidsAnalyzer;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.rules.analysis.Scope;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InSubquery;
@ -26,7 +27,6 @@ import org.apache.doris.nereids.trees.expressions.ListQuery;
import org.apache.doris.nereids.trees.expressions.ScalarSubquery;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.qe.ConnectContext;
import java.util.Optional;
@ -35,9 +35,11 @@ import java.util.Optional;
*/
public class DefaultSubExprRewriter<C> extends DefaultExpressionRewriter<C> {
private final Scope scope;
private final CascadesContext cascadesContext;
public DefaultSubExprRewriter(Scope scope) {
public DefaultSubExprRewriter(Scope scope, CascadesContext cascadesContext) {
this.scope = scope;
this.cascadesContext = cascadesContext;
}
@Override
@ -61,10 +63,10 @@ public class DefaultSubExprRewriter<C> extends DefaultExpressionRewriter<C> {
}
private LogicalPlan analyzeSubquery(SubqueryExpr expr) {
NereidsAnalyzer subAnalyzer = new NereidsAnalyzer(ConnectContext.get());
LogicalPlan analyzed = subAnalyzer.analyze(
expr.getQueryPlan(), Optional.ofNullable(scope));
return analyzed;
CascadesContext subqueryContext = new Memo(expr.getQueryPlan())
.newCascadesContext(cascadesContext.getStatementContext());
subqueryContext.newAnalyzer(Optional.ofNullable(getScope())).analyze();
return (LogicalPlan) subqueryContext.getMemo().copyOut();
}
public Scope getScope() {

View File

@ -26,6 +26,7 @@ import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Arithmetic;
import org.apache.doris.nereids.trees.expressions.Between;
import org.apache.doris.nereids.trees.expressions.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
@ -128,31 +129,35 @@ public abstract class ExpressionVisitor<R, C> {
}
public R visitBooleanLiteral(BooleanLiteral booleanLiteral, C context) {
return visit(booleanLiteral, context);
return visitLiteral(booleanLiteral, context);
}
public R visitStringLiteral(StringLiteral stringLiteral, C context) {
return visit(stringLiteral, context);
return visitLiteral(stringLiteral, context);
}
public R visitIntegerLiteral(IntegerLiteral integerLiteral, C context) {
return visit(integerLiteral, context);
return visitLiteral(integerLiteral, context);
}
public R visitBigIntLiteral(BigIntLiteral bigIntLiteral, C context) {
return visitLiteral(bigIntLiteral, context);
}
public R visitNullLiteral(NullLiteral nullLiteral, C context) {
return visit(nullLiteral, context);
return visitLiteral(nullLiteral, context);
}
public R visitDoubleLiteral(DoubleLiteral doubleLiteral, C context) {
return visit(doubleLiteral, context);
return visitLiteral(doubleLiteral, context);
}
public R visitDateLiteral(DateLiteral dateLiteral, C context) {
return visit(dateLiteral, context);
return visitLiteral(dateLiteral, context);
}
public R visitDateTimeLiteral(DateTimeLiteral dateTimeLiteral, C context) {
return visit(dateTimeLiteral, context);
return visitLiteral(dateTimeLiteral, context);
}
public R visitBetween(Between between, C context) {

View File

@ -37,9 +37,7 @@ public interface Plan extends TreeNode<Plan> {
// cache GroupExpression for fast exit from Memo.copyIn.
Optional<GroupExpression> getGroupExpression();
default <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
throw new RuntimeException("accept() is not implemented by plan " + this.getClass().getSimpleName());
}
<R, C> R accept(PlanVisitor<R, C> visitor, C context);
List<Expression> getExpressions();

View File

@ -37,6 +37,7 @@ public enum PlanType {
LOGICAL_APPLY,
LOGICAL_CORRELATED_JOIN,
LOGICAL_ENFORCE_SINGLE_ROW,
LOGICAL_SELECT_HINT,
GROUP_PLAN,
// physical plan

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.plans.commands;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
/**
* explain command.
@ -42,6 +43,11 @@ public class ExplainCommand implements Command {
this.logicalPlan = logicalPlan;
}
@Override
public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
return visitor.visitExplainCommand(this, context);
}
public ExplainLevel getLevel() {
return level;
}

View File

@ -0,0 +1,112 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.plans.logical;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.SelectHint;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
/**
* select hint plan.
* e.g. LogicalSelectHint (set_var(query_timeout='1800', exec_mem_limit='2147483648'))
*/
public class LogicalSelectHint<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_TYPE> {
private final Map<String, SelectHint> hints;
public LogicalSelectHint(Map<String, SelectHint> hints, CHILD_TYPE child) {
this(hints, Optional.empty(), Optional.empty(), child);
}
public LogicalSelectHint(Map<String, SelectHint> hints,
Optional<LogicalProperties> logicalProperties, CHILD_TYPE child) {
this(hints, Optional.empty(), logicalProperties, child);
}
/**
* LogicalSelectHint's full parameter constructor.
* @param hints hint maps, key is hint name, e.g. 'SET_VAR', and value is parameter pairs, e.g. query_time=100
* @param groupExpression groupExpression exists when this plan is copy out from memo.
* @param logicalProperties logicalProperties is use for compute output
* @param child child plan
*/
public LogicalSelectHint(Map<String, SelectHint> hints,
Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, CHILD_TYPE child) {
super(PlanType.LOGICAL_SELECT_HINT, groupExpression, logicalProperties, child);
this.hints = ImmutableMap.copyOf(Objects.requireNonNull(hints, "hints can not be null"));
}
public Map<String, SelectHint> getHints() {
return hints;
}
@Override
public LogicalSelectHint<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
return new LogicalSelectHint<>(hints, children.get(0));
}
@Override
public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
return visitor.visitLogicalSelectHint((LogicalSelectHint<Plan>) this, context);
}
@Override
public List<Expression> getExpressions() {
return ImmutableList.of();
}
@Override
public LogicalSelectHint<CHILD_TYPE> withGroupExpression(Optional<GroupExpression> groupExpression) {
return new LogicalSelectHint<>(hints, groupExpression, Optional.of(logicalProperties), child());
}
@Override
public LogicalSelectHint<CHILD_TYPE> withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
return new LogicalSelectHint<>(hints, Optional.empty(), logicalProperties, child());
}
@Override
public List<Slot> computeOutput(Plan input) {
return child().getOutput();
}
@Override
public String toString() {
String hintStr = this.hints.entrySet()
.stream()
.map(entry -> entry.getValue().toString())
.collect(Collectors.joining(", "));
return "LogicalSelectHint (" + hintStr + ")";
}
}

View File

@ -20,6 +20,8 @@ package org.apache.doris.nereids.trees.plans.visitor;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.commands.Command;
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalCorrelatedJoin;
@ -30,6 +32,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalSelectHint;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate;
@ -53,6 +56,18 @@ public abstract class PlanVisitor<R, C> {
public abstract R visit(Plan plan, C context);
// *******************************
// commands
// *******************************
public R visitCommand(Command command, C context) {
return visit(command, context);
}
public R visitExplainCommand(ExplainCommand explain, C context) {
return visitCommand(explain, context);
}
// *******************************
// Logical plans
// *******************************
@ -69,6 +84,10 @@ public abstract class PlanVisitor<R, C> {
return visit(relation, context);
}
public R visitLogicalSelectHint(LogicalSelectHint<Plan> hint, C context) {
return visit(hint, context);
}
public R visitLogicalAggregate(LogicalAggregate<Plan> aggregate, C context) {
return visit(aggregate, context);

View File

@ -86,6 +86,7 @@ import org.apache.doris.mysql.MysqlEofPacket;
import org.apache.doris.mysql.MysqlSerializer;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.planner.OlapScanNode;
import org.apache.doris.planner.OriginalPlanner;
@ -156,6 +157,7 @@ public class StmtExecutor implements ProfileWriter {
private static final int MAX_DATA_TO_SEND_FOR_TXN = 100;
private ConnectContext context;
private StatementContext statementContext;
private MysqlSerializer serializer;
private OriginStatement originStmt;
private StatementBase parsedStmt;
@ -183,6 +185,7 @@ public class StmtExecutor implements ProfileWriter {
this.originStmt = originStmt;
this.serializer = context.getSerializer();
this.isProxy = isProxy;
this.statementContext = new StatementContext(context, originStmt);
}
// this constructor is only for test now.
@ -197,6 +200,8 @@ public class StmtExecutor implements ProfileWriter {
this.originStmt = parsedStmt.getOrigStmt();
this.serializer = context.getSerializer();
this.isProxy = false;
this.statementContext = new StatementContext(ctx, originStmt);
this.statementContext.setParsedStatement(parsedStmt);
}
public void setCoord(Coordinator coord) {
@ -600,7 +605,7 @@ public class StmtExecutor implements ProfileWriter {
if (parsedStmt instanceof ShowStmt) {
SelectStmt selectStmt = ((ShowStmt) parsedStmt).toSelectStmt(analyzer);
if (selectStmt != null) {
parsedStmt = selectStmt;
setParsedStmt(selectStmt);
}
}
@ -671,7 +676,7 @@ public class StmtExecutor implements ProfileWriter {
context.getSessionVariable().getSqlMode());
SqlParser parser = new SqlParser(input);
try {
parsedStmt = SqlParserUtils.getStmt(parser, originStmt.idx);
StatementBase parsedStmt = setParsedStmt(SqlParserUtils.getStmt(parser, originStmt.idx));
parsedStmt.setOrigStmt(originStmt);
parsedStmt.setUserInfo(context.getCurrentUserIdentity());
} catch (Error e) {
@ -716,7 +721,7 @@ public class StmtExecutor implements ProfileWriter {
parsedStmt.rewriteExprs(rewriter);
reAnalyze = rewriter.changed();
if (analyzer.containSubquery()) {
parsedStmt = StmtRewriter.rewrite(analyzer, parsedStmt);
parsedStmt = setParsedStmt(StmtRewriter.rewrite(analyzer, parsedStmt));
reAnalyze = true;
}
if (parsedStmt instanceof SelectStmt) {
@ -771,7 +776,7 @@ public class StmtExecutor implements ProfileWriter {
if (parsedStmt instanceof LogicalPlanAdapter) {
// create plan
planner = new NereidsPlanner(context);
planner = new NereidsPlanner(statementContext);
} else {
planner = new OriginalPlanner(analyzer);
}
@ -921,7 +926,7 @@ public class StmtExecutor implements ProfileWriter {
analyzer = new Analyzer(context.getEnv(), context);
newSelectStmt.analyze(analyzer);
if (parsedStmt instanceof LogicalPlanAdapter) {
planner = new NereidsPlanner(context);
planner = new NereidsPlanner(statementContext);
} else {
planner = new OriginalPlanner(analyzer);
}
@ -1650,7 +1655,7 @@ public class StmtExecutor implements ProfileWriter {
// after success create table insert data
if (MysqlStateType.OK.equals(context.getState().getStateType())) {
try {
parsedStmt = ctasStmt.getInsertStmt();
parsedStmt = setParsedStmt(ctasStmt.getInsertStmt());
execute();
} catch (Exception e) {
LOG.warn("CTAS insert data error, stmt={}", parsedStmt.toSql(), e);
@ -1689,4 +1694,10 @@ public class StmtExecutor implements ProfileWriter {
private List<PrimitiveType> exprToType(List<Expr> exprs) {
return exprs.stream().map(e -> e.getType().getPrimitiveType()).collect(Collectors.toList());
}
private StatementBase setParsedStmt(StatementBase parsedStmt) {
this.parsedStmt = parsedStmt;
this.statementContext.setParsedStatement(parsedStmt);
return parsedStmt;
}
}

View File

@ -17,7 +17,6 @@
package org.apache.doris.nereids.datasets.ssb;
import org.apache.doris.nereids.analyzer.NereidsAnalyzer;
import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.Plan;
@ -82,7 +81,7 @@ public class SSBJoinReorderTest extends SSBTestBase {
}
private void test(String sql, List<String> expectJoinConditions, List<String> expectFilterPredicates) {
LogicalPlan analyzed = new NereidsAnalyzer(connectContext).analyze(sql);
LogicalPlan analyzed = analyze(sql);
LogicalPlan plan = testJoinReorder(analyzed);
new PlanChecker(expectJoinConditions, expectFilterPredicates).check(plan);
}

View File

@ -17,7 +17,6 @@
package org.apache.doris.nereids.datasets.tpch;
import org.apache.doris.nereids.analyzer.NereidsAnalyzer;
import org.apache.doris.nereids.analyzer.Unbound;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.Plan;
@ -30,7 +29,7 @@ import java.util.List;
public abstract class AnalyzeCheckTestBase extends TestWithFeService {
protected void checkAnalyze(String sql) {
LogicalPlan analyzed = new NereidsAnalyzer(connectContext).analyze(sql);
LogicalPlan analyzed = analyze(sql);
Assertions.assertTrue(checkBound(analyzed));
}

View File

@ -17,19 +17,18 @@
package org.apache.doris.nereids.jobs;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.cost.CostCalculator;
import org.apache.doris.nereids.jobs.cascades.OptimizeGroupJob;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.Lists;
import mockit.Mock;
@ -93,12 +92,11 @@ public class CostAndEnforcerJobTest {
LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN,
Optional.of(bottomJoinOnCondition), scans.get(0), scans.get(1));
PlannerContext plannerContext = new Memo(bottomJoin).newPlannerContext(new ConnectContext())
.setDefaultJobContext();
plannerContext.pushJob(
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(bottomJoin);
cascadesContext.pushJob(
new OptimizeGroupJob(
plannerContext.getMemo().getRoot(),
plannerContext.getCurrentJobContext()));
plannerContext.getJobScheduler().executeJobPool(plannerContext);
cascadesContext.getMemo().getRoot(),
cascadesContext.getCurrentJobContext()));
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
}
}

View File

@ -18,11 +18,10 @@
package org.apache.doris.nereids.jobs;
import org.apache.doris.catalog.Table;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
@ -35,8 +34,8 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
@ -63,14 +62,12 @@ public class RewriteTopDownJobTest {
new SlotReference("name", StringType.INSTANCE, true, ImmutableList.of("test"))),
leaf
);
PlannerContext plannerContext = new Memo(project)
.newPlannerContext(new ConnectContext())
.setDefaultJobContext();
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(project);
List<Rule> fakeRules = Lists.newArrayList(new FakeRule().build());
plannerContext.topDownRewrite(fakeRules);
cascadesContext.topDownRewrite(fakeRules);
Group rootGroup = plannerContext.getMemo().getRoot();
Group rootGroup = cascadesContext.getMemo().getRoot();
Assertions.assertEquals(1, rootGroup.getLogicalExpressions().size());
GroupExpression rootGroupExpression = rootGroup.getLogicalExpression();
List<Slot> output = rootGroup.getLogicalProperties().getOutput();

View File

@ -19,9 +19,8 @@ package org.apache.doris.nereids.jobs.cascades;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -32,6 +31,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.ColumnStats;
@ -66,14 +66,13 @@ public class DeriveStatsJobTest {
public void testExecute() throws Exception {
LogicalOlapScan olapScan = constructOlapSCan();
LogicalAggregate agg = constructAgg(olapScan);
Memo memo = new Memo(agg);
PlannerContext plannerContext = new PlannerContext(memo, context);
new DeriveStatsJob(memo.getRoot().getLogicalExpression(),
new JobContext(plannerContext, null, Double.MAX_VALUE)).execute();
while (!plannerContext.getJobPool().isEmpty()) {
plannerContext.getJobPool().pop().execute();
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(agg);
new DeriveStatsJob(cascadesContext.getMemo().getRoot().getLogicalExpression(),
new JobContext(cascadesContext, null, Double.MAX_VALUE)).execute();
while (!cascadesContext.getJobPool().isEmpty()) {
cascadesContext.getJobPool().pop().execute();
}
StatsDeriveResult statistics = memo.getRoot().getStatistics();
StatsDeriveResult statistics = cascadesContext.getMemo().getRoot().getStatistics();
Assertions.assertNotNull(statistics);
Assertions.assertEquals(10, statistics.getRowCount());
}

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -27,10 +27,10 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import mockit.Mocked;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
@ -39,7 +39,7 @@ import java.util.Optional;
public class JoinCommuteTest {
@Test
public void testInnerJoinCommute(@Mocked PlannerContext plannerContext) {
public void testInnerJoinCommute() {
LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
@ -49,9 +49,10 @@ public class JoinCommuteTest {
LogicalJoin<LogicalOlapScan, LogicalOlapScan> join = new LogicalJoin<>(
JoinType.INNER_JOIN, Optional.of(onCondition), scan1, scan2);
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(join);
Rule rule = new JoinCommute(true).build();
List<Plan> transform = rule.transform(join, plannerContext);
List<Plan> transform = rule.transform(join, cascadesContext);
Assertions.assertEquals(1, transform.size());
Plan newJoin = transform.get(0);

View File

@ -18,7 +18,7 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
@ -30,12 +30,12 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import mockit.Mocked;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
@ -67,8 +67,7 @@ public class JoinLAsscomProjectTest {
outputs.add(t3Output);
}
private Pair<LogicalJoin, LogicalJoin> testJoinProjectLAsscom(PlannerContext plannerContext,
List<NamedExpression> projects) {
private Pair<LogicalJoin, LogicalJoin> testJoinProjectLAsscom(List<NamedExpression> projects) {
/*
* topJoin newTopJoin
* / \ / \
@ -94,8 +93,9 @@ public class JoinLAsscomProjectTest {
LogicalJoin<LogicalProject<LogicalJoin<LogicalOlapScan, LogicalOlapScan>>, LogicalOlapScan> topJoin
= new LogicalJoin<>(JoinType.INNER_JOIN, Optional.of(topJoinOnCondition), project, scans.get(2));
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(topJoin);
Rule rule = new JoinLAsscomProject().build();
List<Plan> transform = rule.transform(topJoin, plannerContext);
List<Plan> transform = rule.transform(topJoin, cascadesContext);
Assertions.assertEquals(1, transform.size());
Assertions.assertTrue(transform.get(0) instanceof LogicalJoin);
LogicalJoin newTopJoin = (LogicalJoin) transform.get(0);
@ -103,7 +103,7 @@ public class JoinLAsscomProjectTest {
}
@Test
public void testStarJoinProjectLAsscom(@Mocked PlannerContext plannerContext) {
public void testStarJoinProjectLAsscom() {
List<SlotReference> t1 = outputs.get(0);
List<SlotReference> t2 = outputs.get(1);
List<NamedExpression> projects = ImmutableList.of(
@ -113,7 +113,7 @@ public class JoinLAsscomProjectTest {
t2.get(1)
);
Pair<LogicalJoin, LogicalJoin> pair = testJoinProjectLAsscom(plannerContext, projects);
Pair<LogicalJoin, LogicalJoin> pair = testJoinProjectLAsscom(projects);
LogicalJoin oldJoin = pair.first;
LogicalJoin newTopJoin = pair.second;

View File

@ -18,7 +18,7 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -27,11 +27,11 @@ import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.Lists;
import mockit.Mocked;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
@ -62,7 +62,7 @@ public class JoinLAsscomTest {
outputs.add(t3Output);
}
public Pair<LogicalJoin, LogicalJoin> testJoinLAsscom(PlannerContext plannerContext,
public Pair<LogicalJoin, LogicalJoin> testJoinLAsscom(
Expression bottomJoinOnCondition, Expression topJoinOnCondition) {
/*
* topJoin newTopJoin
@ -77,8 +77,9 @@ public class JoinLAsscomTest {
LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> topJoin = new LogicalJoin<>(
JoinType.INNER_JOIN, Optional.of(topJoinOnCondition), bottomJoin, scans.get(2));
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(topJoin);
Rule rule = new JoinLAsscom().build();
List<Plan> transform = rule.transform(topJoin, plannerContext);
List<Plan> transform = rule.transform(topJoin, cascadesContext);
Assertions.assertEquals(1, transform.size());
Assertions.assertTrue(transform.get(0) instanceof LogicalJoin);
LogicalJoin newTopJoin = (LogicalJoin) transform.get(0);
@ -86,7 +87,7 @@ public class JoinLAsscomTest {
}
@Test
public void testStarJoinLAsscom(@Mocked PlannerContext plannerContext) {
public void testStarJoinLAsscom() {
/*
* Star-Join
* t1 -- t2
@ -108,8 +109,7 @@ public class JoinLAsscomTest {
Expression bottomJoinOnCondition = new EqualTo(t1.get(0), t2.get(0));
Expression topJoinOnCondition = new EqualTo(t1.get(1), t3.get(1));
Pair<LogicalJoin, LogicalJoin> pair = testJoinLAsscom(plannerContext, bottomJoinOnCondition,
topJoinOnCondition);
Pair<LogicalJoin, LogicalJoin> pair = testJoinLAsscom(bottomJoinOnCondition, topJoinOnCondition);
LogicalJoin oldJoin = pair.first;
LogicalJoin newTopJoin = pair.second;
@ -123,7 +123,7 @@ public class JoinLAsscomTest {
}
@Test
public void testChainJoinLAsscom(@Mocked PlannerContext plannerContext) {
public void testChainJoinLAsscom() {
/*
* Chain-Join
* t1 -- t2 -- t3
@ -143,8 +143,7 @@ public class JoinLAsscomTest {
Expression bottomJoinOnCondition = new EqualTo(t1.get(0), t2.get(0));
Expression topJoinOnCondition = new EqualTo(t2.get(0), t3.get(0));
Pair<LogicalJoin, LogicalJoin> pair = testJoinLAsscom(plannerContext, bottomJoinOnCondition,
topJoinOnCondition);
Pair<LogicalJoin, LogicalJoin> pair = testJoinLAsscom(bottomJoinOnCondition, topJoinOnCondition);
LogicalJoin oldJoin = pair.first;
LogicalJoin newTopJoin = pair.second;

View File

@ -17,13 +17,14 @@
package org.apache.doris.nereids.rules.implementation;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.util.MemoTestUtils;
import mockit.Mocked;
import org.junit.jupiter.api.Assertions;
@ -33,10 +34,11 @@ import java.util.List;
public class LogicalLimitToPhysicalLimitTest {
@Test
public void toPhysicalLimitTest(@Mocked Group group, @Mocked PlannerContext plannerContext) {
public void toPhysicalLimitTest(@Mocked Group group) {
Plan logicalPlan = new LogicalLimit<>(3, 4, new GroupPlan(group));
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(logicalPlan);
Rule rule = new LogicalLimitToPhysicalLimit().build();
List<Plan> physicalPlans = rule.transform(logicalPlan, plannerContext);
List<Plan> physicalPlans = rule.transform(logicalPlan, cascadesContext);
Assertions.assertEquals(1, physicalPlans.size());
Plan impl = physicalPlans.get(0);
Assertions.assertEquals(PlanType.PHYSICAL_LIMIT, impl.getType());

View File

@ -17,9 +17,8 @@
package org.apache.doris.nereids.rules.implementation;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
@ -31,8 +30,8 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
@ -53,14 +52,14 @@ public class LogicalProjectToPhysicalProjectTest {
.put(LogicalSort.class.getName(), (new LogicalSortToPhysicalHeapSort()).build())
.build();
private static PhysicalPlan rewriteLogicalToPhysical(Group group, PlannerContext plannerContext) {
private static PhysicalPlan rewriteLogicalToPhysical(Group group, CascadesContext cascadesContext) {
List<Plan> children = Lists.newArrayList();
for (Group child : group.getLogicalExpression().children()) {
children.add(rewriteLogicalToPhysical(child, plannerContext));
children.add(rewriteLogicalToPhysical(child, cascadesContext));
}
Rule rule = rulesMap.get(group.getLogicalExpression().getPlan().getClass().getName());
List<Plan> transform = rule.transform(group.getLogicalExpression().getPlan(), plannerContext);
List<Plan> transform = rule.transform(group.getLogicalExpression().getPlan(), cascadesContext);
Assertions.assertEquals(1, transform.size());
Assertions.assertTrue(transform.get(0) instanceof PhysicalPlan);
PhysicalPlan implPlanNode = (PhysicalPlan) transform.get(0);
@ -68,20 +67,14 @@ public class LogicalProjectToPhysicalProjectTest {
return (PhysicalPlan) implPlanNode.withChildren(children);
}
public static PhysicalPlan rewriteLogicalToPhysical(LogicalPlan plan) {
PlannerContext plannerContext = new Memo(plan)
.newPlannerContext(new ConnectContext())
.setDefaultJobContext();
return rewriteLogicalToPhysical(plannerContext.getMemo().getRoot(), plannerContext);
}
@Test
public void projectionImplTest() {
LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "a", 0);
LogicalPlan project = new LogicalProject<>(Lists.newArrayList(), scan);
PhysicalPlan physicalProject = rewriteLogicalToPhysical(project);
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(project);
PhysicalPlan physicalProject = rewriteLogicalToPhysical(cascadesContext.getMemo().getRoot(), cascadesContext);
Assertions.assertEquals(PlanType.PHYSICAL_PROJECT, physicalProject.getType());
PhysicalPlan physicalScan = (PhysicalPlan) physicalProject.child(0);
Assertions.assertEquals(PlanType.PHYSICAL_OLAP_SCAN, physicalScan.getType());

View File

@ -17,7 +17,6 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.analyzer.NereidsAnalyzer;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
@ -57,12 +56,10 @@ public class ColumnPruningTest extends TestWithFeService implements PatternMatch
public void testPruneColumns1() {
// TODO: It's inconvenient and less efficient to use planPattern().when(...) to check plan properties.
// Enhance the generated patterns in the future.
new PlanChecker()
.plan(new NereidsAnalyzer(connectContext)
.analyze(
"select id,name,grade from student "
+ "left join score on student.id = score.sid where score.grade > 60"))
.applyTopDown(new ColumnPruning(), connectContext)
PlanChecker.from(connectContext)
.analyze("select id,name,grade from student left join score on student.id = score.sid"
+ " where score.grade > 60")
.applyTopDown(new ColumnPruning())
.matches(
logicalProject(
logicalFilter(
@ -90,13 +87,11 @@ public class ColumnPruningTest extends TestWithFeService implements PatternMatch
@Test
public void testPruneColumns2() {
new PlanChecker()
.plan(new NereidsAnalyzer(connectContext)
.analyze(
"select name,sex,cid,grade "
+ "from student left join score on student.id = score.sid "
+ "where score.grade > 60"))
.applyTopDown(new ColumnPruning(), connectContext)
PlanChecker.from(connectContext)
.analyze("select name,sex,cid,grade "
+ "from student left join score on student.id = score.sid "
+ "where score.grade > 60")
.applyTopDown(new ColumnPruning())
.matches(
logicalProject(
logicalFilter(
@ -124,10 +119,9 @@ public class ColumnPruningTest extends TestWithFeService implements PatternMatch
@Test
public void testPruneColumns3() {
new PlanChecker()
.plan(new NereidsAnalyzer(connectContext)
.analyze("select id,name from student where age > 18"))
.applyTopDown(new ColumnPruning(), connectContext)
PlanChecker.from(connectContext)
.analyze("select id,name from student where age > 18")
.applyTopDown(new ColumnPruning())
.matches(
logicalProject(
logicalFilter(
@ -143,14 +137,13 @@ public class ColumnPruningTest extends TestWithFeService implements PatternMatch
@Test
public void testPruneColumns4() {
new PlanChecker()
.plan(new NereidsAnalyzer(connectContext)
.analyze("select name,cname,grade "
+ "from student left join score "
+ "on student.id = score.sid left join course "
+ "on score.cid = course.cid "
+ "where score.grade > 60"))
.applyTopDown(new ColumnPruning(), connectContext)
PlanChecker.from(connectContext)
.analyze("select name,cname,grade "
+ "from student left join score "
+ "on student.id = score.sid left join course "
+ "on score.cid = course.cid "
+ "where score.grade > 60")
.applyTopDown(new ColumnPruning())
.matches(
logicalProject(
logicalFilter(

View File

@ -17,16 +17,15 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IntegerLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.nereids.util.MemoTestUtils;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
@ -48,13 +47,11 @@ public class MergeConsecutiveFilterTest {
Expression expression3 = new IntegerLiteral(3);
LogicalFilter filter3 = new LogicalFilter(expression3, filter2);
PlannerContext plannerContext = new Memo(filter3)
.newPlannerContext(new ConnectContext())
.setDefaultJobContext();
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(filter3);
List<Rule> rules = Lists.newArrayList(new MergeConsecutiveFilters().build());
plannerContext.bottomUpRewrite(rules);
cascadesContext.bottomUpRewrite(rules);
//check transformed plan
Plan resultPlan = plannerContext.getMemo().copyOut();
Plan resultPlan = cascadesContext.getMemo().copyOut();
System.out.println(resultPlan.treeString());
Assertions.assertTrue(resultPlan instanceof LogicalFilter);
Expression allPredicates = ExpressionUtils.and(expression3,

View File

@ -17,9 +17,8 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
@ -30,7 +29,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.nereids.util.MemoTestUtils;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
@ -52,12 +51,10 @@ public class MergeConsecutiveProjectsTest {
LogicalProject project2 = new LogicalProject(Lists.newArrayList(colA, colB), project1);
LogicalProject project3 = new LogicalProject(Lists.newArrayList(colA), project2);
PlannerContext plannerContext = new Memo(project3)
.newPlannerContext(new ConnectContext())
.setDefaultJobContext();
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(project3);
List<Rule> rules = Lists.newArrayList(new MergeConsecutiveProjects().build());
plannerContext.bottomUpRewrite(rules);
Plan plan = plannerContext.getMemo().copyOut();
cascadesContext.bottomUpRewrite(rules);
Plan plan = cascadesContext.getMemo().copyOut();
System.out.println(plan.treeString());
Assertions.assertTrue(plan instanceof LogicalProject);
Assertions.assertTrue(((LogicalProject<?>) plan).getProjects().equals(Lists.newArrayList(colA)));
@ -96,12 +93,10 @@ public class MergeConsecutiveProjectsTest {
),
project1);
PlannerContext plannerContext = new Memo(project2)
.newPlannerContext(new ConnectContext())
.setDefaultJobContext();
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(project2);
List<Rule> rules = Lists.newArrayList(new MergeConsecutiveProjects().build());
plannerContext.bottomUpRewrite(rules);
Plan plan = plannerContext.getMemo().copyOut();
cascadesContext.bottomUpRewrite(rules);
Plan plan = cascadesContext.getMemo().copyOut();
System.out.println(plan.treeString());
Assertions.assertTrue(plan instanceof LogicalProject);
LogicalProject finalProject = (LogicalProject) plan;

View File

@ -18,7 +18,7 @@
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.analyzer.NereidsAnalyzer;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator;
import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
import org.apache.doris.nereids.parser.NereidsParser;
@ -40,6 +40,7 @@ import java.util.List;
import java.util.Optional;
public class AnalyzeSubQueryTest extends TestWithFeService implements PatternMatchSupported {
private final NereidsParser parser = new NereidsParser();
private final List<String> testSql = ImmutableList.of(
@ -87,11 +88,10 @@ public class AnalyzeSubQueryTest extends TestWithFeService implements PatternMat
public void testTranslateCase() throws Exception {
for (String sql : testSql) {
NamedExpressionUtil.clear();
System.out.println("\n\n***** " + sql + " *****\n\n");
PhysicalPlan plan = new NereidsPlanner(connectContext).plan(
StatementContext statementContext = MemoTestUtils.createStatementContext(connectContext, sql);
PhysicalPlan plan = new NereidsPlanner(statementContext).plan(
parser.parseSingle(sql),
PhysicalProperties.ANY,
connectContext
PhysicalProperties.ANY
);
// Just to check whether translate will throw exception
new PhysicalPlanTranslator().translatePlan(plan, new PlanTranslatorContext());
@ -100,78 +100,77 @@ public class AnalyzeSubQueryTest extends TestWithFeService implements PatternMat
@Test
public void testCaseSubQuery() {
FieldChecker projectChecker = new FieldChecker(ImmutableList.of("projects"));
new PlanChecker().plan(new NereidsAnalyzer(connectContext).analyze(testSql.get(0)))
PlanChecker.from(connectContext)
.analyze(testSql.get(0))
.applyTopDown(new EliminateAliasNode())
.matches(
logicalProject(
logicalProject(
logicalProject(
logicalOlapScan().when(o -> true)
).when(projectChecker.check(ImmutableList.of(ImmutableList.of(
new SlotReference(new ExprId(0), "id", new BigIntType(), true, ImmutableList.of("T")),
new SlotReference(new ExprId(1), "score", new BigIntType(), true, ImmutableList.of("T")))))
)
).when(projectChecker.check(ImmutableList.of(ImmutableList.of(
new SlotReference(new ExprId(0), "id", new BigIntType(), true, ImmutableList.of("T2")),
new SlotReference(new ExprId(1), "score", new BigIntType(), true, ImmutableList.of("T2")))))
logicalOlapScan().when(o -> true)
).when(FieldChecker.check("projects", ImmutableList.of(
new SlotReference(new ExprId(0), "id", new BigIntType(), true, ImmutableList.of("T")),
new SlotReference(new ExprId(1), "score", new BigIntType(), true, ImmutableList.of("T"))))
)
).when(FieldChecker.check("projects", ImmutableList.of(
new SlotReference(new ExprId(0), "id", new BigIntType(), true, ImmutableList.of("T2")),
new SlotReference(new ExprId(1), "score", new BigIntType(), true, ImmutableList.of("T2"))))
)
);
}
@Test
public void testCaseMixed() {
FieldChecker projectChecker = new FieldChecker(ImmutableList.of("projects"));
FieldChecker joinChecker = new FieldChecker(ImmutableList.of("joinType", "condition"));
new PlanChecker().plan(new NereidsAnalyzer(connectContext).analyze(testSql.get(1)))
PlanChecker.from(connectContext)
.analyze(testSql.get(1))
.applyTopDown(new EliminateAliasNode())
.matches(
logicalProject(
logicalJoin(
logicalOlapScan(),
logicalProject(
logicalOlapScan()
).when(projectChecker.check(ImmutableList.of(ImmutableList.of(
new SlotReference(new ExprId(0), "id", new BigIntType(), true, ImmutableList.of("TT2")),
new SlotReference(new ExprId(1), "score", new BigIntType(), true, ImmutableList.of("TT2")))))
)
).when(joinChecker.check(ImmutableList.of(
JoinType.INNER_JOIN,
Optional.of(new EqualTo(
new SlotReference(new ExprId(2), "id", new BigIntType(), true, ImmutableList.of("TT1")),
new SlotReference(new ExprId(0), "id", new BigIntType(), true, ImmutableList.of("T"))))))
)
).when(projectChecker.check(ImmutableList.of(ImmutableList.of(
new SlotReference(new ExprId(2), "id", new BigIntType(), true, ImmutableList.of("TT1")),
new SlotReference(new ExprId(3), "score", new BigIntType(), true, ImmutableList.of("TT1")),
new SlotReference(new ExprId(0), "id", new BigIntType(), true, ImmutableList.of("T")),
new SlotReference(new ExprId(1), "score", new BigIntType(), true, ImmutableList.of("T")))))
logicalProject(
logicalJoin(
logicalOlapScan(),
logicalProject(
logicalOlapScan()
).when(FieldChecker.check("projects", ImmutableList.of(
new SlotReference(new ExprId(0), "id", new BigIntType(), true, ImmutableList.of("TT2")),
new SlotReference(new ExprId(1), "score", new BigIntType(), true, ImmutableList.of("TT2"))))
)
)
.when(FieldChecker.check("joinType", JoinType.INNER_JOIN))
.when(FieldChecker.check("condition",
Optional.of(new EqualTo(
new SlotReference(new ExprId(2), "id", new BigIntType(), true, ImmutableList.of("TT1")),
new SlotReference(new ExprId(0), "id", new BigIntType(), true, ImmutableList.of("T")))))
)
).when(FieldChecker.check("projects", ImmutableList.of(
new SlotReference(new ExprId(2), "id", new BigIntType(), true, ImmutableList.of("TT1")),
new SlotReference(new ExprId(3), "score", new BigIntType(), true, ImmutableList.of("TT1")),
new SlotReference(new ExprId(0), "id", new BigIntType(), true, ImmutableList.of("T")),
new SlotReference(new ExprId(1), "score", new BigIntType(), true, ImmutableList.of("T"))))
)
);
}
@Test
public void testCaseJoinSameTable() {
FieldChecker projectChecker = new FieldChecker(ImmutableList.of("projects"));
FieldChecker joinChecker = new FieldChecker(ImmutableList.of("joinType", "condition"));
new PlanChecker().plan(new NereidsAnalyzer(connectContext).analyze(testSql.get(5)))
PlanChecker.from(connectContext)
.analyze(testSql.get(5))
.applyTopDown(new EliminateAliasNode())
.matches(
logicalProject(
logicalJoin(
logicalOlapScan(),
logicalOlapScan()
).when(joinChecker.check(ImmutableList.of(
JoinType.INNER_JOIN,
Optional.of(new EqualTo(
new SlotReference(new ExprId(0), "id", new BigIntType(), true, ImmutableList.of("default_cluster:test", "T1")),
new SlotReference(new ExprId(2), "id", new BigIntType(), true, ImmutableList.of("T2"))))))
)
).when(projectChecker.check(ImmutableList.of(ImmutableList.of(
new SlotReference(new ExprId(0), "id", new BigIntType(), true, ImmutableList.of("default_cluster:test", "T1")),
new SlotReference(new ExprId(1), "score", new BigIntType(), true, ImmutableList.of("default_cluster:test", "T1")),
new SlotReference(new ExprId(2), "id", new BigIntType(), true, ImmutableList.of("T2")),
new SlotReference(new ExprId(3), "score", new BigIntType(), true, ImmutableList.of("T2")))))
logicalProject(
logicalJoin(
logicalOlapScan(),
logicalOlapScan()
)
.when(FieldChecker.check("joinType", JoinType.INNER_JOIN))
.when(FieldChecker.check("condition", Optional.of(new EqualTo(
new SlotReference(new ExprId(0), "id", new BigIntType(), true, ImmutableList.of("default_cluster:test", "T1")),
new SlotReference(new ExprId(2), "id", new BigIntType(), true, ImmutableList.of("T2")))))
)
).when(FieldChecker.check("projects", ImmutableList.of(
new SlotReference(new ExprId(0), "id", new BigIntType(), true, ImmutableList.of("default_cluster:test", "T1")),
new SlotReference(new ExprId(1), "score", new BigIntType(), true, ImmutableList.of("default_cluster:test", "T1")),
new SlotReference(new ExprId(2), "id", new BigIntType(), true, ImmutableList.of("T2")),
new SlotReference(new ExprId(3), "score", new BigIntType(), true, ImmutableList.of("T2"))))
)
);
}
}

View File

@ -20,36 +20,23 @@ package org.apache.doris.nereids.util;
import org.junit.jupiter.api.Assertions;
import java.lang.reflect.Field;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.IntStream;
public class FieldChecker {
public final List<String> fields;
public FieldChecker(List<String> fields) {
this.fields = fields;
}
public <T> Predicate<T> check(List<Object> valueList) {
public static <T> Predicate<T> check(String fieldName, Object value) {
return (o) -> {
Assertions.assertEquals(fields.size(), valueList.size());
Class<?> classInfo = o.getClass();
IntStream.range(0, valueList.size()).forEach(i -> {
Field field;
try {
field = classInfo.getDeclaredField(this.fields.get(i));
} catch (NoSuchFieldException e) {
throw new RuntimeException(e);
}
field.setAccessible(true);
try {
Assertions.assertEquals(valueList.get(i), field.get(o));
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
});
Field field;
try {
field = o.getClass().getDeclaredField(fieldName);
} catch (NoSuchFieldException e) {
throw new RuntimeException(e);
}
field.setAccessible(true);
try {
Assertions.assertEquals(value, field.get(o));
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
return true;
};
}

View File

@ -0,0 +1,111 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.util;
import org.apache.doris.analysis.UserIdentity;
import org.apache.doris.catalog.Env;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.OriginStatement;
import org.apache.doris.system.SystemInfoService;
import java.io.IOException;
import java.nio.channels.SocketChannel;
/**
* MemoUtils.
*/
public class MemoTestUtils {
public static ConnectContext createConnectContext() {
return createCtx(UserIdentity.ROOT, "127.0.0.1");
}
public static LogicalPlan parse(String sql) {
return new NereidsParser().parseSingle(sql);
}
public static StatementContext createStatementContext(String sql) {
return createStatementContext(createConnectContext(), sql);
}
public static StatementContext createStatementContext(ConnectContext connectContext, String sql) {
return new StatementContext(connectContext, new OriginStatement(sql, 0));
}
public static CascadesContext createCascadesContext(String sql) {
return createCascadesContext(createConnectContext(), sql);
}
public static CascadesContext createCascadesContext(Plan initPlan) {
return createCascadesContext(createConnectContext(), initPlan);
}
public static CascadesContext createCascadesContext(ConnectContext connectContext, String sql) {
StatementContext statementCtx = createStatementContext(connectContext, sql);
LogicalPlan initPlan = parse(sql);
return CascadesContext.newContext(statementCtx, initPlan);
}
public static CascadesContext createCascadesContext(StatementContext statementContext, String sql) {
LogicalPlan initPlan = new NereidsParser().parseSingle(sql);
return CascadesContext.newContext(statementContext, initPlan);
}
public static CascadesContext createCascadesContext(ConnectContext connectContext, Plan initPlan) {
StatementContext statementCtx = createStatementContext(connectContext, "");
return CascadesContext.newContext(statementCtx, initPlan);
}
public static CascadesContext createCascadesContext(StatementContext statementContext, Plan initPlan) {
return CascadesContext.newContext(statementContext, initPlan);
}
public static LogicalPlan analyze(String sql) {
CascadesContext cascadesContext = createCascadesContext(sql);
cascadesContext.newAnalyzer().analyze();
return (LogicalPlan) cascadesContext.getMemo().copyOut();
}
/**
* create test connection context.
* @param user connect user
* @param host connect host
* @return ConnectContext
* @throws IOException exception
*/
public static ConnectContext createCtx(UserIdentity user, String host) {
try {
SocketChannel channel = SocketChannel.open();
ConnectContext ctx = new ConnectContext(channel);
ctx.setCluster(SystemInfoService.DEFAULT_CLUSTER);
ctx.setCurrentUserIdentity(user);
ctx.setQualifiedUser(user.getQualifiedUser());
ctx.setRemoteIP(host);
ctx.setEnv(Env.getCurrentEnv());
ctx.setThreadLocalInfo();
return ctx;
} catch (Throwable t) {
throw new IllegalStateException("can not create test connect context", t);
}
}
}

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.pattern.GroupExpressionMatching;
import org.apache.doris.nereids.pattern.PatternDescriptor;
@ -30,38 +31,56 @@ import org.junit.jupiter.api.Assertions;
* Utility to apply rules to plan and check output plan matches the expected pattern.
*/
public class PlanChecker {
private ConnectContext connectContext;
private CascadesContext cascadesContext;
private Plan inputPlan;
private Memo memo;
public PlanChecker plan(Plan plan) {
this.inputPlan = plan;
public PlanChecker(ConnectContext connectContext) {
this.connectContext = connectContext;
}
public PlanChecker(CascadesContext cascadesContext) {
this.connectContext = cascadesContext.getConnectContext();
this.cascadesContext = cascadesContext;
}
public PlanChecker analyze(String sql) {
this.cascadesContext = MemoTestUtils.createCascadesContext(connectContext, sql);
this.cascadesContext.newAnalyzer().analyze();
return this;
}
public PlanChecker analyze(Plan plan) {
this.cascadesContext = MemoTestUtils.createCascadesContext(connectContext, plan);
this.cascadesContext.newAnalyzer().analyze();
return this;
}
public PlanChecker applyTopDown(RuleFactory rule) {
return applyTopDown(rule, new ConnectContext());
}
public PlanChecker applyTopDown(RuleFactory rule, ConnectContext connectContext) {
memo = PlanRewriter.topDownRewriteMemo(inputPlan, connectContext, rule);
cascadesContext.topDownRewrite(rule);
return this;
}
public PlanChecker applyBottomUp(RuleFactory rule) {
return applyBottomUp(rule);
}
public PlanChecker applyBottomUp(RuleFactory rule, ConnectContext connectContext) {
memo = PlanRewriter.bottomUpRewriteMemo(inputPlan, connectContext, rule);
cascadesContext.bottomUpRewrite(rule);
return this;
}
public void matches(PatternDescriptor<? extends Plan> patternDesc) {
Memo memo = cascadesContext.getMemo();
GroupExpressionMatching matchResult = new GroupExpressionMatching(patternDesc.pattern,
memo.getRoot().getLogicalExpression());
Assertions.assertTrue(matchResult.iterator().hasNext(), () ->
"pattern not match, plan :\n" + memo.getRoot().getLogicalExpression().getPlan().treeString() + "\n"
);
}
public static PlanChecker from(ConnectContext connectContext) {
return new PlanChecker(connectContext);
}
public static PlanChecker from(ConnectContext connectContext, Plan initPlan) {
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(connectContext, initPlan);
return new PlanChecker(cascadesContext);
}
}

View File

@ -17,11 +17,13 @@
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleFactory;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.OriginStatement;
/**
* Utility to copy plan into {@link Memo} and apply rewrite rules.
@ -37,16 +39,14 @@ public class PlanRewriter {
public static Memo bottomUpRewriteMemo(Plan plan, ConnectContext connectContext, RuleFactory... rules) {
return new Memo(plan)
.newPlannerContext(connectContext)
.setDefaultJobContext()
.newCascadesContext(new StatementContext(connectContext, new OriginStatement("", 0)))
.topDownRewrite(rules)
.getMemo();
}
public static Memo bottomUpRewriteMemo(Plan plan, ConnectContext connectContext, Rule... rules) {
return new Memo(plan)
.newPlannerContext(connectContext)
.setDefaultJobContext()
.newCascadesContext(new StatementContext(connectContext, new OriginStatement("", 0)))
.topDownRewrite(rules)
.getMemo();
}
@ -61,16 +61,14 @@ public class PlanRewriter {
public static Memo topDownRewriteMemo(Plan plan, ConnectContext connectContext, RuleFactory... rules) {
return new Memo(plan)
.newPlannerContext(connectContext)
.setDefaultJobContext()
.newCascadesContext(new StatementContext(connectContext, new OriginStatement("", 0)))
.topDownRewrite(rules)
.getMemo();
}
public static Memo topDownRewriteMemo(Plan plan, ConnectContext connectContext, Rule... rules) {
return new Memo(plan)
.newPlannerContext(connectContext)
.setDefaultJobContext()
.newCascadesContext(new StatementContext(connectContext, new OriginStatement("", 0)))
.topDownRewrite(rules)
.getMemo();
}

View File

@ -41,6 +41,10 @@ import org.apache.doris.common.Config;
import org.apache.doris.common.DdlException;
import org.apache.doris.common.FeConstants;
import org.apache.doris.common.util.SqlParserUtils;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.planner.Planner;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.OriginStatement;
@ -132,6 +136,22 @@ public abstract class TestWithFeService {
return createCtx(UserIdentity.ROOT, "127.0.0.1");
}
protected StatementContext createStatementCtx(String sql) {
return new StatementContext(connectContext, new OriginStatement(sql, 0));
}
protected CascadesContext createCascadesContext(String sql) {
StatementContext statementCtx = createStatementCtx(sql);
LogicalPlan initPlan = new NereidsParser().parseSingle(sql);
return CascadesContext.newContext(statementCtx, initPlan);
}
public LogicalPlan analyze(String sql) {
CascadesContext cascadesContext = createCascadesContext(sql);
cascadesContext.newAnalyzer().analyze();
return (LogicalPlan) cascadesContext.getMemo().copyOut();
}
protected ConnectContext createCtx(UserIdentity user, String host) throws IOException {
SocketChannel channel = SocketChannel.open();
ConnectContext ctx = new ConnectContext(channel);