[feature](nereids) support common table expression (#12742)

Support common table expression(CTE) in Nereids:
- Just implemented inline CTE, which means we will copy the logicalPlan of CTE everywhere it is referenced;
- If the name of CTE is the same as an existing table or view, we will choose CTE first;
This commit is contained in:
Fy
2022-11-02 23:41:53 +08:00
committed by GitHub
parent b83744d2f6
commit e021705053
27 changed files with 1226 additions and 74 deletions

View File

@ -49,8 +49,8 @@ singleStatement
;
statement
: query #statementDefault
| (EXPLAIN | DESC | DESCRIBE) level=(VERBOSE | GRAPH)? query #explain
: cte? query #statementDefault
| (EXPLAIN | DESC | DESCRIBE) level=(VERBOSE | GRAPH)? query #explain
;
// -----------------Query-----------------
@ -76,6 +76,18 @@ querySpecification
havingClause? #regularQuerySpecification
;
cte
: WITH aliasQuery (COMMA aliasQuery)*
;
aliasQuery
: identifier columnAliases? AS LEFT_PAREN query RIGHT_PAREN
;
columnAliases
: LEFT_PAREN identifier (COMMA identifier)* RIGHT_PAREN
;
selectClause
: SELECT selectHint? namedExpressionSeq
;

View File

@ -56,7 +56,6 @@ public class CascadesContext {
private JobContext currentJobContext;
// subqueryExprIsAnalyzed: whether the subquery has been analyzed.
private Map<SubqueryExpr, Boolean> subqueryExprIsAnalyzed;
private RuntimeFilterContext runtimeFilterContext;
/**

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids;
import org.apache.doris.analysis.StatementBase;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.nereids.rules.analysis.CTEContext;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.qe.ConnectContext;
@ -39,12 +40,19 @@ public class StatementContext {
private StatementBase parsedStatement;
private CTEContext cteContext;
public StatementContext() {
}
public StatementContext(ConnectContext connectContext, OriginStatement originStatement) {
this(connectContext, originStatement, new CTEContext());
}
public StatementContext(ConnectContext connectContext, OriginStatement originStatement, CTEContext cteContext) {
this.connectContext = connectContext;
this.originStatement = originStatement;
this.cteContext = cteContext;
}
public void setConnectContext(ConnectContext connectContext) {
@ -75,6 +83,14 @@ public class StatementContext {
return relationIdGenerator.getNextId();
}
public CTEContext getCteContext() {
return cteContext;
}
public void setCteContext(CTEContext cteContext) {
this.cteContext = cteContext;
}
public void setParsedStatement(StatementBase parsedStatement) {
this.parsedStatement = parsedStatement;
}

View File

@ -41,8 +41,8 @@ public class NereidsAnalyzer {
}
public NereidsAnalyzer(CascadesContext cascadesContext, Optional<Scope> outerScope) {
this.cascadesContext = Objects.requireNonNull(cascadesContext, "cascadesContext can not be null");
this.outerScope = Objects.requireNonNull(outerScope, "outerScope can not be null");
this.cascadesContext = Objects.requireNonNull(cascadesContext, "cascadesContext cannot be null");
this.outerScope = Objects.requireNonNull(outerScope, "outerScope cannot be null");
}
/**
@ -64,4 +64,5 @@ public class NereidsAnalyzer {
public Optional<Scope> getOuterScope() {
return outerScope;
}
}

View File

@ -28,14 +28,23 @@ import org.apache.doris.nereids.types.DataType;
import com.google.common.base.Preconditions;
import java.util.List;
import java.util.Optional;
/**
* Expression for unbound alias.
*/
public class UnboundAlias extends NamedExpression implements UnaryExpression, Unbound, PropagateNullable {
private Optional<String> alias;
public UnboundAlias(Expression child) {
super(child);
this.alias = Optional.empty();
}
public UnboundAlias(Expression child, String alias) {
super(child);
this.alias = Optional.of(alias);
}
@Override
@ -43,9 +52,21 @@ public class UnboundAlias extends NamedExpression implements UnaryExpression, Un
return child().getDataType();
}
@Override
public String toSql() {
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append("(" + child() + ")");
alias.ifPresent(name -> stringBuilder.append(" AS " + name));
return stringBuilder.toString();
}
@Override
public String toString() {
return "UnboundAlias(" + child() + ")";
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append("UnboundAlias(" + child() + ")");
alias.ifPresent(name -> stringBuilder.append(" AS " + name));
return stringBuilder.toString();
}
@Override
@ -58,4 +79,8 @@ public class UnboundAlias extends NamedExpression implements UnaryExpression, Un
Preconditions.checkArgument(children.size() == 1);
return new UnboundAlias(children.get(0));
}
public Optional<String> getAlias() {
return alias;
}
}

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.rules.analysis.BindFunction;
import org.apache.doris.nereids.rules.analysis.BindRelation;
import org.apache.doris.nereids.rules.analysis.BindSlotReference;
import org.apache.doris.nereids.rules.analysis.ProjectToGlobalAggregate;
import org.apache.doris.nereids.rules.analysis.RegisterCTE;
import org.apache.doris.nereids.rules.analysis.ResolveHaving;
import org.apache.doris.nereids.rules.analysis.Scope;
@ -42,6 +43,9 @@ public class AnalyzeRulesJob extends BatchRulesJob {
public AnalyzeRulesJob(CascadesContext cascadesContext, Optional<Scope> scope) {
super(cascadesContext);
rulesJob.addAll(ImmutableList.of(
bottomUpBatch(ImmutableList.of(
new RegisterCTE()
)),
bottomUpBatch(ImmutableList.of(
new BindRelation(),
new BindSlotReference(scope),

View File

@ -120,16 +120,16 @@ public class GroupExpression {
*/
public void replaceChild(Group originChild, Group newChild) {
originChild.removeParentExpression(this);
List<Group> groups = Lists.newArrayListWithCapacity(this.children.size());
ImmutableList.Builder<Group> groupBuilder = ImmutableList.builderWithExpectedSize(arity());
for (int i = 0; i < children.size(); i++) {
if (children.get(i) == originChild) {
groups.add(newChild);
groupBuilder.add(newChild);
newChild.addParentExpression(this);
} else {
groups.add(child(i));
groupBuilder.add(child(i));
}
}
children = ImmutableList.copyOf(groups);
newChild.addParentExpression(this);
this.children = groupBuilder.build();
}
public void setChild(int index, Group group) {

View File

@ -21,6 +21,7 @@ import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.DorisParser;
import org.apache.doris.nereids.DorisParser.AggClauseContext;
import org.apache.doris.nereids.DorisParser.AliasQueryContext;
import org.apache.doris.nereids.DorisParser.AliasedQueryContext;
import org.apache.doris.nereids.DorisParser.AliasedRelationContext;
import org.apache.doris.nereids.DorisParser.ArithmeticBinaryContext;
@ -28,6 +29,7 @@ import org.apache.doris.nereids.DorisParser.ArithmeticUnaryContext;
import org.apache.doris.nereids.DorisParser.BooleanLiteralContext;
import org.apache.doris.nereids.DorisParser.ColumnReferenceContext;
import org.apache.doris.nereids.DorisParser.ComparisonContext;
import org.apache.doris.nereids.DorisParser.CteContext;
import org.apache.doris.nereids.DorisParser.DecimalLiteralContext;
import org.apache.doris.nereids.DorisParser.DereferenceContext;
import org.apache.doris.nereids.DorisParser.ExistContext;
@ -64,6 +66,7 @@ import org.apache.doris.nereids.DorisParser.SingleStatementContext;
import org.apache.doris.nereids.DorisParser.SortClauseContext;
import org.apache.doris.nereids.DorisParser.SortItemContext;
import org.apache.doris.nereids.DorisParser.StarContext;
import org.apache.doris.nereids.DorisParser.StatementDefaultContext;
import org.apache.doris.nereids.DorisParser.StringLiteralContext;
import org.apache.doris.nereids.DorisParser.SubqueryExpressionContext;
import org.apache.doris.nereids.DorisParser.TableAliasContext;
@ -84,7 +87,6 @@ import org.apache.doris.nereids.exceptions.ParseException;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.properties.SelectHint;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Between;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
@ -131,6 +133,7 @@ import org.apache.doris.nereids.trees.plans.commands.Command;
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand;
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand.ExplainLevel;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTE;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
@ -195,6 +198,12 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
return ParserUtils.withOrigin(ctx, () -> (LogicalPlan) visit(ctx.statement()));
}
@Override
public LogicalPlan visitStatementDefault(StatementDefaultContext ctx) {
LogicalPlan plan = visitQuery(ctx.query());
return ctx.cte() == null ? plan : withCte(ctx.cte(), plan);
}
/**
* Visit multi-statements.
*/
@ -216,6 +225,30 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
* Plan parsing
* ******************************************************************************************** */
/**
* process CTE and store the results in a logical plan node LogicalCTE
*/
public LogicalPlan withCte(CteContext ctx, LogicalPlan plan) {
return new LogicalCTE<>(visit(ctx.aliasQuery(), LogicalSubQueryAlias.class), plan);
}
/**
* processs CTE's alias queries and column aliases
*/
@Override
public LogicalSubQueryAlias visitAliasQuery(AliasQueryContext ctx) {
return ParserUtils.withOrigin(ctx, () -> {
LogicalPlan queryPlan = plan(ctx.query());
List<String> columnNames = null;
if (ctx.columnAliases() != null) {
columnNames = ctx.columnAliases().identifier().stream()
.map(id -> id.getText())
.collect(ImmutableList.toImmutableList());
}
return new LogicalSubQueryAlias(ctx.identifier().getText(), Optional.ofNullable(columnNames), queryPlan);
});
}
@Override
public Command visitExplain(ExplainContext ctx) {
LogicalPlan logicalPlan = plan(ctx.query());
@ -315,7 +348,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
return ParserUtils.withOrigin(ctx, () -> {
Expression expression = getExpression(ctx.expression());
if (ctx.name != null) {
return new Alias(expression, ctx.name.getText());
return new UnboundAlias(expression, ctx.name.getText());
} else {
return expression;
}

View File

@ -38,7 +38,7 @@ public class PlanPreprocessors {
public LogicalPlan process(LogicalPlan logicalPlan) {
LogicalPlan resultPlan = logicalPlan;
for (PlanPreprocessor processor : getProcessors()) {
resultPlan = (LogicalPlan) logicalPlan.accept(processor, statementContext);
resultPlan = (LogicalPlan) resultPlan.accept(processor, statementContext);
}
return resultPlan;
}

View File

@ -50,6 +50,7 @@ public enum RuleType {
RESOLVE_PROJECT_ALIAS(RuleTypeClass.REWRITE),
RESOLVE_AGGREGATE_ALIAS(RuleTypeClass.REWRITE),
PROJECT_TO_GLOBAL_AGGREGATE(RuleTypeClass.REWRITE),
REGISTER_CTE(RuleTypeClass.REWRITE),
// check analysis rule
CHECK_ANALYSIS(RuleTypeClass.CHECK),

View File

@ -42,6 +42,7 @@ import java.util.List;
* Rule to bind relations in query plan.
*/
public class BindRelation extends OneAnalysisRuleFactory {
@Override
public Rule build() {
return unboundRelation().thenApply(ctx -> {
@ -49,7 +50,7 @@ public class BindRelation extends OneAnalysisRuleFactory {
switch (nameParts.size()) {
case 1: { // table
// Use current database name from catalog.
return bindWithCurrentDb(ctx.cascadesContext, nameParts);
return bindWithCurrentDb(ctx.cascadesContext, nameParts.get(0));
}
case 2: { // db.table
// Use database name from table name parts.
@ -73,9 +74,15 @@ public class BindRelation extends OneAnalysisRuleFactory {
}
}
private LogicalPlan bindWithCurrentDb(CascadesContext cascadesContext, List<String> nameParts) {
private LogicalPlan bindWithCurrentDb(CascadesContext cascadesContext, String tableName) {
// check if it is a CTE's name
CTEContext cteContext = cascadesContext.getStatementContext().getCteContext();
if (cteContext.containsCTE(tableName)) {
return new LogicalSubQueryAlias<>(tableName, cteContext.getAnalyzedCTEPlan(tableName));
}
String dbName = cascadesContext.getConnectContext().getDatabase();
Table table = getTable(dbName, nameParts.get(0), cascadesContext.getConnectContext().getEnv());
Table table = getTable(dbName, tableName, cascadesContext.getConnectContext().getEnv());
// TODO: should generate different Scan sub class according to table's type
if (table.getType() == TableType.OLAP) {
return new LogicalOlapScan(cascadesContext.getStatementContext().getNextRelationId(),

View File

@ -68,14 +68,15 @@ import java.util.stream.Stream;
* BindSlotReference.
*/
public class BindSlotReference implements AnalysisRuleFactory {
private final Optional<Scope> outerScope;
public BindSlotReference() {
this(Optional.empty());
}
public BindSlotReference(Optional<Scope> outputScope) {
this.outerScope = Objects.requireNonNull(outputScope, "outerScope can not be null");
public BindSlotReference(Optional<Scope> outerScope) {
this.outerScope = Objects.requireNonNull(outerScope, "outerScope cannot be null");
}
private Scope toScope(List<Slot> slots) {
@ -217,6 +218,9 @@ public class BindSlotReference implements AnalysisRuleFactory {
@Override
public Expression visitUnboundAlias(UnboundAlias unboundAlias, PlannerContext context) {
Expression child = unboundAlias.child().accept(this, context);
if (unboundAlias.getAlias().isPresent()) {
return new Alias(child, unboundAlias.getAlias().get());
}
if (child instanceof NamedExpression) {
return new Alias(child, ((NamedExpression) child).getName());
} else {

View File

@ -0,0 +1,64 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import java.util.HashMap;
import java.util.Map;
/**
* Context used for CTE analysis and register
*/
public class CTEContext {
// store CTE name and both initial and analyzed LogicalPlan of with query;
// The initial LogicalPlan is used to inline a CTE if it is referenced by another CTE,
// and the analyzed LogicalPlan will be if it is referenced by the main query.
private Map<String, LogicalPlan> initialCtePlans;
private Map<String, LogicalPlan> analyzedCtePlans;
public CTEContext() {
initialCtePlans = new HashMap<>();
analyzedCtePlans = new HashMap<>();
}
/**
* check if cteName can be found in current order
*/
public boolean containsCTE(String cteName) {
return initialCtePlans.containsKey(cteName);
}
public LogicalPlan getInitialCTEPlan(String cteName) {
return initialCtePlans.get(cteName);
}
public LogicalPlan getAnalyzedCTEPlan(String cteName) {
return analyzedCtePlans.get(cteName);
}
public void putInitialPlan(String cteName, LogicalPlan plan) {
initialCtePlans.put(cteName, plan);
}
public void putAnalyzedPlan(String cteName, LogicalPlan plan) {
analyzedCtePlans.put(cteName, plan);
}
}

View File

@ -0,0 +1,169 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.analyzer.UnboundAlias;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTE;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* Register CTE, includes checking columnAliases, checking CTE name, analyzing each CTE and store the
* analyzed logicalPlan of CTE's query in CTEContext;
* A LogicalProject node will be added to the root of the initial logicalPlan if there exist columnAliases.
* Node LogicalCTE will be eliminated after registering.
*/
public class RegisterCTE extends OneAnalysisRuleFactory {
@Override
public Rule build() {
return logicalCTE().thenApply(ctx -> {
LogicalCTE<GroupPlan> logicalCTE = ctx.root;
register(logicalCTE.getAliasQueries(), ctx.statementContext);
return (LogicalPlan) logicalCTE.child();
}).toRule(RuleType.REGISTER_CTE);
}
/**
* register and store CTEs in CTEContext
*/
private void register(List<LogicalSubQueryAlias> aliasQueryList, StatementContext statementContext) {
CTEContext cteContext = statementContext.getCteContext();
for (LogicalSubQueryAlias<LogicalPlan> aliasQuery : aliasQueryList) {
String cteName = aliasQuery.getAlias();
if (cteContext.containsCTE(cteName)) {
throw new AnalysisException("CTE name [" + cteName + "] cannot be used more than once.");
}
// inline CTE's initialPlan if it is referenced by another CTE
LogicalPlan plan = aliasQuery.child();
plan = (LogicalPlan) new CTEVisitor().inlineCTE(cteContext, plan);
cteContext.putInitialPlan(cteName, plan);
// analyze CTE's initialPlan
CascadesContext cascadesContext = new Memo(plan).newCascadesContext(statementContext);
cascadesContext.newAnalyzer().analyze();
LogicalPlan analyzedPlan = (LogicalPlan) cascadesContext.getMemo().copyOut(false);
if (aliasQuery.getColumnAliases().isPresent()) {
analyzedPlan = withColumnAliases(analyzedPlan, aliasQuery, cteContext);
}
cteContext.putAnalyzedPlan(cteName, analyzedPlan);
}
}
/**
* deal with columnAliases of CTE
*/
private LogicalPlan withColumnAliases(LogicalPlan analyzedPlan,
LogicalSubQueryAlias<LogicalPlan> aliasQuery, CTEContext cteContext) {
List<Slot> outputSlots = analyzedPlan.getOutput();
List<String> columnAliases = aliasQuery.getColumnAliases().get();
checkColumnAlias(aliasQuery, outputSlots);
// if this CTE has columnAlias, we should add an extra LogicalProject to both its initialPlan and analyzedPlan,
// which is used to store columnAlias
// projects for initialPlan
List<NamedExpression> unboundProjects = IntStream.range(0, outputSlots.size())
.mapToObj(i -> i >= columnAliases.size()
? new UnboundSlot(outputSlots.get(i).getName())
: new UnboundAlias(new UnboundSlot(outputSlots.get(i).getName()), columnAliases.get(i)))
.collect(Collectors.toList());
String name = aliasQuery.getAlias();
LogicalPlan initialPlan = cteContext.getInitialCTEPlan(name);
cteContext.putInitialPlan(name, new LogicalProject<>(unboundProjects, initialPlan));
// projects for analyzedPlan
List<NamedExpression> boundedProjects = IntStream.range(0, outputSlots.size())
.mapToObj(i -> i >= columnAliases.size()
? outputSlots.get(i)
: new Alias(outputSlots.get(i), columnAliases.get(i)))
.collect(Collectors.toList());
return new LogicalProject<>(boundedProjects, analyzedPlan);
}
/**
* check columnAliases' size and name
*/
private void checkColumnAlias(LogicalSubQueryAlias<LogicalPlan> aliasQuery, List<Slot> outputSlots) {
List<String> columnAlias = aliasQuery.getColumnAliases().get();
// if the size of columnAlias is smaller than outputSlots' size, we will replace the corresponding number
// of front slots with columnAlias.
if (columnAlias.size() > outputSlots.size()) {
throw new AnalysisException("CTE [" + aliasQuery.getAlias() + "] returns " + columnAlias.size()
+ " columns, but " + outputSlots.size() + " labels were specified. The number of column labels must "
+ "be smaller or equal to the number of returned columns.");
}
Set<String> names = new HashSet<>();
// column alias cannot be used more than once
columnAlias.stream().forEach(alias -> {
if (names.contains(alias.toLowerCase())) {
throw new AnalysisException("Duplicated CTE column alias: [" + alias.toLowerCase()
+ "] in CTE [" + aliasQuery.getAlias() + "]");
}
names.add(alias);
});
}
private class CTEVisitor extends DefaultPlanRewriter<CTEContext> {
@Override
public LogicalPlan visitUnboundRelation(UnboundRelation unboundRelation, CTEContext cteContext) {
// confirm if it is a CTE
if (unboundRelation.getNameParts().size() != 1) {
return unboundRelation;
}
String name = unboundRelation.getTableName();
if (cteContext.containsCTE(name)) {
return new LogicalSubQueryAlias<>(name, cteContext.getInitialCTEPlan(name));
}
return unboundRelation;
}
public Plan inlineCTE(CTEContext cteContext, LogicalPlan ctePlan) {
return ctePlan.accept(this, cteContext);
}
}
}

View File

@ -24,6 +24,7 @@ public enum PlanType {
UNKNOWN,
// logical plan
LOGICAL_CTE,
LOGICAL_SUBQUERY_ALIAS,
LOGICAL_UNBOUND_ONE_ROW_RELATION,
LOGICAL_EMPTY_RELATION,

View File

@ -0,0 +1,116 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.plans.logical;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
/**
* Logical Node for CTE
*/
public class LogicalCTE<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_TYPE> {
private final List<LogicalSubQueryAlias> aliasQueries;
public LogicalCTE(List<LogicalSubQueryAlias> aliasQueries, CHILD_TYPE child) {
this(aliasQueries, Optional.empty(), Optional.empty(), child);
}
public LogicalCTE(List<LogicalSubQueryAlias> aliasQueries, Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, CHILD_TYPE child) {
super(PlanType.LOGICAL_CTE, groupExpression, logicalProperties, child);
this.aliasQueries = aliasQueries;
}
public List<LogicalSubQueryAlias> getAliasQueries() {
return aliasQueries;
}
/**
* In fact, the action of LogicalCTE is to store and register with clauses, and this logical node will be
* eliminated immediately after finishing the process of with-clause registry; This process is executed before
* all the other analyze and optimize rules, so this function will not be called.
*/
@Override
public List<Slot> computeOutput() {
return child().getOutput();
}
@Override
public String toString() {
return Utils.toSqlString("LogicalCTE",
"aliasQueries", aliasQueries
);
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
LogicalCTE that = (LogicalCTE) o;
return aliasQueries.equals(that.aliasQueries);
}
@Override
public int hashCode() {
return Objects.hash(aliasQueries);
}
@Override
public Plan withChildren(List<Plan> children) {
Preconditions.checkArgument(aliasQueries.size() > 0);
return new LogicalCTE<>(aliasQueries, children.get(0));
}
@Override
public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
return visitor.visitLogicalCTE(this, context);
}
@Override
public List<Expression> getExpressions() {
return Collections.emptyList();
}
@Override
public Plan withGroupExpression(Optional<GroupExpression> groupExpression) {
return new LogicalCTE<>(aliasQueries, groupExpression, Optional.of(getLogicalProperties()), child());
}
@Override
public Plan withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
return new LogicalCTE<>(aliasQueries, Optional.empty(), logicalProperties, child());
}
}

View File

@ -169,13 +169,13 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
public String toString() {
return Utils.toSqlString("LogicalJoin",
"type", joinType,
"hashJoinCondition", hashJoinConjuncts,
"otherJoinCondition", otherJoinConjuncts
"hashJoinConjuncts", hashJoinConjuncts,
"otherJoinConjuncts", otherJoinConjuncts
);
}
// TODO:
// 1. consider the order of conjucts in otherJoinConjuncts and hashJoinConditions
// 1. consider the order of conjuncts in otherJoinConjuncts and hashJoinConjuncts
@Override
public boolean equals(Object o) {
if (this == o) {

View File

@ -28,6 +28,7 @@ import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import org.apache.commons.lang.StringUtils;
import java.util.List;
import java.util.Objects;
@ -42,14 +43,22 @@ import java.util.stream.Collectors;
public class LogicalSubQueryAlias<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_TYPE> {
private final String alias;
private final Optional<List<String>> columnAliases;
public LogicalSubQueryAlias(String tableAlias, CHILD_TYPE child) {
this(tableAlias, Optional.empty(), Optional.empty(), child);
this(tableAlias, Optional.empty(), Optional.empty(), Optional.empty(), child);
}
public LogicalSubQueryAlias(String tableAlias, Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, CHILD_TYPE child) {
public LogicalSubQueryAlias(String tableAlias, Optional<List<String>> columnAliases, CHILD_TYPE child) {
this(tableAlias, columnAliases, Optional.empty(), Optional.empty(), child);
}
public LogicalSubQueryAlias(String tableAlias, Optional<List<String>> columnAliases,
Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, CHILD_TYPE child) {
super(PlanType.LOGICAL_SUBQUERY_ALIAS, groupExpression, logicalProperties, child);
this.alias = tableAlias;
this.columnAliases = columnAliases;
}
@Override
@ -63,8 +72,18 @@ public class LogicalSubQueryAlias<CHILD_TYPE extends Plan> extends LogicalUnary<
return alias;
}
public Optional<List<String>> getColumnAliases() {
return columnAliases;
}
@Override
public String toString() {
if (columnAliases.isPresent()) {
return Utils.toSqlString("LogicalSubQueryAlias",
"alias", alias,
"columnAliases", StringUtils.join(columnAliases.get(), ",")
);
}
return Utils.toSqlString("LogicalSubQueryAlias",
"alias", alias
);
@ -105,11 +124,13 @@ public class LogicalSubQueryAlias<CHILD_TYPE extends Plan> extends LogicalUnary<
@Override
public Plan withGroupExpression(Optional<GroupExpression> groupExpression) {
return new LogicalSubQueryAlias<>(alias, groupExpression, Optional.of(getLogicalProperties()), child());
return new LogicalSubQueryAlias<>(alias, columnAliases, groupExpression,
Optional.of(getLogicalProperties()), child());
}
@Override
public Plan withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
return new LogicalSubQueryAlias<>(alias, Optional.empty(), logicalProperties, child());
return new LogicalSubQueryAlias<>(alias, columnAliases, Optional.empty(),
logicalProperties, child());
}
}

View File

@ -26,6 +26,7 @@ import org.apache.doris.nereids.trees.plans.commands.ExplainCommand;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalAssertNumRows;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTE;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
@ -83,6 +84,10 @@ public abstract class PlanVisitor<R, C> {
// Logical plans
// *******************************
public R visitLogicalCTE(LogicalCTE<? extends Plan> cte, C context) {
return visit(cte, context);
}
public R visitSubQueryAlias(LogicalSubQueryAlias<? extends Plan> alias, C context) {
return visit(alias, context);
}

View File

@ -93,6 +93,7 @@ import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.rules.analysis.CTEContext;
import org.apache.doris.planner.OlapScanNode;
import org.apache.doris.planner.OriginalPlanner;
import org.apache.doris.planner.Planner;
@ -211,6 +212,7 @@ public class StmtExecutor implements ProfileWriter {
this.statementContext.setConnectContext(ctx);
this.statementContext.setOriginStatement(originStmt);
this.statementContext.setParsedStatement(parsedStmt);
this.statementContext.setCteContext(new CTEContext());
} else {
this.statementContext = new StatementContext(ctx, originStmt);
this.statementContext.setParsedStatement(parsedStmt);

View File

@ -102,10 +102,10 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING a1 > 0";
a1 = new SlotReference(
new ExprId(2), "a1", TinyIntType.INSTANCE, true,
new ExprId(1), "a1", TinyIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test_having", "t1")
);
Alias value = new Alias(new ExprId(0), a1, "value");
Alias value = new Alias(new ExprId(3), a1, "value");
PlanChecker.from(connectContext).analyze(sql)
.applyBottomUp(new ExpressionRewrite(TypeCoercion.INSTANCE))
.matchesFromRoot(
@ -189,14 +189,14 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
sql = "SELECT a1, SUM(a2) as value FROM t1 GROUP BY a1 HAVING SUM(a2) > 0";
a1 = new SlotReference(
new ExprId(2), "a1", TinyIntType.INSTANCE, true,
new ExprId(1), "a1", TinyIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test_having", "t1")
);
a2 = new SlotReference(
new ExprId(3), "a2", TinyIntType.INSTANCE, true,
new ExprId(2), "a2", TinyIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test_having", "t1")
);
Alias value = new Alias(new ExprId(0), new Sum(a2), "value");
Alias value = new Alias(new ExprId(3), new Sum(a2), "value");
PlanChecker.from(connectContext).analyze(sql)
.matchesFromRoot(
logicalProject(
@ -349,24 +349,24 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
+ "FROM t1, t2 WHERE t1.pk = t2.pk GROUP BY t1.pk, t1.pk + 1\n"
+ "HAVING t1.pk > 0 AND COUNT(a1) + 1 > 0 AND SUM(a1 + a2) + 1 > 0 AND v1 + 1 > 0 AND v1 > 0";
SlotReference pk = new SlotReference(
new ExprId(1), "pk", TinyIntType.INSTANCE, true,
new ExprId(0), "pk", TinyIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test_having", "t1")
);
SlotReference a1 = new SlotReference(
new ExprId(2), "a1", TinyIntType.INSTANCE, true,
new ExprId(1), "a1", TinyIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test_having", "t1")
);
SlotReference a2 = new SlotReference(
new ExprId(3), "a2", TinyIntType.INSTANCE, true,
new ExprId(2), "a2", TinyIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test_having", "t1")
);
Alias pk1 = new Alias(new ExprId(7), new Add(pk, Literal.of((byte) 1)), "(pk + 1)");
Alias pk11 = new Alias(new ExprId(8), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((short) 1)), "((pk + 1) + 1)");
Alias pk2 = new Alias(new ExprId(9), new Add(pk, Literal.of((byte) 2)), "(pk + 2)");
Alias sumA1 = new Alias(new ExprId(10), new Sum(a1), "SUM(a1)");
Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1), Literal.of(1L)), "(COUNT(a1) + 1)");
Alias sumA1A2 = new Alias(new ExprId(12), new Sum(new Add(a1, a2)), "SUM((a1 + a2))");
Alias v1 = new Alias(new ExprId(0), new Count(a2), "v1");
Alias pk1 = new Alias(new ExprId(6), new Add(pk, Literal.of((byte) 1)), "(pk + 1)");
Alias pk11 = new Alias(new ExprId(7), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((short) 1)), "((pk + 1) + 1)");
Alias pk2 = new Alias(new ExprId(8), new Add(pk, Literal.of((byte) 2)), "(pk + 2)");
Alias sumA1 = new Alias(new ExprId(9), new Sum(a1), "SUM(a1)");
Alias countA11 = new Alias(new ExprId(10), new Add(new Count(a1), Literal.of(1L)), "(COUNT(a1) + 1)");
Alias sumA1A2 = new Alias(new ExprId(11), new Sum(new Add(a1, a2)), "SUM((a1 + a2))");
Alias v1 = new Alias(new ExprId(12), new Count(a2), "v1");
PlanChecker.from(connectContext).analyze(sql)
.matchesFromRoot(
logicalProject(

View File

@ -26,8 +26,10 @@ import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand;
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand.ExplainLevel;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTE;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
@ -36,6 +38,7 @@ import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Optional;
import java.util.Set;
public class NereidsParserTest extends ParserTestBase {
@ -77,6 +80,29 @@ public class NereidsParserTest extends ParserTestBase {
);
}
@Test
public void testParseCTE() {
// Just for debug; will be completed before merged;
NereidsParser nereidsParser = new NereidsParser();
LogicalPlan logicalPlan;
String cteSql1 = "with t1 as (select s_suppkey from supplier where s_suppkey < 10) select * from t1";
logicalPlan = nereidsParser.parseSingle(cteSql1);
Assertions.assertEquals(PlanType.LOGICAL_CTE, logicalPlan.getType());
Assertions.assertEquals(((LogicalCTE<?>) logicalPlan).getAliasQueries().size(), 1);
String cteSql2 = "with t1 as (select s_suppkey from supplier), t2 as (select s_suppkey from t1) select * from t2";
logicalPlan = nereidsParser.parseSingle(cteSql2);
Assertions.assertEquals(PlanType.LOGICAL_CTE, logicalPlan.getType());
Assertions.assertEquals(((LogicalCTE<?>) logicalPlan).getAliasQueries().size(), 2);
String cteSql3 = "with t1 (key, name) as (select s_suppkey, s_name from supplier) select * from t1";
logicalPlan = nereidsParser.parseSingle(cteSql3);
Assertions.assertEquals(PlanType.LOGICAL_CTE, logicalPlan.getType());
Assertions.assertEquals(((LogicalCTE<?>) logicalPlan).getAliasQueries().size(), 1);
Optional<List<String>> columnAliases = ((LogicalCTE<?>) logicalPlan).getAliasQueries().get(0).getColumnAliases();
Assertions.assertEquals(columnAliases.get().size(), 2);
}
@Test
public void testExplainNormal() {
String sql = "explain select `AD``D` from t1 where a = 1";

View File

@ -73,7 +73,7 @@ public class PlanToStringTest {
new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList()))),
left, right);
Assertions.assertTrue(plan.toString().matches(
"LogicalJoin \\( type=INNER_JOIN, hashJoinCondition=\\[\\(a#\\d+ = b#\\d+\\)], otherJoinCondition=\\[] \\)"));
"LogicalJoin \\( type=INNER_JOIN, hashJoinConjuncts=\\[\\(a#\\d+ = b#\\d+\\)], otherJoinConjuncts=\\[] \\)"));
}
@Test

View File

@ -372,20 +372,19 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte
logicalProject(
logicalFilter()
).when(FieldChecker.check("projects", ImmutableList.of(
new Alias(new ExprId(0),
new SlotReference(new ExprId(6), "v1", BigIntType.INSTANCE,
true,
ImmutableList.of("default_cluster:test", "t7")), "aa")
new Alias(new ExprId(7), new SlotReference(new ExprId(5), "v1", BigIntType.INSTANCE,
true,
ImmutableList.of("default_cluster:test", "t7")), "aa")
)))
).when(FieldChecker.check("outputExpressions", ImmutableList.of(
new Alias(new ExprId(8),
new Max(new SlotReference(new ExprId(0), "aa", BigIntType.INSTANCE,
true,
ImmutableList.of("t2"))), "max(aa)")
)))
new Alias(new ExprId(8),
new Max(new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE,
true,
ImmutableList.of("t2"))), "max(aa)")
)))
.when(FieldChecker.check("groupByExpressions", ImmutableList.of()))
).when(FieldChecker.check("correlationSlot", ImmutableList.of(
new SlotReference(new ExprId(2), "k2", BigIntType.INSTANCE, true,
new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test", "t6")))))
);
}
@ -401,23 +400,17 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte
logicalAggregate(
logicalFilter(
logicalProject().when(FieldChecker.check("projects", ImmutableList.of(
new Alias(new ExprId(0), new SlotReference(new ExprId(6), "v1",
BigIntType.INSTANCE, true,
new Alias(new ExprId(7), new SlotReference(new ExprId(5), "v1", BigIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test", "t7")), "aa"),
new SlotReference(new ExprId(3), "k1", BigIntType.INSTANCE,
false,
new SlotReference(new ExprId(2), "k1", BigIntType.INSTANCE, false,
ImmutableList.of("default_cluster:test", "t7")),
new SlotReference(new ExprId(4), "k2", new VarcharType(128),
true,
new SlotReference(new ExprId(3), "k2", new VarcharType(128), true,
ImmutableList.of("default_cluster:test", "t7")),
new SlotReference(new ExprId(5), "k3", BigIntType.INSTANCE,
true,
new SlotReference(new ExprId(4), "k3", BigIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test", "t7")),
new SlotReference(new ExprId(6), "v1", BigIntType.INSTANCE,
true,
new SlotReference(new ExprId(5), "v1", BigIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test", "t7")),
new SlotReference(new ExprId(7), "v2", BigIntType.INSTANCE,
true,
new SlotReference(new ExprId(6), "v2", BigIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test", "t7"))
)))
)
@ -438,14 +431,12 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte
logicalAggregate(
logicalProject()
).when(FieldChecker.check("outputExpressions", ImmutableList.of(
new Alias(new ExprId(8),
new Max(new SlotReference(new ExprId(0), "aa", BigIntType.INSTANCE,
true,
ImmutableList.of("t2"))), "max(aa)"),
new SlotReference(new ExprId(7), "v2", BigIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test", "t7")))))
new Alias(new ExprId(8), new Max(new SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE, true,
ImmutableList.of("t2"))), "max(aa)"),
new SlotReference(new ExprId(6), "v2", BigIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test", "t7")))))
.when(FieldChecker.check("groupByExpressions", ImmutableList.of(
new SlotReference(new ExprId(7), "v2", BigIntType.INSTANCE, true,
new SlotReference(new ExprId(6), "v2", BigIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test", "t7"))
)))
)
@ -467,9 +458,9 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte
)
).when(FieldChecker.check("joinType", JoinType.LEFT_OUTER_JOIN))
.when(FieldChecker.check("otherJoinConjuncts", ImmutableList.of(
new EqualTo(new SlotReference(new ExprId(2), "k2", BigIntType.INSTANCE, true,
new EqualTo(new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test", "t6")),
new SlotReference(new ExprId(7), "v2", BigIntType.INSTANCE, true,
new SlotReference(new ExprId(6), "v2", BigIntType.INSTANCE, true,
ImmutableList.of("default_cluster:test", "t7")))
)))
);

View File

@ -0,0 +1,336 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.analyzer.UnboundAlias;
import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.datasets.ssb.SSBUtils;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator;
import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.rules.analysis.CTEContext;
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
import org.apache.doris.nereids.rules.rewrite.logical.InApplyToJoin;
import org.apache.doris.nereids.rules.rewrite.logical.PushApplyUnderFilter;
import org.apache.doris.nereids.rules.rewrite.logical.PushApplyUnderProject;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.NamedExpressionUtil;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.utframe.TestWithFeService;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import mockit.Mock;
import mockit.MockUp;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
public class RegisterCTETest extends TestWithFeService implements PatternMatchSupported {
private final NereidsParser parser = new NereidsParser();
private final String sql1 = "WITH cte1 AS (SELECT s_suppkey FROM supplier WHERE s_suppkey < 5), "
+ "cte2 AS (SELECT s_suppkey FROM cte1 WHERE s_suppkey < 3)"
+ "SELECT * FROM cte1, cte2";
private final String sql2 = "WITH cte1 (skey) AS (SELECT s_suppkey, s_nation FROM supplier WHERE s_suppkey < 5), "
+ "cte2 (sk2) AS (SELECT skey FROM cte1 WHERE skey < 3)"
+ "SELECT * FROM cte1, cte2";
private final String sql3 = "WITH cte1 AS (SELECT * FROM supplier), "
+ "cte2 AS (SELECT * FROM supplier WHERE s_region in (\"ASIA\", \"AFRICA\"))"
+ "SELECT s_region, count(*) FROM cte1 GROUP BY s_region HAVING s_region in (SELECT s_region FROM cte2)";
private final String sql4 = "WITH cte1 AS (SELECT s_suppkey AS sk FROM supplier WHERE s_suppkey < 5), "
+ "cte2 AS (SELECT sk FROM cte1 WHERE sk < 3)"
+ "SELECT * FROM cte1 JOIN cte2 ON cte1.sk = cte2.sk";
private final String sql5 = "WITH V1 AS (SELECT s_suppkey FROM supplier), "
+ "V2 AS (SELECT s_suppkey FROM V1)"
+ "SELECT * FROM V2";
private final String sql6 = "WITH cte1 AS (SELECT s_suppkey FROM supplier)"
+ "SELECT * FROM cte1 AS t1, cte1 AS t2";
private final List<String> testSql = ImmutableList.of(
sql1, sql2, sql3, sql4, sql5, sql6
);
@Override
protected void runBeforeAll() throws Exception {
createDatabase("test");
useDatabase("test");
SSBUtils.createTables(this);
createView("CREATE VIEW V1 AS SELECT * FROM part");
createView("CREATE VIEW V2 AS SELECT * FROM part");
}
@Override
protected void runBeforeEach() throws Exception {
NamedExpressionUtil.clear();
}
private CTEContext getCTEContextAfterRegisterCTE(String sql) {
return PlanChecker.from(connectContext)
.analyze(sql)
.getCascadesContext().getStatementContext().getCteContext();
}
/* ********************************************************************************************
* Test CTE
* ******************************************************************************************** */
@Test
public void testTranslateCase() throws Exception {
new MockUp<RuleSet>() {
@Mock
public List<Rule> getExplorationRules() {
return Lists.newArrayList(new AggregateDisassemble().build());
}
};
for (String sql : testSql) {
NamedExpressionUtil.clear();
StatementContext statementContext = MemoTestUtils.createStatementContext(connectContext, sql);
PhysicalPlan plan = new NereidsPlanner(statementContext).plan(
parser.parseSingle(sql),
PhysicalProperties.ANY
);
// Just to check whether translate will throw exception
new PhysicalPlanTranslator().translatePlan(plan, new PlanTranslatorContext());
}
}
@Test
public void testCTERegister() {
CTEContext cteContext = getCTEContextAfterRegisterCTE(sql1);
Assertions.assertTrue(cteContext.containsCTE("cte1")
&& cteContext.containsCTE("cte2"));
LogicalPlan cte2InitialPlan = cteContext.getInitialCTEPlan("cte2");
PlanChecker.from(connectContext, cte2InitialPlan).matchesFromRoot(
logicalProject(
logicalFilter(
logicalSubQueryAlias(
logicalProject(
logicalFilter(
unboundRelation()
)
)
)
)
)
);
}
@Test
public void testCTERegisterWithColumnAlias() {
CTEContext cteContext = getCTEContextAfterRegisterCTE(sql2);
Assertions.assertTrue(cteContext.containsCTE("cte1")
&& cteContext.containsCTE("cte2"));
// check initial plan
LogicalPlan cte1InitialPlan = cteContext.getInitialCTEPlan("cte1");
List<NamedExpression> targetProjects = new ArrayList<>();
targetProjects.add(new UnboundAlias(new UnboundSlot("s_suppkey"), "skey"));
targetProjects.add(new UnboundSlot("s_nation"));
PlanChecker.from(connectContext, cte1InitialPlan)
.matches(
logicalProject(
).when(FieldChecker.check("projects", targetProjects))
);
// check analyzed plan
LogicalPlan cte1AnalyzedPlan = cteContext.getAnalyzedCTEPlan("cte1");
targetProjects = new ArrayList<>();
targetProjects.add(new Alias(new ExprId(7),
new SlotReference(new ExprId(0), "s_suppkey", VarcharType.INSTANCE,
false, ImmutableList.of("defaulst_cluster:test", "supplier")), "skey"));
targetProjects.add(new SlotReference(new ExprId(4), "s_nation", VarcharType.INSTANCE,
false, ImmutableList.of("defaulst_cluster:test", "supplier")));
PlanChecker.from(connectContext, cte1AnalyzedPlan)
.matches(
logicalProject(
).when(FieldChecker.check("projects", targetProjects))
);
}
@Test
public void testCTEInHavingAndSubquery() {
SlotReference region1 = new SlotReference(new ExprId(5), "s_region", VarcharType.INSTANCE,
false, ImmutableList.of("cte1"));
SlotReference region2 = new SlotReference(new ExprId(12), "s_region", VarcharType.INSTANCE,
false, ImmutableList.of("cte2"));
SlotReference count = new SlotReference(new ExprId(14), "count()", BigIntType.INSTANCE,
false, ImmutableList.of());
Alias countAlias = new Alias(new ExprId(14), new Count(), "count()");
PlanChecker.from(connectContext)
.analyze(sql3)
.applyBottomUp(new PushApplyUnderProject())
.applyBottomUp(new PushApplyUnderFilter())
.applyBottomUp(new InApplyToJoin())
.matches(
logicalProject(
logicalJoin(
logicalAggregate()
.when(FieldChecker.check("outputExpressions", ImmutableList.of(region1, countAlias)))
.when(FieldChecker.check("groupByExpressions", ImmutableList.of(region1))),
any()
).when(FieldChecker.check("joinType", JoinType.LEFT_SEMI_JOIN))
.when(FieldChecker.check("otherJoinConjuncts", ImmutableList.of(
new EqualTo(region1, region2)
)))
).when(FieldChecker.check("projects", ImmutableList.of(region1, count)))
);
}
@Test
public void testCTEWithAlias() {
SlotReference skInCTE1 = new SlotReference(new ExprId(7), "sk", IntegerType.INSTANCE,
false, ImmutableList.of("cte1"));
SlotReference skInCTE2 = new SlotReference(new ExprId(15), "sk", IntegerType.INSTANCE,
false, ImmutableList.of("cte2"));
Alias skAlias = new Alias(new ExprId(7),
new SlotReference(new ExprId(0), "s_suppkey", IntegerType.INSTANCE,
false, ImmutableList.of("default_cluster:test", "supplier")), "sk");
PlanChecker.from(connectContext)
.analyze(sql4)
.matches(
logicalProject(
logicalJoin(
logicalProject().when(FieldChecker.check("projects", ImmutableList.of(skAlias))),
logicalProject().when(FieldChecker.check("projects", ImmutableList.of(skInCTE2)))
).when(FieldChecker.check("joinType", JoinType.INNER_JOIN))
.when(FieldChecker.check("otherJoinConjuncts", ImmutableList.of(
new EqualTo(skInCTE1, skInCTE2)
)))
).when(FieldChecker.check("projects", ImmutableList.of(skInCTE1, skInCTE2)))
);
}
@Test
public void testCTEWithAnExistedTableOrViewName() {
SlotReference suppkeyInV1 = new SlotReference(new ExprId(7), "s_suppkey", IntegerType.INSTANCE,
false, ImmutableList.of("V1"));
SlotReference suppkeyInV2 = new SlotReference(new ExprId(7), "s_suppkey", IntegerType.INSTANCE,
false, ImmutableList.of("V2"));
SlotReference suppkeyInSupplier = new SlotReference(new ExprId(7), "s_suppkey", IntegerType.INSTANCE,
false, ImmutableList.of("default_cluster:test", "supplier"));
PlanChecker.from(connectContext)
.analyze(sql5)
.matches(
logicalProject(
logicalProject(
logicalProject()
.when(FieldChecker.check("projects", ImmutableList.of(suppkeyInSupplier)))
).when(FieldChecker.check("projects", ImmutableList.of(suppkeyInV1)))
).when(FieldChecker.check("projects", ImmutableList.of(suppkeyInV2)))
);
}
/* ********************************************************************************************
* Test CTE Exceptions
* ******************************************************************************************** */
@Test
public void testCTEExceptionOfDuplicatedColumnAlias() {
String sql = "WITH cte1 (a1, A1) AS (SELECT * FROM supplier)"
+ "SELECT * FROM cte1";
AnalysisException exception = Assertions.assertThrows(AnalysisException.class, () -> {
PlanChecker.from(connectContext).checkPlannerResult(sql);
}, "Not throw expected exception.");
Assertions.assertTrue(exception.getMessage().contains("Duplicated CTE column alias: [a1] in CTE [cte1]"));
}
@Test
public void testCTEExceptionOfColumnAliasSize() {
String sql = "WITH cte1 (a1, a2) AS "
+ "(SELECT s_suppkey FROM supplier)"
+ "SELECT * FROM cte1";
AnalysisException exception = Assertions.assertThrows(AnalysisException.class, () -> {
PlanChecker.from(connectContext).checkPlannerResult(sql);
}, "Not throw expected exception.");
System.out.println(exception.getMessage());
Assertions.assertTrue(exception.getMessage().contains("CTE [cte1] returns 2 columns, "
+ "but 1 labels were specified."));
}
@Test
public void testCTEExceptionOfReferenceInWrongOrder() {
String sql = "WITH cte1 AS (SELECT * FROM cte2), "
+ "cte2 AS (SELECT * FROM supplier)"
+ "SELECT * FROM cte1, cte2";
RuntimeException exception = Assertions.assertThrows(RuntimeException.class, () -> {
PlanChecker.from(connectContext).checkPlannerResult(sql);
}, "Not throw expected exception.");
Assertions.assertTrue(exception.getMessage().contains("[cte2] does not exist in database"));
}
@Test
public void testCTEExceptionOfErrorInUnusedCTE() {
String sql = "WITH cte1 AS (SELECT * FROM not_existed_table)"
+ "SELECT * FROM supplier";
RuntimeException exception = Assertions.assertThrows(RuntimeException.class, () -> {
PlanChecker.from(connectContext).checkPlannerResult(sql);
}, "Not throw expected exception.");
Assertions.assertTrue(exception.getMessage().contains("[not_existed_table] does not exist in database"));
}
@Test
public void testCTEExceptionOfDuplicatedCTEName() {
String sql = "WITH cte1 AS (SELECT * FROM supplier), "
+ "cte1 AS (SELECT * FROM part)"
+ "SELECT * FROM cte1";
AnalysisException exception = Assertions.assertThrows(AnalysisException.class, () -> {
PlanChecker.from(connectContext).analyze(sql);
}, "Not throw expected exception.");
Assertions.assertTrue(exception.getMessage().contains("[cte1] cannot be used more than once"));
}
}