[fix](Nerids) fix error when the view has lambda functions (#25067)

1. To ensure compatibility with the original optimizer, expose the non-lambda signature of highorder function externally.
2. fix some bugs in toSql function in the original optimizer
This commit is contained in:
谢健
2023-10-08 15:45:24 +08:00
committed by GitHub
parent 541f48a754
commit 3a45001447
11 changed files with 207 additions and 31 deletions

View File

@ -273,7 +273,14 @@ public class LambdaFunctionCallExpr extends FunctionCallExpr {
@Override
public String toSqlImpl() {
StringBuilder sb = new StringBuilder();
sb.append(getFnName().getFunction());
String fnName = getFnName().getFunction();
if (fn != null) {
// `array_last` will be replaced with `element_at` function after analysis.
// At this moment, using the name `array_last` would generate invalid SQL.
fnName = fn.getFunctionName().getFunction();
}
sb.append(fnName);
sb.append("(");
int childSize = children.size();
Expr lastExpr = getChild(childSize - 1);
@ -295,8 +302,12 @@ public class LambdaFunctionCallExpr extends FunctionCallExpr {
// and some functions is only implement as a normal array function;
// but also want use as lambda function, select array_sortby(x->x,['b','a','c']);
// so we convert to: array_sortby(array('b', 'a', 'c'), array_map(x -> `x`, array('b', 'a', 'c')))
if (lastIsLambdaExpr == false) {
sb.append(", ");
if (!lastIsLambdaExpr) {
if (childSize > 1) {
// some functions don't have lambda expr, so don't need to add ","
// such as array_exists(array_map(x->x>3, [1,2,3,6,34,3,11]))
sb.append(", ");
}
sb.append(lastExpr.toSql());
}
sb.append(")");

View File

@ -119,6 +119,7 @@ public interface ComputeSignature extends FunctionTrait, ImplicitCastInputTypes
/** use processor to process computeSignature */
static boolean processComplexType(DataType signatureType, DataType realType,
BiFunction<DataType, DataType, Boolean> processor) {
if (signatureType instanceof ArrayType && realType instanceof ArrayType) {
return processor.apply(((ArrayType) signatureType).getItemType(),
((ArrayType) realType).getItemType());

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
@ -49,11 +48,7 @@ public class ArrayCount extends ScalarFunction
* array_count(lambda, a1, ...) = array_count(array_map(lambda, a1, ...))
*/
public ArrayCount(Expression arg) {
super("array_count", new ArrayMap(arg));
if (!(arg instanceof Lambda)) {
throw new AnalysisException(
String.format("The 1st arg of %s must be lambda but is %s", getName(), arg));
}
super("array_count", arg instanceof Lambda ? new ArrayMap(arg) : arg);
}
/**

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
@ -48,11 +47,7 @@ public class ArrayExists extends ScalarFunction
* array_exists(lambda, a1, ...) = array_exists(array_map(lambda, a1, ...))
*/
public ArrayExists(Expression arg) {
super("array_exists", new ArrayMap(arg));
if (!(arg instanceof Lambda)) {
throw new AnalysisException(
String.format("The 1st arg of %s must be lambda but is %s", getName(), arg));
}
super("array_exists", arg instanceof Lambda ? new ArrayMap(arg) : arg);
}
/**

View File

@ -56,6 +56,10 @@ public class ArrayFilter extends ScalarFunction
}
}
public ArrayFilter(Expression arg1, Expression arg2) {
super("array_filter", arg1, arg2);
}
@Override
public ArrayFilter withChildren(List<Expression> children) {
return new ArrayFilter(children);

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
@ -49,11 +48,7 @@ public class ArrayFirstIndex extends ScalarFunction
* array_first_index(lambda, a1, ...) = array_first_index(array_map(lambda, a1, ...))
*/
public ArrayFirstIndex(Expression arg) {
super("array_first_index", new ArrayMap(arg));
if (!(arg instanceof Lambda)) {
throw new AnalysisException(
String.format("The 1st arg of %s must be lambda but is %s", getName(), arg));
}
super("array_first_index", arg instanceof Lambda ? new ArrayMap(arg) : arg);
}
/**

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
@ -49,11 +48,7 @@ public class ArrayLastIndex extends ScalarFunction
* array_last_index(lambda, a1, ...) = array_last_index(array_map(lambda, a1, ...))
*/
public ArrayLastIndex(Expression arg) {
super("array_last_index", new ArrayMap(arg));
if (!(arg instanceof Lambda)) {
throw new AnalysisException(
String.format("The 1st arg of %s must be lambda but is %s", getName(), arg));
}
super("array_last_index", arg instanceof Lambda ? new ArrayMap(arg) : arg);
}
/**

View File

@ -20,7 +20,6 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
@ -33,7 +32,7 @@ import java.util.List;
* ScalarFunction 'array_sortby'.
*/
public class ArraySortBy extends ScalarFunction
implements HighOrderFunction, PropagateNullable {
implements HighOrderFunction {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0).args(ArrayType.of(AnyDataType.INSTANCE_WITHOUT_INDEX),
@ -56,6 +55,10 @@ public class ArraySortBy extends ScalarFunction
}
}
public ArraySortBy(Expression arg1, Expression arg2) {
super("array_sortby", arg1, arg2);
}
@Override
public ArraySortBy withChildren(List<Expression> children) {
return new ArraySortBy(children);
@ -70,4 +73,9 @@ public class ArraySortBy extends ScalarFunction
public List<FunctionSignature> getImplSignature() {
return SIGNATURES;
}
@Override
public boolean nullable() {
return child(0).nullable();
}
}

View File

@ -146,9 +146,10 @@ public class TypeCoercionUtils {
* Return Optional.empty() if we cannot do implicit cast.
*/
public static Optional<DataType> implicitCast(DataType input, DataType expected) {
if (input instanceof ArrayType && expected instanceof ArrayType) {
if ((input instanceof ArrayType || input instanceof NullType) && expected instanceof ArrayType) {
Optional<DataType> itemType = implicitCast(
((ArrayType) input).getItemType(), ((ArrayType) expected).getItemType());
input instanceof ArrayType ? ((ArrayType) input).getItemType() : input,
((ArrayType) expected).getItemType());
return itemType.map(ArrayType::of);
} else if (input instanceof MapType && expected instanceof MapType) {
Optional<DataType> keyType = implicitCast(