[feature-wip](nereids) Support some spark-sql built-in functions when set dialect=spark_sql (#28531)
This commit is contained in:
@ -23,9 +23,11 @@ import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Expression placeHolder, the expression in PlaceHolderExpression will be collected by
|
||||
@ -33,15 +35,25 @@ import java.util.Objects;
|
||||
* @see PlaceholderCollector
|
||||
*/
|
||||
public class PlaceholderExpression extends Expression implements AlwaysNotNullable {
|
||||
private final Class<? extends Expression> delegateClazz;
|
||||
|
||||
private final ImmutableSet<Class<? extends Expression>> delegateClazzSet;
|
||||
/**
|
||||
* 1 based
|
||||
* start from 1, set the index of this placeholderExpression in sourceFnTransformedArguments
|
||||
* this placeholderExpression will be replaced later
|
||||
*/
|
||||
private final int position;
|
||||
|
||||
public PlaceholderExpression(List<Expression> children, Class<? extends Expression> delegateClazz, int position) {
|
||||
super(children);
|
||||
this.delegateClazz = Objects.requireNonNull(delegateClazz, "delegateClazz should not be null");
|
||||
this.delegateClazzSet = ImmutableSet.of(
|
||||
Objects.requireNonNull(delegateClazz, "delegateClazz should not be null"));
|
||||
this.position = position;
|
||||
}
|
||||
|
||||
public PlaceholderExpression(List<Expression> children,
|
||||
Set<Class<? extends Expression>> delegateClazzSet, int position) {
|
||||
super(children);
|
||||
this.delegateClazzSet = ImmutableSet.copyOf(delegateClazzSet);
|
||||
this.position = position;
|
||||
}
|
||||
|
||||
@ -49,13 +61,18 @@ public class PlaceholderExpression extends Expression implements AlwaysNotNullab
|
||||
return new PlaceholderExpression(ImmutableList.of(), delegateClazz, position);
|
||||
}
|
||||
|
||||
public static PlaceholderExpression of(Set<Class<? extends Expression>> delegateClazzSet, int position) {
|
||||
return new PlaceholderExpression(ImmutableList.of(), delegateClazzSet, position);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
visitor.visitPlaceholderExpression(this, context);
|
||||
return visitor.visit(this, context);
|
||||
}
|
||||
|
||||
public Class<? extends Expression> getDelegateClazz() {
|
||||
return delegateClazz;
|
||||
public Set<Class<? extends Expression>> getDelegateClazzSet() {
|
||||
return delegateClazzSet;
|
||||
}
|
||||
|
||||
public int getPosition() {
|
||||
@ -74,11 +91,11 @@ public class PlaceholderExpression extends Expression implements AlwaysNotNullab
|
||||
return false;
|
||||
}
|
||||
PlaceholderExpression that = (PlaceholderExpression) o;
|
||||
return position == that.position && Objects.equals(delegateClazz, that.delegateClazz);
|
||||
return position == that.position && Objects.equals(delegateClazzSet, that.delegateClazzSet);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(super.hashCode(), delegateClazz, position);
|
||||
return Objects.hash(super.hashCode(), delegateClazzSet, position);
|
||||
}
|
||||
}
|
||||
|
||||
@ -77,17 +77,15 @@ public abstract class AbstractFnCallTransformers {
|
||||
|
||||
protected void doRegister(
|
||||
String sourceFnNme,
|
||||
int sourceFnArgumentsNum,
|
||||
String targetFnName,
|
||||
List<? extends Expression> targetFnArguments,
|
||||
boolean variableArgument) {
|
||||
List<? extends Expression> targetFnArguments) {
|
||||
|
||||
List<Expression> castedTargetFnArguments = targetFnArguments
|
||||
.stream()
|
||||
.map(each -> (Expression) each)
|
||||
.collect(Collectors.toList());
|
||||
transformerBuilder.put(sourceFnNme, new CommonFnCallTransformer(new UnboundFunction(
|
||||
targetFnName, castedTargetFnArguments), variableArgument, sourceFnArgumentsNum));
|
||||
targetFnName, castedTargetFnArguments)));
|
||||
}
|
||||
|
||||
protected void doRegister(
|
||||
|
||||
@ -23,56 +23,76 @@ import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Function;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Trino function transformer
|
||||
* Common function transformer,
|
||||
* can transform functions which the size and type of target arguments are both the same with the source function,
|
||||
* or source function is a variable-arguments function.
|
||||
*/
|
||||
public class CommonFnCallTransformer extends AbstractFnCallTransformer {
|
||||
private final UnboundFunction targetFunction;
|
||||
private final List<PlaceholderExpression> targetArguments;
|
||||
private final boolean variableArgument;
|
||||
private final int sourceArgumentsNum;
|
||||
|
||||
// true means the arguments of this function is dynamic, for example:
|
||||
// - named_struct('f1', 1, 'f2', 'a', 'f3', "abc")
|
||||
// - struct(1, 'a', 'abc');
|
||||
private final boolean variableArguments;
|
||||
|
||||
/**
|
||||
* Trino function transformer, mostly this handle common function.
|
||||
* Common function transformer, mostly this handle common function.
|
||||
*/
|
||||
public CommonFnCallTransformer(UnboundFunction targetFunction,
|
||||
boolean variableArgument,
|
||||
int sourceArgumentsNum) {
|
||||
public CommonFnCallTransformer(UnboundFunction targetFunction, boolean variableArguments) {
|
||||
this.targetFunction = targetFunction;
|
||||
this.variableArgument = variableArgument;
|
||||
this.sourceArgumentsNum = sourceArgumentsNum;
|
||||
PlaceholderCollector placeHolderCollector = new PlaceholderCollector(variableArgument);
|
||||
PlaceholderCollector placeHolderCollector = new PlaceholderCollector();
|
||||
placeHolderCollector.visit(targetFunction, null);
|
||||
this.targetArguments = placeHolderCollector.getPlaceholderExpressions();
|
||||
this.variableArguments = variableArguments;
|
||||
}
|
||||
|
||||
public CommonFnCallTransformer(UnboundFunction targetFunction) {
|
||||
this.targetFunction = targetFunction;
|
||||
PlaceholderCollector placeHolderCollector = new PlaceholderCollector();
|
||||
placeHolderCollector.visit(targetFunction, null);
|
||||
this.targetArguments = placeHolderCollector.getPlaceholderExpressions();
|
||||
this.variableArguments = false;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean check(String sourceFnName,
|
||||
List<Expression> sourceFnTransformedArguments,
|
||||
ParserContext context) {
|
||||
// if variableArguments=true, we can not recognize if the type of all arguments is valid or not,
|
||||
// because:
|
||||
// 1. the argument size is not sure
|
||||
// 2. there are some functions which can accept different types of arguments,
|
||||
// for example: struct(1, 'a', 'abc')
|
||||
// so just return true here.
|
||||
if (variableArguments) {
|
||||
return true;
|
||||
}
|
||||
List<Class<? extends Expression>> sourceFnTransformedArgClazz = sourceFnTransformedArguments.stream()
|
||||
.map(Expression::getClass)
|
||||
.collect(Collectors.toList());
|
||||
if (variableArgument) {
|
||||
if (targetArguments.isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
Class<? extends Expression> targetArgumentClazz = targetArguments.get(0).getDelegateClazz();
|
||||
for (Expression argument : sourceFnTransformedArguments) {
|
||||
if (!targetArgumentClazz.isAssignableFrom(argument.getClass())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (sourceFnTransformedArguments.size() != sourceArgumentsNum) {
|
||||
if (sourceFnTransformedArguments.size() != targetArguments.size()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < targetArguments.size(); i++) {
|
||||
if (!targetArguments.get(i).getDelegateClazz().isAssignableFrom(sourceFnTransformedArgClazz.get(i))) {
|
||||
for (PlaceholderExpression targetArgument : targetArguments) {
|
||||
// replace the arguments of target function by the position of target argument
|
||||
int position = targetArgument.getPosition();
|
||||
Class<? extends Expression> sourceArgClazz = sourceFnTransformedArgClazz.get(position - 1);
|
||||
boolean valid = false;
|
||||
for (Class<? extends Expression> targetArgClazz : targetArgument.getDelegateClazzSet()) {
|
||||
if (targetArgClazz.isAssignableFrom(sourceArgClazz)) {
|
||||
valid = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!valid) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -83,7 +103,16 @@ public class CommonFnCallTransformer extends AbstractFnCallTransformer {
|
||||
protected Function transform(String sourceFnName,
|
||||
List<Expression> sourceFnTransformedArguments,
|
||||
ParserContext context) {
|
||||
return targetFunction.withChildren(sourceFnTransformedArguments);
|
||||
if (variableArguments) {
|
||||
// not support adjust the order of arguments when variableArguments=true
|
||||
return targetFunction.withChildren(sourceFnTransformedArguments);
|
||||
}
|
||||
List<Expression> sourceFnTransformedArgumentsInorder = Lists.newArrayList();
|
||||
for (PlaceholderExpression placeholderExpression : targetArguments) {
|
||||
Expression expression = sourceFnTransformedArguments.get(placeholderExpression.getPosition() - 1);
|
||||
sourceFnTransformedArgumentsInorder.add(expression);
|
||||
}
|
||||
return targetFunction.withChildren(sourceFnTransformedArgumentsInorder);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -93,20 +122,12 @@ public class CommonFnCallTransformer extends AbstractFnCallTransformer {
|
||||
public static final class PlaceholderCollector extends DefaultExpressionVisitor<Void, Void> {
|
||||
|
||||
private final List<PlaceholderExpression> placeholderExpressions = new ArrayList<>();
|
||||
private final boolean variableArgument;
|
||||
|
||||
public PlaceholderCollector(boolean variableArgument) {
|
||||
this.variableArgument = variableArgument;
|
||||
}
|
||||
public PlaceholderCollector() {}
|
||||
|
||||
@Override
|
||||
public Void visitPlaceholderExpression(PlaceholderExpression placeholderExpression, Void context) {
|
||||
|
||||
if (variableArgument) {
|
||||
placeholderExpressions.add(placeholderExpression);
|
||||
return null;
|
||||
}
|
||||
placeholderExpressions.set(placeholderExpression.getPosition() - 1, placeholderExpression);
|
||||
placeholderExpressions.add(placeholderExpression);
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@ -15,14 +15,12 @@
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.parser.trino;
|
||||
|
||||
import org.apache.doris.nereids.parser.AbstractFnCallTransformer;
|
||||
package org.apache.doris.nereids.parser;
|
||||
|
||||
/**
|
||||
* Trino complex function transformer
|
||||
*/
|
||||
public abstract class ComplexTrinoFnCallTransformer extends AbstractFnCallTransformer {
|
||||
public abstract class ComplexFnCallTransformer extends AbstractFnCallTransformer {
|
||||
|
||||
protected abstract String getSourceFnName();
|
||||
}
|
||||
@ -0,0 +1,81 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.parser.spark;
|
||||
|
||||
import org.apache.doris.nereids.analyzer.UnboundFunction;
|
||||
import org.apache.doris.nereids.parser.ComplexFnCallTransformer;
|
||||
import org.apache.doris.nereids.parser.ParserContext;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Function;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* DateTrunc complex function transformer
|
||||
*/
|
||||
public class DateTruncFnCallTransformer extends ComplexFnCallTransformer {
|
||||
|
||||
// reference: https://spark.apache.org/docs/latest/api/sql/index.html#trunc
|
||||
// spark-sql support YEAR/YYYY/YY for year, support MONTH/MON/MM for month
|
||||
private static final ImmutableSet<String> YEAR = ImmutableSet.<String>builder()
|
||||
.add("YEAR")
|
||||
.add("YYYY")
|
||||
.add("YY")
|
||||
.build();
|
||||
|
||||
private static final ImmutableSet<String> MONTH = ImmutableSet.<String>builder()
|
||||
.add("MONTH")
|
||||
.add("MON")
|
||||
.add("MM")
|
||||
.build();
|
||||
|
||||
@Override
|
||||
public String getSourceFnName() {
|
||||
return "trunc";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean check(String sourceFnName, List<Expression> sourceFnTransformedArguments,
|
||||
ParserContext context) {
|
||||
return getSourceFnName().equalsIgnoreCase(sourceFnName) && (sourceFnTransformedArguments.size() == 2);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Function transform(String sourceFnName, List<Expression> sourceFnTransformedArguments,
|
||||
ParserContext context) {
|
||||
VarcharLiteral fmtLiteral = (VarcharLiteral) sourceFnTransformedArguments.get(1);
|
||||
if (YEAR.contains(fmtLiteral.getValue().toUpperCase())) {
|
||||
return new UnboundFunction(
|
||||
"date_trunc",
|
||||
ImmutableList.of(sourceFnTransformedArguments.get(0), new VarcharLiteral("YEAR")));
|
||||
}
|
||||
if (MONTH.contains(fmtLiteral.getValue().toUpperCase())) {
|
||||
return new UnboundFunction(
|
||||
"date_trunc",
|
||||
ImmutableList.of(sourceFnTransformedArguments.get(0), new VarcharLiteral("MONTH")));
|
||||
}
|
||||
|
||||
return new UnboundFunction(
|
||||
"date_trunc",
|
||||
ImmutableList.of(sourceFnTransformedArguments.get(0), sourceFnTransformedArguments.get(1)));
|
||||
}
|
||||
}
|
||||
@ -18,15 +18,13 @@
|
||||
package org.apache.doris.nereids.parser.spark;
|
||||
|
||||
import org.apache.doris.nereids.analyzer.PlaceholderExpression;
|
||||
import org.apache.doris.nereids.parser.AbstractFnCallTransformer;
|
||||
import org.apache.doris.nereids.parser.AbstractFnCallTransformers;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
/**
|
||||
* The builder and factory for spark-sql 3.x {@link AbstractFnCallTransformer},
|
||||
* and supply transform facade ability.
|
||||
* The builder and factory for spark-sql 3.x FnCallTransformers, supply transform facade ability.
|
||||
*/
|
||||
public class SparkSql3FnCallTransformers extends AbstractFnCallTransformers {
|
||||
|
||||
@ -35,32 +33,56 @@ public class SparkSql3FnCallTransformers extends AbstractFnCallTransformers {
|
||||
|
||||
@Override
|
||||
protected void registerTransformers() {
|
||||
doRegister("get_json_object", 2, "json_extract",
|
||||
Lists.newArrayList(
|
||||
PlaceholderExpression.of(Expression.class, 1),
|
||||
PlaceholderExpression.of(Expression.class, 2)), true);
|
||||
|
||||
doRegister("get_json_object", 2, "json_extract",
|
||||
Lists.newArrayList(
|
||||
PlaceholderExpression.of(Expression.class, 1),
|
||||
PlaceholderExpression.of(Expression.class, 2)), false);
|
||||
|
||||
doRegister("split", 2, "split_by_string",
|
||||
Lists.newArrayList(
|
||||
PlaceholderExpression.of(Expression.class, 1),
|
||||
PlaceholderExpression.of(Expression.class, 2)), true);
|
||||
doRegister("split", 2, "split_by_string",
|
||||
Lists.newArrayList(
|
||||
PlaceholderExpression.of(Expression.class, 1),
|
||||
PlaceholderExpression.of(Expression.class, 2)), false);
|
||||
// register json functions
|
||||
registerJsonFunctionTransformers();
|
||||
// register string functions
|
||||
registerStringFunctionTransformers();
|
||||
// register date functions
|
||||
registerDateFunctionTransformers();
|
||||
// register numeric functions
|
||||
registerNumericFunctionTransformers();
|
||||
// TODO: add other function transformer
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void registerComplexTransformers() {
|
||||
DateTruncFnCallTransformer dateTruncFnCallTransformer = new DateTruncFnCallTransformer();
|
||||
doRegister(dateTruncFnCallTransformer.getSourceFnName(), dateTruncFnCallTransformer);
|
||||
// TODO: add other complex function transformer
|
||||
}
|
||||
|
||||
private void registerJsonFunctionTransformers() {
|
||||
doRegister("get_json_object", "json_extract",
|
||||
Lists.newArrayList(
|
||||
PlaceholderExpression.of(Expression.class, 1),
|
||||
PlaceholderExpression.of(Expression.class, 2)));
|
||||
}
|
||||
|
||||
private void registerStringFunctionTransformers() {
|
||||
doRegister("split", "split_by_string",
|
||||
Lists.newArrayList(
|
||||
PlaceholderExpression.of(Expression.class, 1),
|
||||
PlaceholderExpression.of(Expression.class, 2)));
|
||||
}
|
||||
|
||||
private void registerDateFunctionTransformers() {
|
||||
// spark-sql support to_date(date_str, fmt) function but doris only support to_date(date_str)
|
||||
// here try to compat with this situation by using str_to_date(str, fmt),
|
||||
// this function support the following three formats which can handle the mainly situations:
|
||||
// 1. yyyyMMdd
|
||||
// 2. yyyy-MM-dd
|
||||
// 3. yyyy-MM-dd HH:mm:ss
|
||||
doRegister("to_date", "str_to_date",
|
||||
Lists.newArrayList(
|
||||
PlaceholderExpression.of(Expression.class, 1),
|
||||
PlaceholderExpression.of(Expression.class, 2)));
|
||||
}
|
||||
|
||||
private void registerNumericFunctionTransformers() {
|
||||
doRegister("mean", "avg",
|
||||
Lists.newArrayList(PlaceholderExpression.of(Expression.class, 1)));
|
||||
}
|
||||
|
||||
static class SingletonHolder {
|
||||
private static final SparkSql3FnCallTransformers INSTANCE = new SparkSql3FnCallTransformers();
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.parser.trino;
|
||||
|
||||
import org.apache.doris.nereids.analyzer.UnboundFunction;
|
||||
import org.apache.doris.nereids.parser.ComplexFnCallTransformer;
|
||||
import org.apache.doris.nereids.parser.ParserContext;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Function;
|
||||
@ -30,7 +31,7 @@ import java.util.List;
|
||||
/**
|
||||
* DateDiff complex function transformer
|
||||
*/
|
||||
public class DateDiffFnCallTransformer extends ComplexTrinoFnCallTransformer {
|
||||
public class DateDiffFnCallTransformer extends ComplexFnCallTransformer {
|
||||
|
||||
private static final String SECOND = "second";
|
||||
private static final String HOUR = "hour";
|
||||
@ -45,15 +46,12 @@ public class DateDiffFnCallTransformer extends ComplexTrinoFnCallTransformer {
|
||||
@Override
|
||||
protected boolean check(String sourceFnName, List<Expression> sourceFnTransformedArguments,
|
||||
ParserContext context) {
|
||||
return getSourceFnName().equalsIgnoreCase(sourceFnName);
|
||||
return getSourceFnName().equalsIgnoreCase(sourceFnName) && (sourceFnTransformedArguments.size() == 3);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Function transform(String sourceFnName, List<Expression> sourceFnTransformedArguments,
|
||||
ParserContext context) {
|
||||
if (sourceFnTransformedArguments.size() != 3) {
|
||||
return null;
|
||||
}
|
||||
VarcharLiteral diffGranularity = (VarcharLiteral) sourceFnTransformedArguments.get(0);
|
||||
if (SECOND.equals(diffGranularity.getValue())) {
|
||||
return new UnboundFunction(
|
||||
|
||||
@ -18,14 +18,13 @@
|
||||
package org.apache.doris.nereids.parser.trino;
|
||||
|
||||
import org.apache.doris.nereids.analyzer.PlaceholderExpression;
|
||||
import org.apache.doris.nereids.parser.AbstractFnCallTransformer;
|
||||
import org.apache.doris.nereids.parser.AbstractFnCallTransformers;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
/**
|
||||
* The builder and factory for trino {@link AbstractFnCallTransformer},
|
||||
* The builder and factory for trino function call transformers,
|
||||
* and supply transform facade ability.
|
||||
*/
|
||||
public class TrinoFnCallTransformers extends AbstractFnCallTransformers {
|
||||
@ -47,8 +46,8 @@ public class TrinoFnCallTransformers extends AbstractFnCallTransformers {
|
||||
}
|
||||
|
||||
protected void registerStringFunctionTransformer() {
|
||||
doRegister("codepoint", 1, "ascii",
|
||||
Lists.newArrayList(PlaceholderExpression.of(Expression.class, 1)), false);
|
||||
doRegister("codepoint", "ascii",
|
||||
Lists.newArrayList(PlaceholderExpression.of(Expression.class, 1)));
|
||||
// TODO: add other string function transformer
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user