From b91bce8a62e3bb4e4653b4c6fa7aa57042899305 Mon Sep 17 00:00:00 2001 From: morrySnow <101034200+morrySnow@users.noreply.github.com> Date: Wed, 11 Oct 2023 10:35:06 +0800 Subject: [PATCH] [feature](Nereids) add array distance functions (#25196) - l1_distance - l2_distance - cosine_distance - inner_product --- .../doris/catalog/BuiltinScalarFunctions.java | 8 + .../functions/scalar/CosineDistance.java | 89 ++ .../functions/scalar/InnerProduct.java | 89 ++ .../functions/scalar/L1Distance.java | 89 ++ .../functions/scalar/L2Distance.java | 89 ++ .../visitor/ScalarFunctionVisitor.java | 52 +- .../scalar_function/Array.out | 812 ++++++++++++++++++ .../scalar_function/Array.groovy | 64 ++ 8 files changed, 1276 insertions(+), 16 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CosineDistance.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/InnerProduct.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/L1Distance.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/L2Distance.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java index 789b7f760e..05d9671e1e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java @@ -108,6 +108,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.Conv; import org.apache.doris.nereids.trees.expressions.functions.scalar.ConvertTo; import org.apache.doris.nereids.trees.expressions.functions.scalar.ConvertTz; import org.apache.doris.nereids.trees.expressions.functions.scalar.Cos; +import org.apache.doris.nereids.trees.expressions.functions.scalar.CosineDistance; import org.apache.doris.nereids.trees.expressions.functions.scalar.CountEqual; import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateMap; import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateNamedStruct; @@ -180,6 +181,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursDiff; import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursSub; import org.apache.doris.nereids.trees.expressions.functions.scalar.If; import org.apache.doris.nereids.trees.expressions.functions.scalar.Initcap; +import org.apache.doris.nereids.trees.expressions.functions.scalar.InnerProduct; import org.apache.doris.nereids.trees.expressions.functions.scalar.Instr; import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonArray; import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonContains; @@ -213,6 +215,8 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbParseNul import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbParseNullableErrorToValue; import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbType; import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbValid; +import org.apache.doris.nereids.trees.expressions.functions.scalar.L1Distance; +import org.apache.doris.nereids.trees.expressions.functions.scalar.L2Distance; import org.apache.doris.nereids.trees.expressions.functions.scalar.LastDay; import org.apache.doris.nereids.trees.expressions.functions.scalar.Least; import org.apache.doris.nereids.trees.expressions.functions.scalar.Left; @@ -491,6 +495,7 @@ public class BuiltinScalarFunctions implements FunctionHelper { scalar(ConvertTo.class, "convert_to"), scalar(ConvertTz.class, "convert_tz"), scalar(Cos.class, "cos"), + scalar(CosineDistance.class, "cosine_distance"), scalar(CountEqual.class, "countequal"), scalar(CreateMap.class, "map"), scalar(CreateStruct.class, "struct"), @@ -560,6 +565,7 @@ public class BuiltinScalarFunctions implements FunctionHelper { scalar(HoursSub.class, "hours_sub"), scalar(If.class, "if"), scalar(Initcap.class, "initcap"), + scalar(InnerProduct.class, "inner_product"), scalar(Instr.class, "instr"), scalar(JsonArray.class, "json_array"), scalar(JsonObject.class, "json_object"), @@ -614,6 +620,8 @@ public class BuiltinScalarFunctions implements FunctionHelper { scalar(JsonbType.class, "jsonb_type"), scalar(JsonLength.class, "json_length"), scalar(JsonContains.class, "json_contains"), + scalar(L1Distance.class, "l1_distance"), + scalar(L2Distance.class, "l2_distance"), scalar(LastDay.class, "last_day"), scalar(Least.class, "least"), scalar(Left.class, "left"), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CosineDistance.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CosineDistance.java new file mode 100644 index 0000000000..a87fde5eb7 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CosineDistance.java @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.scalar; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; +import org.apache.doris.nereids.trees.expressions.functions.ComputePrecisionForArrayItemAgg; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * cosine_distance function + */ +public class CosineDistance extends ScalarFunction implements ExplicitlyCastableSignature, + ComputePrecisionForArrayItemAgg, UnaryExpression, AlwaysNullable { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(TinyIntType.INSTANCE), ArrayType.of(TinyIntType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(SmallIntType.INSTANCE), ArrayType.of(SmallIntType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(IntegerType.INSTANCE), ArrayType.of(IntegerType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(BigIntType.INSTANCE), ArrayType.of(BigIntType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(LargeIntType.INSTANCE), ArrayType.of(LargeIntType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(FloatType.INSTANCE), ArrayType.of(FloatType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(DoubleType.INSTANCE), ArrayType.of(DoubleType.INSTANCE)) + ); + + /** + * constructor with 1 argument. + */ + public CosineDistance(Expression arg0, Expression arg1) { + super("cosine_distance", arg0, arg1); + } + + /** + * withChildren. + */ + @Override + public CosineDistance withChildren(List children) { + Preconditions.checkArgument(children.size() == 2); + return new CosineDistance(children.get(0), children.get(1)); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitCosineDistance(this, context); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/InnerProduct.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/InnerProduct.java new file mode 100644 index 0000000000..672a53859e --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/InnerProduct.java @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.scalar; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; +import org.apache.doris.nereids.trees.expressions.functions.ComputePrecisionForArrayItemAgg; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * inner_product function + */ +public class InnerProduct extends ScalarFunction implements ExplicitlyCastableSignature, + ComputePrecisionForArrayItemAgg, UnaryExpression, AlwaysNullable { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(TinyIntType.INSTANCE), ArrayType.of(TinyIntType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(SmallIntType.INSTANCE), ArrayType.of(SmallIntType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(IntegerType.INSTANCE), ArrayType.of(IntegerType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(BigIntType.INSTANCE), ArrayType.of(BigIntType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(LargeIntType.INSTANCE), ArrayType.of(LargeIntType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(FloatType.INSTANCE), ArrayType.of(FloatType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(DoubleType.INSTANCE), ArrayType.of(DoubleType.INSTANCE)) + ); + + /** + * constructor with 1 argument. + */ + public InnerProduct(Expression arg0, Expression arg1) { + super("inner_product", arg0, arg1); + } + + /** + * withChildren. + */ + @Override + public InnerProduct withChildren(List children) { + Preconditions.checkArgument(children.size() == 2); + return new InnerProduct(children.get(0), children.get(1)); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitInnerProduct(this, context); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/L1Distance.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/L1Distance.java new file mode 100644 index 0000000000..c583a7bb4f --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/L1Distance.java @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.scalar; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; +import org.apache.doris.nereids.trees.expressions.functions.ComputePrecisionForArrayItemAgg; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * l1_distance function + */ +public class L1Distance extends ScalarFunction implements ExplicitlyCastableSignature, + ComputePrecisionForArrayItemAgg, UnaryExpression, AlwaysNullable { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(TinyIntType.INSTANCE), ArrayType.of(TinyIntType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(SmallIntType.INSTANCE), ArrayType.of(SmallIntType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(IntegerType.INSTANCE), ArrayType.of(IntegerType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(BigIntType.INSTANCE), ArrayType.of(BigIntType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(LargeIntType.INSTANCE), ArrayType.of(LargeIntType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(FloatType.INSTANCE), ArrayType.of(FloatType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(DoubleType.INSTANCE), ArrayType.of(DoubleType.INSTANCE)) + ); + + /** + * constructor with 1 argument. + */ + public L1Distance(Expression arg0, Expression arg1) { + super("l1_distance", arg0, arg1); + } + + /** + * withChildren. + */ + @Override + public L1Distance withChildren(List children) { + Preconditions.checkArgument(children.size() == 2); + return new L1Distance(children.get(0), children.get(1)); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitL1Distance(this, context); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/L2Distance.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/L2Distance.java new file mode 100644 index 0000000000..64159c41bc --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/L2Distance.java @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.scalar; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; +import org.apache.doris.nereids.trees.expressions.functions.ComputePrecisionForArrayItemAgg; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * l2_distance function + */ +public class L2Distance extends ScalarFunction implements ExplicitlyCastableSignature, + ComputePrecisionForArrayItemAgg, UnaryExpression, AlwaysNullable { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(TinyIntType.INSTANCE), ArrayType.of(TinyIntType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(SmallIntType.INSTANCE), ArrayType.of(SmallIntType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(IntegerType.INSTANCE), ArrayType.of(IntegerType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(BigIntType.INSTANCE), ArrayType.of(BigIntType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(LargeIntType.INSTANCE), ArrayType.of(LargeIntType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(FloatType.INSTANCE), ArrayType.of(FloatType.INSTANCE)), + FunctionSignature.ret(DoubleType.INSTANCE) + .args(ArrayType.of(DoubleType.INSTANCE), ArrayType.of(DoubleType.INSTANCE)) + ); + + /** + * constructor with 1 argument. + */ + public L2Distance(Expression arg0, Expression arg1) { + super("l2_distance", arg0, arg1); + } + + /** + * withChildren. + */ + @Override + public L2Distance withChildren(List children) { + Preconditions.checkArgument(children.size() == 2); + return new L2Distance(children.get(0), children.get(1)); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitL2Distance(this, context); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java index 111116af0e..f532d3d586 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java @@ -108,6 +108,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.Conv; import org.apache.doris.nereids.trees.expressions.functions.scalar.ConvertTo; import org.apache.doris.nereids.trees.expressions.functions.scalar.ConvertTz; import org.apache.doris.nereids.trees.expressions.functions.scalar.Cos; +import org.apache.doris.nereids.trees.expressions.functions.scalar.CosineDistance; import org.apache.doris.nereids.trees.expressions.functions.scalar.CountEqual; import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateMap; import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateNamedStruct; @@ -176,6 +177,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursDiff; import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursSub; import org.apache.doris.nereids.trees.expressions.functions.scalar.If; import org.apache.doris.nereids.trees.expressions.functions.scalar.Initcap; +import org.apache.doris.nereids.trees.expressions.functions.scalar.InnerProduct; import org.apache.doris.nereids.trees.expressions.functions.scalar.Instr; import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonArray; import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonContains; @@ -209,6 +211,8 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbParseNul import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbParseNullableErrorToValue; import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbType; import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbValid; +import org.apache.doris.nereids.trees.expressions.functions.scalar.L1Distance; +import org.apache.doris.nereids.trees.expressions.functions.scalar.L2Distance; import org.apache.doris.nereids.trees.expressions.functions.scalar.LastDay; import org.apache.doris.nereids.trees.expressions.functions.scalar.Least; import org.apache.doris.nereids.trees.expressions.functions.scalar.Left; @@ -716,6 +720,10 @@ public interface ScalarFunctionVisitor { return visitScalarFunction(concatWs, context); } + default R visitConnectionId(ConnectionId connectionId, C context) { + return visitScalarFunction(connectionId, context); + } + default R visitConv(Conv conv, C context) { return visitScalarFunction(conv, context); } @@ -732,10 +740,18 @@ public interface ScalarFunctionVisitor { return visitScalarFunction(cos, context); } + default R visitCosineDistance(CosineDistance cosineDistance, C context) { + return visitScalarFunction(cosineDistance, context); + } + default R visitCountEqual(CountEqual countequal, C context) { return visitScalarFunction(countequal, context); } + default R visitCurrentCatalog(CurrentCatalog currentCatalog, C context) { + return visitScalarFunction(currentCatalog, context); + } + default R visitCurrentDate(CurrentDate currentDate, C context) { return visitScalarFunction(currentDate, context); } @@ -744,28 +760,16 @@ public interface ScalarFunctionVisitor { return visitScalarFunction(currentTime, context); } - default R visitDate(Date date, C context) { - return visitScalarFunction(date, context); + default R visitCurrentUser(CurrentUser currentUser, C context) { + return visitScalarFunction(currentUser, context); } default R visitDatabase(Database database, C context) { return visitScalarFunction(database, context); } - default R visitCurrentUser(CurrentUser currentUser, C context) { - return visitScalarFunction(currentUser, context); - } - - default R visitCurrentCatalog(CurrentCatalog currentCatalog, C context) { - return visitScalarFunction(currentCatalog, context); - } - - default R visitUser(User user, C context) { - return visitScalarFunction(user, context); - } - - default R visitConnectionId(ConnectionId connectionId, C context) { - return visitScalarFunction(connectionId, context); + default R visitDate(Date date, C context) { + return visitScalarFunction(date, context); } default R visitDateDiff(DateDiff dateDiff, C context) { @@ -1048,6 +1052,10 @@ public interface ScalarFunctionVisitor { return visitScalarFunction(initcap, context); } + default R visitInnerProduct(InnerProduct innerProduct, C context) { + return visitScalarFunction(innerProduct, context); + } + default R visitInstr(Instr instr, C context) { return visitScalarFunction(instr, context); } @@ -1180,6 +1188,14 @@ public interface ScalarFunctionVisitor { return visitScalarFunction(jsonbValid, context); } + default R visitL1Distance(L1Distance l1Distance, C context) { + return visitScalarFunction(l1Distance, context); + } + + default R visitL2Distance(L2Distance l2Distance, C context) { + return visitScalarFunction(l2Distance, context); + } + default R visitLastDay(LastDay lastDay, C context) { return visitScalarFunction(lastDay, context); } @@ -1724,6 +1740,10 @@ public interface ScalarFunctionVisitor { return visitScalarFunction(upper, context); } + default R visitUser(User user, C context) { + return visitScalarFunction(user, context); + } + default R visitUtcTimestamp(UtcTimestamp utcTimestamp, C context) { return visitScalarFunction(utcTimestamp, context); } diff --git a/regression-test/data/nereids_function_p0/scalar_function/Array.out b/regression-test/data/nereids_function_p0/scalar_function/Array.out index df91f9f177..a209a9d5d8 100644 --- a/regression-test/data/nereids_function_p0/scalar_function/Array.out +++ b/regression-test/data/nereids_function_p0/scalar_function/Array.out @@ -3073,6 +3073,818 @@ char23,char33,varchar13,varchar23,varchar33,string3 2012-03-11,2012-03-11 2012-03-12,2012-03-12 +-- !sql_l1_distance_Double -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l1_distance_Double_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l1_distance_Float -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l1_distance_Float_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l1_distance_LargeInt -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l1_distance_LargeInt_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l1_distance_BigInt -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l1_distance_BigInt_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l1_distance_SmallInt -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l1_distance_SmallInt_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l1_distance_Integer -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l1_distance_Integer_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l1_distance_TinyInt -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l1_distance_TinyInt_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l2_distance_Double -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l2_distance_Double_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l2_distance_Float -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l2_distance_Float_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l2_distance_LargeInt -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l2_distance_LargeInt_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l2_distance_BigInt -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l2_distance_BigInt_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l2_distance_SmallInt -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l2_distance_SmallInt_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l2_distance_Integer -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l2_distance_Integer_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l2_distance_TinyInt -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_l2_distance_TinyInt_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_cosine_distance_Double -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_cosine_distance_Double_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_cosine_distance_Float -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_cosine_distance_Float_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_cosine_distance_LargeInt -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_cosine_distance_LargeInt_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_cosine_distance_BigInt -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_cosine_distance_BigInt_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_cosine_distance_SmallInt -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_cosine_distance_SmallInt_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_cosine_distance_Integer -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_cosine_distance_Integer_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_cosine_distance_TinyInt -- +\N +\N +\N +\N +\N +\N +\N +\N +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_cosine_distance_TinyInt_notnull -- +\N +\N +\N +\N +\N +\N +\N +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_inner_product_Double -- +\N +0.010000000000000002 +0.04000000000000001 +0.09 +0.16000000000000003 +0.25 +0.36 +0.48999999999999994 +0.6400000000000001 +0.81 +1.0 +1.2100000000000002 +1.44 + +-- !sql_inner_product_Double_notnull -- +0.010000000000000002 +0.04000000000000001 +0.09 +0.16000000000000003 +0.25 +0.36 +0.48999999999999994 +0.6400000000000001 +0.81 +1.0 +1.2100000000000002 +1.44 + +-- !sql_inner_product_Float -- +\N +1.0 +100.0 +121.0 +144.0 +16.0 +25.0 +36.0 +4.0 +49.0 +64.0 +81.0 +9.0 + +-- !sql_inner_product_Float_notnull -- +1.0 +100.0 +121.0 +144.0 +16.0 +25.0 +36.0 +4.0 +49.0 +64.0 +81.0 +9.0 + +-- !sql_inner_product_LargeInt -- +\N +1.0 +100.0 +121.0 +144.0 +16.0 +25.0 +36.0 +4.0 +49.0 +64.0 +81.0 +9.0 + +-- !sql_inner_product_LargeInt_notnull -- +1.0 +100.0 +121.0 +144.0 +16.0 +25.0 +36.0 +4.0 +49.0 +64.0 +81.0 +9.0 + +-- !sql_inner_product_BigInt -- +\N +1.0 +100.0 +121.0 +144.0 +16.0 +25.0 +36.0 +4.0 +49.0 +64.0 +81.0 +9.0 + +-- !sql_inner_product_BigInt_notnull -- +1.0 +100.0 +121.0 +144.0 +16.0 +25.0 +36.0 +4.0 +49.0 +64.0 +81.0 +9.0 + +-- !sql_inner_product_SmallInt -- +\N +1.0 +100.0 +121.0 +144.0 +16.0 +25.0 +36.0 +4.0 +49.0 +64.0 +81.0 +9.0 + +-- !sql_inner_product_SmallInt_notnull -- +1.0 +100.0 +121.0 +144.0 +16.0 +25.0 +36.0 +4.0 +49.0 +64.0 +81.0 +9.0 + +-- !sql_inner_product_Integer -- +\N +1.0 +100.0 +121.0 +144.0 +16.0 +25.0 +36.0 +4.0 +49.0 +64.0 +81.0 +9.0 + +-- !sql_inner_product_Integer_notnull -- +1.0 +100.0 +121.0 +144.0 +16.0 +25.0 +36.0 +4.0 +49.0 +64.0 +81.0 +9.0 + +-- !sql_inner_product_TinyInt -- +\N +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +1.0 +1.0 +1.0 +1.0 +1.0 + +-- !sql_inner_product_TinyInt_notnull -- +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +0.0 +1.0 +1.0 +1.0 +1.0 +1.0 + -- !sql_array_max_Double -- \N 0.1 diff --git a/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy b/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy index 8bd097bf56..aba2050730 100644 --- a/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy +++ b/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy @@ -268,6 +268,70 @@ suite("nereids_scalar_fn_Array") { order_qt_sql_array_join_two_params_DateV2 "select array_join(kadtv2, ',') from fn_test" order_qt_sql_array_join_two_params_DateV2_notnull "select array_join(kadtv2, ',') from fn_test_not_nullable" + // l1_distance + order_qt_sql_l1_distance_Double "select l1_distance(kadbl, kadbl) from fn_test" + order_qt_sql_l1_distance_Double_notnull "select l1_distance(kadbl, kadbl) from fn_test_not_nullable" + order_qt_sql_l1_distance_Float "select l1_distance(kafloat, kafloat) from fn_test" + order_qt_sql_l1_distance_Float_notnull "select l1_distance(kafloat, kafloat) from fn_test_not_nullable" + order_qt_sql_l1_distance_LargeInt "select l1_distance(kalint, kalint) from fn_test" + order_qt_sql_l1_distance_LargeInt_notnull "select l1_distance(kalint, kalint) from fn_test_not_nullable" + order_qt_sql_l1_distance_BigInt "select l1_distance(kabint, kabint) from fn_test" + order_qt_sql_l1_distance_BigInt_notnull "select l1_distance(kabint, kabint) from fn_test_not_nullable" + order_qt_sql_l1_distance_SmallInt "select l1_distance(kasint, kasint) from fn_test" + order_qt_sql_l1_distance_SmallInt_notnull "select l1_distance(kasint, kasint) from fn_test_not_nullable" + order_qt_sql_l1_distance_Integer "select l1_distance(kaint, kaint) from fn_test" + order_qt_sql_l1_distance_Integer_notnull "select l1_distance(kaint, kaint) from fn_test_not_nullable" + order_qt_sql_l1_distance_TinyInt "select l1_distance(katint, katint) from fn_test" + order_qt_sql_l1_distance_TinyInt_notnull "select l1_distance(katint, katint) from fn_test_not_nullable" + + // l2_distance + order_qt_sql_l2_distance_Double "select l2_distance(kadbl, kadbl) from fn_test" + order_qt_sql_l2_distance_Double_notnull "select l2_distance(kadbl, kadbl) from fn_test_not_nullable" + order_qt_sql_l2_distance_Float "select l2_distance(kafloat, kafloat) from fn_test" + order_qt_sql_l2_distance_Float_notnull "select l2_distance(kafloat, kafloat) from fn_test_not_nullable" + order_qt_sql_l2_distance_LargeInt "select l2_distance(kalint, kalint) from fn_test" + order_qt_sql_l2_distance_LargeInt_notnull "select l2_distance(kalint, kalint) from fn_test_not_nullable" + order_qt_sql_l2_distance_BigInt "select l2_distance(kabint, kabint) from fn_test" + order_qt_sql_l2_distance_BigInt_notnull "select l2_distance(kabint, kabint) from fn_test_not_nullable" + order_qt_sql_l2_distance_SmallInt "select l2_distance(kasint, kasint) from fn_test" + order_qt_sql_l2_distance_SmallInt_notnull "select l2_distance(kasint, kasint) from fn_test_not_nullable" + order_qt_sql_l2_distance_Integer "select l2_distance(kaint, kaint) from fn_test" + order_qt_sql_l2_distance_Integer_notnull "select l2_distance(kaint, kaint) from fn_test_not_nullable" + order_qt_sql_l2_distance_TinyInt "select l2_distance(katint, katint) from fn_test" + order_qt_sql_l2_distance_TinyInt_notnull "select l2_distance(katint, katint) from fn_test_not_nullable" + + // cosine_distance + order_qt_sql_cosine_distance_Double "select cosine_distance(kadbl, kadbl) from fn_test" + order_qt_sql_cosine_distance_Double_notnull "select cosine_distance(kadbl, kadbl) from fn_test_not_nullable" + order_qt_sql_cosine_distance_Float "select cosine_distance(kafloat, kafloat) from fn_test" + order_qt_sql_cosine_distance_Float_notnull "select cosine_distance(kafloat, kafloat) from fn_test_not_nullable" + order_qt_sql_cosine_distance_LargeInt "select cosine_distance(kalint, kalint) from fn_test" + order_qt_sql_cosine_distance_LargeInt_notnull "select cosine_distance(kalint, kalint) from fn_test_not_nullable" + order_qt_sql_cosine_distance_BigInt "select cosine_distance(kabint, kabint) from fn_test" + order_qt_sql_cosine_distance_BigInt_notnull "select cosine_distance(kabint, kabint) from fn_test_not_nullable" + order_qt_sql_cosine_distance_SmallInt "select cosine_distance(kasint, kasint) from fn_test" + order_qt_sql_cosine_distance_SmallInt_notnull "select cosine_distance(kasint, kasint) from fn_test_not_nullable" + order_qt_sql_cosine_distance_Integer "select cosine_distance(kaint, kaint) from fn_test" + order_qt_sql_cosine_distance_Integer_notnull "select cosine_distance(kaint, kaint) from fn_test_not_nullable" + order_qt_sql_cosine_distance_TinyInt "select cosine_distance(katint, katint) from fn_test" + order_qt_sql_cosine_distance_TinyInt_notnull "select cosine_distance(katint, katint) from fn_test_not_nullable" + + // inner_product + order_qt_sql_inner_product_Double "select inner_product(kadbl, kadbl) from fn_test" + order_qt_sql_inner_product_Double_notnull "select inner_product(kadbl, kadbl) from fn_test_not_nullable" + order_qt_sql_inner_product_Float "select inner_product(kafloat, kafloat) from fn_test" + order_qt_sql_inner_product_Float_notnull "select inner_product(kafloat, kafloat) from fn_test_not_nullable" + order_qt_sql_inner_product_LargeInt "select inner_product(kalint, kalint) from fn_test" + order_qt_sql_inner_product_LargeInt_notnull "select inner_product(kalint, kalint) from fn_test_not_nullable" + order_qt_sql_inner_product_BigInt "select inner_product(kabint, kabint) from fn_test" + order_qt_sql_inner_product_BigInt_notnull "select inner_product(kabint, kabint) from fn_test_not_nullable" + order_qt_sql_inner_product_SmallInt "select inner_product(kasint, kasint) from fn_test" + order_qt_sql_inner_product_SmallInt_notnull "select inner_product(kasint, kasint) from fn_test_not_nullable" + order_qt_sql_inner_product_Integer "select inner_product(kaint, kaint) from fn_test" + order_qt_sql_inner_product_Integer_notnull "select inner_product(kaint, kaint) from fn_test_not_nullable" + order_qt_sql_inner_product_TinyInt "select inner_product(katint, katint) from fn_test" + order_qt_sql_inner_product_TinyInt_notnull "select inner_product(katint, katint) from fn_test_not_nullable" + // array_max order_qt_sql_array_max_Double "select array_max(kadbl) from fn_test" order_qt_sql_array_max_Double_notnull "select array_max(kadbl) from fn_test_not_nullable"