[feature-wip](nereids) Support some spark-sql built-in functions when set dialect=spark_sql (#28531)

This commit is contained in:
Xiangyu Wang
2023-12-30 00:10:35 +08:00
committed by GitHub
parent 445f72b395
commit 8407490053
9 changed files with 266 additions and 98 deletions

View File

@ -23,9 +23,11 @@ import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
/**
* Expression placeHolder, the expression in PlaceHolderExpression will be collected by
@ -33,15 +35,25 @@ import java.util.Objects;
* @see PlaceholderCollector
*/
public class PlaceholderExpression extends Expression implements AlwaysNotNullable {
private final Class<? extends Expression> delegateClazz;
private final ImmutableSet<Class<? extends Expression>> delegateClazzSet;
/**
* 1 based
* start from 1, set the index of this placeholderExpression in sourceFnTransformedArguments
* this placeholderExpression will be replaced later
*/
private final int position;
public PlaceholderExpression(List<Expression> children, Class<? extends Expression> delegateClazz, int position) {
super(children);
this.delegateClazz = Objects.requireNonNull(delegateClazz, "delegateClazz should not be null");
this.delegateClazzSet = ImmutableSet.of(
Objects.requireNonNull(delegateClazz, "delegateClazz should not be null"));
this.position = position;
}
public PlaceholderExpression(List<Expression> children,
Set<Class<? extends Expression>> delegateClazzSet, int position) {
super(children);
this.delegateClazzSet = ImmutableSet.copyOf(delegateClazzSet);
this.position = position;
}
@ -49,13 +61,18 @@ public class PlaceholderExpression extends Expression implements AlwaysNotNullab
return new PlaceholderExpression(ImmutableList.of(), delegateClazz, position);
}
public static PlaceholderExpression of(Set<Class<? extends Expression>> delegateClazzSet, int position) {
return new PlaceholderExpression(ImmutableList.of(), delegateClazzSet, position);
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
visitor.visitPlaceholderExpression(this, context);
return visitor.visit(this, context);
}
public Class<? extends Expression> getDelegateClazz() {
return delegateClazz;
public Set<Class<? extends Expression>> getDelegateClazzSet() {
return delegateClazzSet;
}
public int getPosition() {
@ -74,11 +91,11 @@ public class PlaceholderExpression extends Expression implements AlwaysNotNullab
return false;
}
PlaceholderExpression that = (PlaceholderExpression) o;
return position == that.position && Objects.equals(delegateClazz, that.delegateClazz);
return position == that.position && Objects.equals(delegateClazzSet, that.delegateClazzSet);
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), delegateClazz, position);
return Objects.hash(super.hashCode(), delegateClazzSet, position);
}
}

View File

@ -77,17 +77,15 @@ public abstract class AbstractFnCallTransformers {
protected void doRegister(
String sourceFnNme,
int sourceFnArgumentsNum,
String targetFnName,
List<? extends Expression> targetFnArguments,
boolean variableArgument) {
List<? extends Expression> targetFnArguments) {
List<Expression> castedTargetFnArguments = targetFnArguments
.stream()
.map(each -> (Expression) each)
.collect(Collectors.toList());
transformerBuilder.put(sourceFnNme, new CommonFnCallTransformer(new UnboundFunction(
targetFnName, castedTargetFnArguments), variableArgument, sourceFnArgumentsNum));
targetFnName, castedTargetFnArguments)));
}
protected void doRegister(

View File

@ -23,56 +23,76 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
/**
* Trino function transformer
* Common function transformer,
* can transform functions which the size and type of target arguments are both the same with the source function,
* or source function is a variable-arguments function.
*/
public class CommonFnCallTransformer extends AbstractFnCallTransformer {
private final UnboundFunction targetFunction;
private final List<PlaceholderExpression> targetArguments;
private final boolean variableArgument;
private final int sourceArgumentsNum;
// true means the arguments of this function is dynamic, for example:
// - named_struct('f1', 1, 'f2', 'a', 'f3', "abc")
// - struct(1, 'a', 'abc');
private final boolean variableArguments;
/**
* Trino function transformer, mostly this handle common function.
* Common function transformer, mostly this handle common function.
*/
public CommonFnCallTransformer(UnboundFunction targetFunction,
boolean variableArgument,
int sourceArgumentsNum) {
public CommonFnCallTransformer(UnboundFunction targetFunction, boolean variableArguments) {
this.targetFunction = targetFunction;
this.variableArgument = variableArgument;
this.sourceArgumentsNum = sourceArgumentsNum;
PlaceholderCollector placeHolderCollector = new PlaceholderCollector(variableArgument);
PlaceholderCollector placeHolderCollector = new PlaceholderCollector();
placeHolderCollector.visit(targetFunction, null);
this.targetArguments = placeHolderCollector.getPlaceholderExpressions();
this.variableArguments = variableArguments;
}
public CommonFnCallTransformer(UnboundFunction targetFunction) {
this.targetFunction = targetFunction;
PlaceholderCollector placeHolderCollector = new PlaceholderCollector();
placeHolderCollector.visit(targetFunction, null);
this.targetArguments = placeHolderCollector.getPlaceholderExpressions();
this.variableArguments = false;
}
@Override
protected boolean check(String sourceFnName,
List<Expression> sourceFnTransformedArguments,
ParserContext context) {
// if variableArguments=true, we can not recognize if the type of all arguments is valid or not,
// because:
// 1. the argument size is not sure
// 2. there are some functions which can accept different types of arguments,
// for example: struct(1, 'a', 'abc')
// so just return true here.
if (variableArguments) {
return true;
}
List<Class<? extends Expression>> sourceFnTransformedArgClazz = sourceFnTransformedArguments.stream()
.map(Expression::getClass)
.collect(Collectors.toList());
if (variableArgument) {
if (targetArguments.isEmpty()) {
return false;
}
Class<? extends Expression> targetArgumentClazz = targetArguments.get(0).getDelegateClazz();
for (Expression argument : sourceFnTransformedArguments) {
if (!targetArgumentClazz.isAssignableFrom(argument.getClass())) {
return false;
}
}
}
if (sourceFnTransformedArguments.size() != sourceArgumentsNum) {
if (sourceFnTransformedArguments.size() != targetArguments.size()) {
return false;
}
for (int i = 0; i < targetArguments.size(); i++) {
if (!targetArguments.get(i).getDelegateClazz().isAssignableFrom(sourceFnTransformedArgClazz.get(i))) {
for (PlaceholderExpression targetArgument : targetArguments) {
// replace the arguments of target function by the position of target argument
int position = targetArgument.getPosition();
Class<? extends Expression> sourceArgClazz = sourceFnTransformedArgClazz.get(position - 1);
boolean valid = false;
for (Class<? extends Expression> targetArgClazz : targetArgument.getDelegateClazzSet()) {
if (targetArgClazz.isAssignableFrom(sourceArgClazz)) {
valid = true;
break;
}
}
if (!valid) {
return false;
}
}
@ -83,7 +103,16 @@ public class CommonFnCallTransformer extends AbstractFnCallTransformer {
protected Function transform(String sourceFnName,
List<Expression> sourceFnTransformedArguments,
ParserContext context) {
return targetFunction.withChildren(sourceFnTransformedArguments);
if (variableArguments) {
// not support adjust the order of arguments when variableArguments=true
return targetFunction.withChildren(sourceFnTransformedArguments);
}
List<Expression> sourceFnTransformedArgumentsInorder = Lists.newArrayList();
for (PlaceholderExpression placeholderExpression : targetArguments) {
Expression expression = sourceFnTransformedArguments.get(placeholderExpression.getPosition() - 1);
sourceFnTransformedArgumentsInorder.add(expression);
}
return targetFunction.withChildren(sourceFnTransformedArgumentsInorder);
}
/**
@ -93,20 +122,12 @@ public class CommonFnCallTransformer extends AbstractFnCallTransformer {
public static final class PlaceholderCollector extends DefaultExpressionVisitor<Void, Void> {
private final List<PlaceholderExpression> placeholderExpressions = new ArrayList<>();
private final boolean variableArgument;
public PlaceholderCollector(boolean variableArgument) {
this.variableArgument = variableArgument;
}
public PlaceholderCollector() {}
@Override
public Void visitPlaceholderExpression(PlaceholderExpression placeholderExpression, Void context) {
if (variableArgument) {
placeholderExpressions.add(placeholderExpression);
return null;
}
placeholderExpressions.set(placeholderExpression.getPosition() - 1, placeholderExpression);
placeholderExpressions.add(placeholderExpression);
return null;
}

View File

@ -15,14 +15,12 @@
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.parser.trino;
import org.apache.doris.nereids.parser.AbstractFnCallTransformer;
package org.apache.doris.nereids.parser;
/**
* Trino complex function transformer
*/
public abstract class ComplexTrinoFnCallTransformer extends AbstractFnCallTransformer {
public abstract class ComplexFnCallTransformer extends AbstractFnCallTransformer {
protected abstract String getSourceFnName();
}

View File

@ -0,0 +1,81 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.parser.spark;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.parser.ComplexFnCallTransformer;
import org.apache.doris.nereids.parser.ParserContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.List;
/**
* DateTrunc complex function transformer
*/
public class DateTruncFnCallTransformer extends ComplexFnCallTransformer {
// reference: https://spark.apache.org/docs/latest/api/sql/index.html#trunc
// spark-sql support YEAR/YYYY/YY for year, support MONTH/MON/MM for month
private static final ImmutableSet<String> YEAR = ImmutableSet.<String>builder()
.add("YEAR")
.add("YYYY")
.add("YY")
.build();
private static final ImmutableSet<String> MONTH = ImmutableSet.<String>builder()
.add("MONTH")
.add("MON")
.add("MM")
.build();
@Override
public String getSourceFnName() {
return "trunc";
}
@Override
protected boolean check(String sourceFnName, List<Expression> sourceFnTransformedArguments,
ParserContext context) {
return getSourceFnName().equalsIgnoreCase(sourceFnName) && (sourceFnTransformedArguments.size() == 2);
}
@Override
protected Function transform(String sourceFnName, List<Expression> sourceFnTransformedArguments,
ParserContext context) {
VarcharLiteral fmtLiteral = (VarcharLiteral) sourceFnTransformedArguments.get(1);
if (YEAR.contains(fmtLiteral.getValue().toUpperCase())) {
return new UnboundFunction(
"date_trunc",
ImmutableList.of(sourceFnTransformedArguments.get(0), new VarcharLiteral("YEAR")));
}
if (MONTH.contains(fmtLiteral.getValue().toUpperCase())) {
return new UnboundFunction(
"date_trunc",
ImmutableList.of(sourceFnTransformedArguments.get(0), new VarcharLiteral("MONTH")));
}
return new UnboundFunction(
"date_trunc",
ImmutableList.of(sourceFnTransformedArguments.get(0), sourceFnTransformedArguments.get(1)));
}
}

View File

@ -18,15 +18,13 @@
package org.apache.doris.nereids.parser.spark;
import org.apache.doris.nereids.analyzer.PlaceholderExpression;
import org.apache.doris.nereids.parser.AbstractFnCallTransformer;
import org.apache.doris.nereids.parser.AbstractFnCallTransformers;
import org.apache.doris.nereids.trees.expressions.Expression;
import com.google.common.collect.Lists;
/**
* The builder and factory for spark-sql 3.x {@link AbstractFnCallTransformer},
* and supply transform facade ability.
* The builder and factory for spark-sql 3.x FnCallTransformers, supply transform facade ability.
*/
public class SparkSql3FnCallTransformers extends AbstractFnCallTransformers {
@ -35,32 +33,56 @@ public class SparkSql3FnCallTransformers extends AbstractFnCallTransformers {
@Override
protected void registerTransformers() {
doRegister("get_json_object", 2, "json_extract",
Lists.newArrayList(
PlaceholderExpression.of(Expression.class, 1),
PlaceholderExpression.of(Expression.class, 2)), true);
doRegister("get_json_object", 2, "json_extract",
Lists.newArrayList(
PlaceholderExpression.of(Expression.class, 1),
PlaceholderExpression.of(Expression.class, 2)), false);
doRegister("split", 2, "split_by_string",
Lists.newArrayList(
PlaceholderExpression.of(Expression.class, 1),
PlaceholderExpression.of(Expression.class, 2)), true);
doRegister("split", 2, "split_by_string",
Lists.newArrayList(
PlaceholderExpression.of(Expression.class, 1),
PlaceholderExpression.of(Expression.class, 2)), false);
// register json functions
registerJsonFunctionTransformers();
// register string functions
registerStringFunctionTransformers();
// register date functions
registerDateFunctionTransformers();
// register numeric functions
registerNumericFunctionTransformers();
// TODO: add other function transformer
}
@Override
protected void registerComplexTransformers() {
DateTruncFnCallTransformer dateTruncFnCallTransformer = new DateTruncFnCallTransformer();
doRegister(dateTruncFnCallTransformer.getSourceFnName(), dateTruncFnCallTransformer);
// TODO: add other complex function transformer
}
private void registerJsonFunctionTransformers() {
doRegister("get_json_object", "json_extract",
Lists.newArrayList(
PlaceholderExpression.of(Expression.class, 1),
PlaceholderExpression.of(Expression.class, 2)));
}
private void registerStringFunctionTransformers() {
doRegister("split", "split_by_string",
Lists.newArrayList(
PlaceholderExpression.of(Expression.class, 1),
PlaceholderExpression.of(Expression.class, 2)));
}
private void registerDateFunctionTransformers() {
// spark-sql support to_date(date_str, fmt) function but doris only support to_date(date_str)
// here try to compat with this situation by using str_to_date(str, fmt),
// this function support the following three formats which can handle the mainly situations:
// 1. yyyyMMdd
// 2. yyyy-MM-dd
// 3. yyyy-MM-dd HH:mm:ss
doRegister("to_date", "str_to_date",
Lists.newArrayList(
PlaceholderExpression.of(Expression.class, 1),
PlaceholderExpression.of(Expression.class, 2)));
}
private void registerNumericFunctionTransformers() {
doRegister("mean", "avg",
Lists.newArrayList(PlaceholderExpression.of(Expression.class, 1)));
}
static class SingletonHolder {
private static final SparkSql3FnCallTransformers INSTANCE = new SparkSql3FnCallTransformers();
}

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.parser.trino;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.parser.ComplexFnCallTransformer;
import org.apache.doris.nereids.parser.ParserContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.Function;
@ -30,7 +31,7 @@ import java.util.List;
/**
* DateDiff complex function transformer
*/
public class DateDiffFnCallTransformer extends ComplexTrinoFnCallTransformer {
public class DateDiffFnCallTransformer extends ComplexFnCallTransformer {
private static final String SECOND = "second";
private static final String HOUR = "hour";
@ -45,15 +46,12 @@ public class DateDiffFnCallTransformer extends ComplexTrinoFnCallTransformer {
@Override
protected boolean check(String sourceFnName, List<Expression> sourceFnTransformedArguments,
ParserContext context) {
return getSourceFnName().equalsIgnoreCase(sourceFnName);
return getSourceFnName().equalsIgnoreCase(sourceFnName) && (sourceFnTransformedArguments.size() == 3);
}
@Override
protected Function transform(String sourceFnName, List<Expression> sourceFnTransformedArguments,
ParserContext context) {
if (sourceFnTransformedArguments.size() != 3) {
return null;
}
VarcharLiteral diffGranularity = (VarcharLiteral) sourceFnTransformedArguments.get(0);
if (SECOND.equals(diffGranularity.getValue())) {
return new UnboundFunction(

View File

@ -18,14 +18,13 @@
package org.apache.doris.nereids.parser.trino;
import org.apache.doris.nereids.analyzer.PlaceholderExpression;
import org.apache.doris.nereids.parser.AbstractFnCallTransformer;
import org.apache.doris.nereids.parser.AbstractFnCallTransformers;
import org.apache.doris.nereids.trees.expressions.Expression;
import com.google.common.collect.Lists;
/**
* The builder and factory for trino {@link AbstractFnCallTransformer},
* The builder and factory for trino function call transformers,
* and supply transform facade ability.
*/
public class TrinoFnCallTransformers extends AbstractFnCallTransformers {
@ -47,8 +46,8 @@ public class TrinoFnCallTransformers extends AbstractFnCallTransformers {
}
protected void registerStringFunctionTransformer() {
doRegister("codepoint", 1, "ascii",
Lists.newArrayList(PlaceholderExpression.of(Expression.class, 1)), false);
doRegister("codepoint", "ascii",
Lists.newArrayList(PlaceholderExpression.of(Expression.class, 1)));
// TODO: add other string function transformer
}