[feature-wip](nereids) Make nereids more compatible with spark-sql syntax. (#27231)

**Thanks for** pr #21855 to provide a wonderful reference. 

Maybe it is very difficult and **cost-expensive** to implement **a comprehensive logical plan adapter**, maybe there is just some small syntax variations between doris and some other engines (such as hive/spark), so we can just **focus on** the **difference** here.

This pr mainly focus on the **syntax difference between doris and spark-sql**. For instance, do some function tranformations and override some syntax validations.

- add a dialect named `spark_sql`
- move method `NereidsParser#parseSQLWithDialect` to `TrinoParser`
- extract some `FnCallTransformer`/`FnCallTransformers` classes, so we can reuse the logic about the function transformers
- allow derived tables without alias when we set dialect to `spark_sql`(legacy and nereids parser are both supported)
- add some function transformers for hive/spark built-in functions

### Test case (from our online doris cluster)

- Test derived table without alias

```sql
MySQL [(none)]> show variables like '%dialect%';
+---------------+-------+---------------+---------+
| Variable_name | Value | Default_Value | Changed |
+---------------+-------+---------------+---------+
| sql_dialect   | spark_sql  | doris         | 1       |
+---------------+-------+---------------+---------+
1 row in set (0.01 sec)

MySQL [(none)]> select * from (select 1);
+------+
| 1    |
+------+
|    1 |
+------+
1 row in set (0.03 sec)

MySQL [(none)]> select __auto_generated_subquery_name.a from (select 1 as a);
+------+
| a    |
+------+
|    1 |
+------+
1 row in set (0.03 sec)

MySQL [(none)]> set sql_dialect=doris;
Query OK, 0 rows affected (0.02 sec)

MySQL [(none)]> select * from (select 1);
ERROR 1248 (42000): errCode = 2, detailMessage = Every derived table must have its own alias
MySQL [(none)]> 
```

- Test spark-sql/hive built-in functions

```sql
MySQL [(none)]> show global functions;
Empty set (0.01 sec)

MySQL [(none)]> show variables like '%dialect%';
+---------------+-------+---------------+---------+
| Variable_name | Value | Default_Value | Changed |
+---------------+-------+---------------+---------+
| sql_dialect   | spark_sql  | doris         | 1       |
+---------------+-------+---------------+---------+
1 row in set (0.01 sec)

MySQL [(none)]> select get_json_object('{"a":"b"}', '$.a');
+----------------------------------+
| json_extract('{"a":"b"}', '$.a') |
+----------------------------------+
| "b"                              |
+----------------------------------+
1 row in set (0.04 sec)

MySQL [(none)]> select split("a b c", " ");
+-------------------------------+
| split_by_string('a b c', ' ') |
+-------------------------------+
| ["a", "b", "c"]               |
+-------------------------------+
1 row in set (1.17 sec)
```
This commit is contained in:
Xiangyu Wang
2023-12-11 11:16:53 +08:00
committed by GitHub
parent e1587537bc
commit f2fd66ad3b
20 changed files with 516 additions and 161 deletions

View File

@ -27,6 +27,9 @@ import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.ErrorCode;
import org.apache.doris.common.ErrorReport;
import org.apache.doris.common.UserException;
import org.apache.doris.nereids.parser.ParseDialect;
import org.apache.doris.nereids.parser.spark.SparkSql3LogicalPlanBuilder;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.rewrite.ExprRewriter;
import org.apache.doris.thrift.TNullSide;
@ -194,7 +197,13 @@ public class InlineViewRef extends TableRef {
}
if (view == null && !hasExplicitAlias()) {
ErrorReport.reportAnalysisException(ErrorCode.ERR_DERIVED_MUST_HAVE_ALIAS);
String dialect = ConnectContext.get().getSessionVariable().getSqlDialect();
ParseDialect.Dialect sqlDialect = ParseDialect.Dialect.getByName(dialect);
if (ParseDialect.Dialect.SPARK_SQL != sqlDialect) {
ErrorReport.reportAnalysisException(ErrorCode.ERR_DERIVED_MUST_HAVE_ALIAS);
}
hasExplicitAlias = true;
aliases = new String[] { SparkSql3LogicalPlanBuilder.DEFAULT_TABLE_ALIAS };
}
// Analyze the inline view query statement with its own analyzer

View File

@ -17,7 +17,7 @@
package org.apache.doris.nereids.analyzer;
import org.apache.doris.nereids.parser.trino.TrinoFnCallTransformer.PlaceholderCollector;
import org.apache.doris.nereids.parser.CommonFnCallTransformer.PlaceholderCollector;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;

View File

@ -15,9 +15,8 @@
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.parser.trino;
package org.apache.doris.nereids.parser;
import org.apache.doris.nereids.parser.ParserContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.Function;

View File

@ -0,0 +1,110 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.parser;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.Iterables;
import java.util.List;
import java.util.stream.Collectors;
/**
* The abstract holder for {@link AbstractFnCallTransformer},
* and supply transform facade ability.
*/
public abstract class AbstractFnCallTransformers {
private final ImmutableListMultimap<String, AbstractFnCallTransformer> transformerMap;
private final ImmutableListMultimap<String, AbstractFnCallTransformer> complexTransformerMap;
private final ImmutableListMultimap.Builder<String, AbstractFnCallTransformer> transformerBuilder =
ImmutableListMultimap.builder();
private final ImmutableListMultimap.Builder<String, AbstractFnCallTransformer> complexTransformerBuilder =
ImmutableListMultimap.builder();
protected AbstractFnCallTransformers() {
registerTransformers();
transformerMap = transformerBuilder.build();
registerComplexTransformers();
// build complex transformer map in the end
complexTransformerMap = complexTransformerBuilder.build();
}
/**
* Function transform facade
*/
public Function transform(String sourceFnName, List<Expression> sourceFnTransformedArguments,
ParserContext context) {
List<AbstractFnCallTransformer> transformers = getTransformers(sourceFnName);
return doTransform(transformers, sourceFnName, sourceFnTransformedArguments, context);
}
private Function doTransform(List<AbstractFnCallTransformer> transformers,
String sourceFnName,
List<Expression> sourceFnTransformedArguments,
ParserContext context) {
for (AbstractFnCallTransformer transformer : transformers) {
if (transformer.check(sourceFnName, sourceFnTransformedArguments, context)) {
Function transformedFunction =
transformer.transform(sourceFnName, sourceFnTransformedArguments, context);
if (transformedFunction == null) {
continue;
}
return transformedFunction;
}
}
return null;
}
protected void doRegister(
String sourceFnNme,
int sourceFnArgumentsNum,
String targetFnName,
List<? extends Expression> targetFnArguments,
boolean variableArgument) {
List<Expression> castedTargetFnArguments = targetFnArguments
.stream()
.map(each -> (Expression) each)
.collect(Collectors.toList());
transformerBuilder.put(sourceFnNme, new CommonFnCallTransformer(new UnboundFunction(
targetFnName, castedTargetFnArguments), variableArgument, sourceFnArgumentsNum));
}
protected void doRegister(
String sourceFnNme,
AbstractFnCallTransformer transformer) {
complexTransformerBuilder.put(sourceFnNme, transformer);
}
private List<AbstractFnCallTransformer> getTransformers(String sourceFnName) {
ImmutableList<AbstractFnCallTransformer> fnCallTransformers =
transformerMap.get(sourceFnName);
ImmutableList<AbstractFnCallTransformer> complexFnCallTransformers =
complexTransformerMap.get(sourceFnName);
return ImmutableList.copyOf(Iterables.concat(fnCallTransformers, complexFnCallTransformers));
}
protected abstract void registerTransformers();
protected abstract void registerComplexTransformers();
}

View File

@ -15,11 +15,10 @@
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.parser.trino;
package org.apache.doris.nereids.parser;
import org.apache.doris.nereids.analyzer.PlaceholderExpression;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.parser.ParserContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
@ -31,7 +30,7 @@ import java.util.stream.Collectors;
/**
* Trino function transformer
*/
public class TrinoFnCallTransformer extends AbstractFnCallTransformer {
public class CommonFnCallTransformer extends AbstractFnCallTransformer {
private final UnboundFunction targetFunction;
private final List<PlaceholderExpression> targetArguments;
private final boolean variableArgument;
@ -40,9 +39,9 @@ public class TrinoFnCallTransformer extends AbstractFnCallTransformer {
/**
* Trino function transformer, mostly this handle common function.
*/
public TrinoFnCallTransformer(UnboundFunction targetFunction,
boolean variableArgument,
int sourceArgumentsNum) {
public CommonFnCallTransformer(UnboundFunction targetFunction,
boolean variableArgument,
int sourceArgumentsNum) {
this.targetFunction = targetFunction;
this.variableArgument = variableArgument;
this.sourceArgumentsNum = sourceArgumentsNum;

View File

@ -894,7 +894,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
/**
* process lateral view, add a {@link org.apache.doris.nereids.trees.plans.logical.LogicalGenerate} on plan.
*/
private LogicalPlan withGenerate(LogicalPlan plan, LateralViewContext ctx) {
protected LogicalPlan withGenerate(LogicalPlan plan, LateralViewContext ctx) {
if (ctx.LATERAL() == null) {
return plan;
}

View File

@ -31,7 +31,7 @@ import java.math.BigInteger;
/**
* Logical plan builder assistant for buildIn dialect and other dialect.
* The same logical in {@link org.apache.doris.nereids.parser.LogicalPlanBuilder}
* and {@link org.apache.doris.nereids.parser.trino.LogicalPlanTrinoBuilder} can be
* and {@link org.apache.doris.nereids.parser.trino.TrinoLogicalPlanBuilder} can be
* extracted to here.
*/
public class LogicalPlanBuilderAssistant {

View File

@ -22,9 +22,8 @@ import org.apache.doris.common.Pair;
import org.apache.doris.nereids.DorisLexer;
import org.apache.doris.nereids.DorisParser;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.exceptions.UnsupportedDialectException;
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.parser.trino.LogicalPlanTrinoBuilder;
import org.apache.doris.nereids.parser.spark.SparkSql3LogicalPlanBuilder;
import org.apache.doris.nereids.parser.trino.TrinoParser;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
@ -37,14 +36,14 @@ import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.ParserRuleContext;
import org.antlr.v4.runtime.atn.PredictionMode;
import org.antlr.v4.runtime.misc.ParseCancellationException;
import org.apache.commons.collections.CollectionUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import javax.annotation.Nullable;
/**
* Sql parser, convert sql DSL to logical plan.
@ -59,7 +58,19 @@ public class NereidsParser {
* see <a href="https://dev.mysql.com/doc/internals/en/com-set-option.html">docs</a> for more information.
*/
public List<StatementBase> parseSQL(String originStr) {
List<Pair<LogicalPlan, StatementContext>> logicalPlans = parseMultiple(originStr);
return parseSQL(originStr, (LogicalPlanBuilder) null);
}
/**
* ParseSQL with dialect.
*/
public List<StatementBase> parseSQL(String sql, SessionVariable sessionVariable) {
@Nullable ParseDialect.Dialect sqlDialect = ParseDialect.Dialect.getByName(sessionVariable.getSqlDialect());
return parseSQLWithDialect(sql, sqlDialect, sessionVariable);
}
private List<StatementBase> parseSQL(String originStr, @Nullable LogicalPlanBuilder logicalPlanBuilder) {
List<Pair<LogicalPlan, StatementContext>> logicalPlans = parseMultiple(originStr, logicalPlanBuilder);
List<StatementBase> statementBases = Lists.newArrayList();
for (Pair<LogicalPlan, StatementContext> parsedPlanToContext : logicalPlans) {
statementBases.add(new LogicalPlanAdapter(parsedPlanToContext.first, parsedPlanToContext.second));
@ -67,37 +78,23 @@ public class NereidsParser {
return statementBases;
}
/**
* ParseSQL with dialect.
*/
public List<StatementBase> parseSQL(String sql, SessionVariable sessionVariable) {
if (ParseDialect.TRINO_395.getDialect().getDialectName()
.equalsIgnoreCase(sessionVariable.getSqlDialect())) {
return parseSQLWithDialect(sql, sessionVariable);
} else {
return parseSQL(sql);
}
}
private List<StatementBase> parseSQLWithDialect(String sql,
@Nullable ParseDialect.Dialect sqlDialect,
SessionVariable sessionVariable) {
switch (sqlDialect) {
case TRINO:
final List<StatementBase> logicalPlans = TrinoParser.parse(sql, sessionVariable);
if (CollectionUtils.isEmpty(logicalPlans)) {
return parseSQL(sql);
}
return logicalPlans;
private List<StatementBase> parseSQLWithDialect(String sql, SessionVariable sessionVariable) {
final List<StatementBase> logicalPlans = new ArrayList<>();
try {
io.trino.sql.parser.StatementSplitter splitter = new io.trino.sql.parser.StatementSplitter(sql);
ParserContext parserContext = new ParserContext(ParseDialect.TRINO_395);
StatementContext statementContext = new StatementContext();
for (io.trino.sql.parser.StatementSplitter.Statement statement : splitter.getCompleteStatements()) {
Object parsedPlan = parseSingleWithDialect(statement.statement(), parserContext);
logicalPlans.add(parsedPlan == null
? null : new LogicalPlanAdapter((LogicalPlan) parsedPlan, statementContext));
}
} catch (io.trino.sql.parser.ParsingException | UnsupportedDialectException e) {
LOG.debug("Failed to parse logical plan from trino, sql is :{}", sql, e);
return parseSQL(sql);
case SPARK_SQL:
return parseSQL(sql, new SparkSql3LogicalPlanBuilder());
default:
return parseSQL(sql);
}
if (logicalPlans.isEmpty() || logicalPlans.stream().anyMatch(Objects::isNull)) {
return parseSQL(sql);
}
return logicalPlans;
}
/**
@ -107,11 +104,26 @@ public class NereidsParser {
* @return logical plan
*/
public LogicalPlan parseSingle(String sql) {
return parse(sql, DorisParser::singleStatement);
return parseSingle(sql, null);
}
/**
* parse sql DSL string.
*
* @param sql sql string
* @return logical plan
*/
public LogicalPlan parseSingle(String sql, @Nullable LogicalPlanBuilder logicalPlanBuilder) {
return parse(sql, logicalPlanBuilder, DorisParser::singleStatement);
}
public List<Pair<LogicalPlan, StatementContext>> parseMultiple(String sql) {
return parse(sql, DorisParser::multiStatements);
return parseMultiple(sql, null);
}
public List<Pair<LogicalPlan, StatementContext>> parseMultiple(String sql,
@Nullable LogicalPlanBuilder logicalPlanBuilder) {
return parse(sql, logicalPlanBuilder, DorisParser::multiStatements);
}
public Expression parseExpression(String expression) {
@ -127,28 +139,15 @@ public class NereidsParser {
}
private <T> T parse(String sql, Function<DorisParser, ParserRuleContext> parseFunction) {
ParserRuleContext tree = toAst(sql, parseFunction);
LogicalPlanBuilder logicalPlanBuilder = new LogicalPlanBuilder();
return (T) logicalPlanBuilder.visit(tree);
return parse(sql, null, parseFunction);
}
/**
* Parse dialect sql.
*
* @param sql sql string
* @param parserContext parse context
* @return logical plan
*/
public <T> T parseSingleWithDialect(String sql, ParserContext parserContext) {
if (ParseDialect.TRINO_395.equals(parserContext.getParserDialect())) {
io.trino.sql.tree.Statement statement = TrinoParser.parse(sql);
return (T) new LogicalPlanTrinoBuilder().visit(statement, parserContext);
} else {
LOG.debug("Failed to parse logical plan, the dialect name is {}, version is {}",
parserContext.getParserDialect().getDialect().getDialectName(),
parserContext.getParserDialect().getVersion());
throw new UnsupportedDialectException(parserContext.getParserDialect());
}
private <T> T parse(String sql, @Nullable LogicalPlanBuilder logicalPlanBuilder,
Function<DorisParser, ParserRuleContext> parseFunction) {
ParserRuleContext tree = toAst(sql, parseFunction);
LogicalPlanBuilder realLogicalPlanBuilder = logicalPlanBuilder == null
? new LogicalPlanBuilder() : logicalPlanBuilder;
return (T) realLogicalPlanBuilder.visit(tree);
}
private ParserRuleContext toAst(String sql, Function<DorisParser, ParserRuleContext> parseFunction) {

View File

@ -17,6 +17,8 @@
package org.apache.doris.nereids.parser;
import javax.annotation.Nullable;
/**
* ParseDialect enum, maybe support other dialect.
*/
@ -29,7 +31,11 @@ public enum ParseDialect {
/**
* Doris parser and it's version is 2.0.0.
*/
DORIS_2_ALL(Dialect.DORIS, Version.DORIS_2_ALL);
DORIS_2_ALL(Dialect.DORIS, Version.DORIS_2_ALL),
/**
* Spark parser and it's version is 3.x.
*/
SPARK_SQL_3_ALL(Dialect.SPARK_SQL, Version.SPARK_SQL_3_ALL);
private final Dialect dialect;
private final Version version;
@ -58,7 +64,11 @@ public enum ParseDialect {
/**
* Doris parser and it's version is 2.0.0.
*/
DORIS_2_ALL("2.*");
DORIS_2_ALL("2.*"),
/**
* Spark sql parser and it's version is 3.x.
*/
SPARK_SQL_3_ALL("3.*");
private final String version;
Version(String version) {
@ -81,7 +91,11 @@ public enum ParseDialect {
/**
* Doris parser dialect
*/
DORIS("doris");
DORIS("doris"),
/**
* Spark sql parser dialect
*/
SPARK_SQL("spark_sql");
private String dialectName;
@ -96,7 +110,7 @@ public enum ParseDialect {
/**
* Get dialect by name
*/
public static Dialect getByName(String dialectName) {
public static @Nullable Dialect getByName(String dialectName) {
if (dialectName == null) {
return null;
}

View File

@ -0,0 +1,71 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.parser.spark;
import org.apache.doris.nereids.analyzer.PlaceholderExpression;
import org.apache.doris.nereids.parser.AbstractFnCallTransformer;
import org.apache.doris.nereids.parser.AbstractFnCallTransformers;
import org.apache.doris.nereids.trees.expressions.Expression;
import com.google.common.collect.Lists;
/**
* The builder and factory for spark-sql 3.x {@link AbstractFnCallTransformer},
* and supply transform facade ability.
*/
public class SparkSql3FnCallTransformers extends AbstractFnCallTransformers {
private SparkSql3FnCallTransformers() {
}
@Override
protected void registerTransformers() {
doRegister("get_json_object", 2, "json_extract",
Lists.newArrayList(
PlaceholderExpression.of(Expression.class, 1),
PlaceholderExpression.of(Expression.class, 2)), true);
doRegister("get_json_object", 2, "json_extract",
Lists.newArrayList(
PlaceholderExpression.of(Expression.class, 1),
PlaceholderExpression.of(Expression.class, 2)), false);
doRegister("split", 2, "split_by_string",
Lists.newArrayList(
PlaceholderExpression.of(Expression.class, 1),
PlaceholderExpression.of(Expression.class, 2)), true);
doRegister("split", 2, "split_by_string",
Lists.newArrayList(
PlaceholderExpression.of(Expression.class, 1),
PlaceholderExpression.of(Expression.class, 2)), false);
// TODO: add other function transformer
}
@Override
protected void registerComplexTransformers() {
// TODO: add other complex function transformer
}
static class SingletonHolder {
private static final SparkSql3FnCallTransformers INSTANCE = new SparkSql3FnCallTransformers();
}
public static SparkSql3FnCallTransformers getSingleton() {
return SingletonHolder.INSTANCE;
}
}

View File

@ -0,0 +1,88 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.parser.spark;
import org.apache.doris.nereids.DorisParser;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.exceptions.ParseException;
import org.apache.doris.nereids.parser.LogicalPlanBuilder;
import org.apache.doris.nereids.parser.ParseDialect;
import org.apache.doris.nereids.parser.ParserContext;
import org.apache.doris.nereids.parser.ParserUtils;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import org.apache.commons.lang3.StringUtils;
/**
* Extends from {@link org.apache.doris.nereids.parser.LogicalPlanBuilder},
* just focus on the difference between these query syntax.
*/
public class SparkSql3LogicalPlanBuilder extends LogicalPlanBuilder {
// use a default alias name if not exists, keep the same name with spark-sql
public static final String DEFAULT_TABLE_ALIAS = "__auto_generated_subquery_name";
private final ParserContext parserContext;
public SparkSql3LogicalPlanBuilder() {
this.parserContext = new ParserContext(ParseDialect.SPARK_SQL_3_ALL);
}
@Override
public LogicalPlan visitAliasedQuery(DorisParser.AliasedQueryContext ctx) {
LogicalPlan plan = withTableAlias(visitQuery(ctx.query()), ctx.tableAlias());
for (DorisParser.LateralViewContext lateralViewContext : ctx.lateralView()) {
plan = withGenerate(plan, lateralViewContext);
}
return plan;
}
@Override
public Expression visitFunctionCall(DorisParser.FunctionCallContext ctx) {
Expression expression = super.visitFunctionCall(ctx);
if (!(expression instanceof UnboundFunction)) {
return expression;
}
UnboundFunction sourceFunction = (UnboundFunction) expression;
Function transformedFunction = SparkSql3FnCallTransformers.getSingleton().transform(
sourceFunction.getName(),
sourceFunction.getArguments(),
this.parserContext
);
if (transformedFunction == null) {
return expression;
}
return transformedFunction;
}
private LogicalPlan withTableAlias(LogicalPlan plan, DorisParser.TableAliasContext ctx) {
if (ctx.strictIdentifier() == null) {
return plan;
}
return ParserUtils.withOrigin(ctx.strictIdentifier(), () -> {
String alias = StringUtils.isEmpty(ctx.strictIdentifier().getText())
? DEFAULT_TABLE_ALIAS : ctx.strictIdentifier().getText();
if (null != ctx.identifierList()) {
throw new ParseException("Do not implemented", ctx);
}
return new LogicalSubQueryAlias<>(alias, plan);
});
}
}

View File

@ -17,6 +17,8 @@
package org.apache.doris.nereids.parser.trino;
import org.apache.doris.nereids.parser.AbstractFnCallTransformer;
/**
* Trino complex function transformer
*/

View File

@ -18,113 +18,45 @@
package org.apache.doris.nereids.parser.trino;
import org.apache.doris.nereids.analyzer.PlaceholderExpression;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.parser.ParserContext;
import org.apache.doris.nereids.parser.AbstractFnCallTransformer;
import org.apache.doris.nereids.parser.AbstractFnCallTransformers;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.stream.Collectors;
/**
* The builder and factory for {@link org.apache.doris.nereids.parser.trino.TrinoFnCallTransformer},
* The builder and factory for trino {@link AbstractFnCallTransformer},
* and supply transform facade ability.
*/
public class TrinoFnCallTransformers {
private static ImmutableListMultimap<String, AbstractFnCallTransformer> TRANSFORMER_MAP;
private static ImmutableListMultimap<String, AbstractFnCallTransformer> COMPLEX_TRANSFORMER_MAP;
private static final ImmutableListMultimap.Builder<String, AbstractFnCallTransformer> transformerBuilder =
ImmutableListMultimap.builder();
private static final ImmutableListMultimap.Builder<String, AbstractFnCallTransformer> complexTransformerBuilder =
ImmutableListMultimap.builder();
static {
registerTransformers();
registerComplexTransformers();
}
public class TrinoFnCallTransformers extends AbstractFnCallTransformers {
private TrinoFnCallTransformers() {
}
/**
* Function transform facade
*/
public static Function transform(String sourceFnName, List<Expression> sourceFnTransformedArguments,
ParserContext context) {
List<AbstractFnCallTransformer> transformers = getTransformers(sourceFnName);
return doTransform(transformers, sourceFnName, sourceFnTransformedArguments, context);
}
private static Function doTransform(List<AbstractFnCallTransformer> transformers,
String sourceFnName,
List<Expression> sourceFnTransformedArguments,
ParserContext context) {
for (AbstractFnCallTransformer transformer : transformers) {
if (transformer.check(sourceFnName, sourceFnTransformedArguments, context)) {
Function transformedFunction =
transformer.transform(sourceFnName, sourceFnTransformedArguments, context);
if (transformedFunction == null) {
continue;
}
return transformedFunction;
}
}
return null;
}
private static List<AbstractFnCallTransformer> getTransformers(String sourceFnName) {
ImmutableList<AbstractFnCallTransformer> fnCallTransformers =
TRANSFORMER_MAP.get(sourceFnName);
ImmutableList<AbstractFnCallTransformer> complexFnCallTransformers =
COMPLEX_TRANSFORMER_MAP.get(sourceFnName);
return ImmutableList.copyOf(Iterables.concat(fnCallTransformers, complexFnCallTransformers));
}
private static void registerTransformers() {
@Override
protected void registerTransformers() {
registerStringFunctionTransformer();
// TODO: add other function transformer
// build transformer map in the end
TRANSFORMER_MAP = transformerBuilder.build();
}
private static void registerComplexTransformers() {
@Override
protected void registerComplexTransformers() {
DateDiffFnCallTransformer dateDiffFnCallTransformer = new DateDiffFnCallTransformer();
doRegister(dateDiffFnCallTransformer.getSourceFnName(), dateDiffFnCallTransformer);
// TODO: add other complex function transformer
// build complex transformer map in the end
COMPLEX_TRANSFORMER_MAP = complexTransformerBuilder.build();
}
private static void registerStringFunctionTransformer() {
protected void registerStringFunctionTransformer() {
doRegister("codepoint", 1, "ascii",
Lists.newArrayList(PlaceholderExpression.of(Expression.class, 1)), false);
// TODO: add other string function transformer
}
private static void doRegister(
String sourceFnNme,
int sourceFnArgumentsNum,
String targetFnName,
List<? extends Expression> targetFnArguments,
boolean variableArgument) {
List<Expression> castedTargetFnArguments = targetFnArguments
.stream()
.map(each -> (Expression) each)
.collect(Collectors.toList());
transformerBuilder.put(sourceFnNme, new TrinoFnCallTransformer(new UnboundFunction(
targetFnName, castedTargetFnArguments), variableArgument, sourceFnArgumentsNum));
static class SingletonHolder {
private static final TrinoFnCallTransformers INSTANCE = new TrinoFnCallTransformers();
}
private static void doRegister(
String sourceFnNme,
AbstractFnCallTransformer transformer) {
complexTransformerBuilder.put(sourceFnNme, transformer);
public static TrinoFnCallTransformers getSingleton() {
return SingletonHolder.INSTANCE;
}
}

View File

@ -55,7 +55,7 @@ import java.util.stream.Collectors;
* The actually planBuilder for Trino SQL to Doris logical plan.
* It depends on {@link io.trino.sql.tree.AstVisitor}
*/
public class LogicalPlanTrinoBuilder extends io.trino.sql.tree.AstVisitor<Object, ParserContext> {
public class TrinoLogicalPlanBuilder extends io.trino.sql.tree.AstVisitor<Object, ParserContext> {
public Object visit(io.trino.sql.tree.Node node, ParserContext context) {
return this.process(node, context);
@ -145,7 +145,7 @@ public class LogicalPlanTrinoBuilder extends io.trino.sql.tree.AstVisitor<Object
protected Function visitFunctionCall(io.trino.sql.tree.FunctionCall node, ParserContext context) {
List<Expression> exprs = visit(node.getArguments(), context, Expression.class);
Function transformedFn =
TrinoFnCallTransformers.transform(node.getName().toString(), exprs, context);
TrinoFnCallTransformers.getSingleton().transform(node.getName().toString(), exprs, context);
if (transformedFn == null) {
transformedFn = new UnboundFunction(node.getName().toString(), exprs);

View File

@ -17,16 +17,74 @@
package org.apache.doris.nereids.parser.trino;
import org.apache.doris.analysis.StatementBase;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.exceptions.UnsupportedDialectException;
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.parser.ParseDialect;
import org.apache.doris.nereids.parser.ParserContext;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.qe.SessionVariable;
import com.google.common.base.Preconditions;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import javax.annotation.Nullable;
/**
* Trino Parser, depends on 395 trino-parser, and 4.9.3 antlr-runtime
*/
public class TrinoParser {
public static final Logger LOG = LogManager.getLogger(TrinoParser.class);
private static final io.trino.sql.parser.ParsingOptions PARSING_OPTIONS =
new io.trino.sql.parser.ParsingOptions(
io.trino.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL);
public static io.trino.sql.tree.Statement parse(String query) {
/**
* Parse with trino syntax, return null if parse failed
*/
public static @Nullable List<StatementBase> parse(String sql, SessionVariable sessionVariable) {
final List<StatementBase> logicalPlans = new ArrayList<>();
try {
io.trino.sql.parser.StatementSplitter splitter = new io.trino.sql.parser.StatementSplitter(sql);
ParserContext parserContext = new ParserContext(ParseDialect.TRINO_395);
StatementContext statementContext = new StatementContext();
for (io.trino.sql.parser.StatementSplitter.Statement statement : splitter.getCompleteStatements()) {
Object parsedPlan = parseSingle(statement.statement(), parserContext);
logicalPlans.add(parsedPlan == null
? null : new LogicalPlanAdapter((LogicalPlan) parsedPlan, statementContext));
}
if (logicalPlans.isEmpty() || logicalPlans.stream().anyMatch(Objects::isNull)) {
return null;
}
return logicalPlans;
} catch (io.trino.sql.parser.ParsingException | UnsupportedDialectException e) {
LOG.debug("Failed to parse logical plan from trino, sql is :{}", sql, e);
return null;
}
}
private static io.trino.sql.tree.Statement parse(String sql) {
io.trino.sql.parser.SqlParser sqlParser = new io.trino.sql.parser.SqlParser();
return sqlParser.createStatement(query, PARSING_OPTIONS);
return sqlParser.createStatement(sql, PARSING_OPTIONS);
}
/**
* Parse trino dialect sql.
*
* @param sql sql string
* @param parserContext parse context
* @return logical plan
*/
public static <T> T parseSingle(String sql, ParserContext parserContext) {
Preconditions.checkArgument(parserContext.getParserDialect() == ParseDialect.TRINO_395);
io.trino.sql.tree.Statement statement = TrinoParser.parse(sql);
return (T) new TrinoLogicalPlanBuilder().visit(statement, parserContext);
}
}

View File

@ -188,7 +188,7 @@ public abstract class ConnectProcessor {
// Nereids do not support prepare and execute now, so forbid prepare command, only process query command
if (mysqlCommand == MysqlCommand.COM_QUERY && ctx.getSessionVariable().isEnableNereidsPlanner()) {
try {
stmts = new NereidsParser().parseSQL(originStmt);
stmts = new NereidsParser().parseSQL(originStmt, ctx.getSessionVariable());
} catch (NotSupportedException e) {
// Parse sql failed, audit it and return
handleQueryException(e, originStmt, null, null);

View File

@ -598,7 +598,7 @@ public class StmtExecutor {
}
List<StatementBase> statements;
try {
statements = new NereidsParser().parseSQL(originStmt.originStmt);
statements = new NereidsParser().parseSQL(originStmt.originStmt, context.getSessionVariable());
} catch (Exception e) {
throw new ParseException("Nereids parse failed. " + e.getMessage());
}