From 177e82bdabafc4566eef5c410295a4b88f2f1516 Mon Sep 17 00:00:00 2001 From: xy720 <22125576+xy720@users.noreply.github.com> Date: Mon, 24 Oct 2022 11:51:47 +0800 Subject: [PATCH] [Enhancement](array-type) Add type derivation for array functions (#13534) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit From now, we don't support type derivation for array function's arguments. So that the cases below will return wrong values or even cause be core. mysql> select array_union([1],[10000000]); +----------------------------------------+ | array_union(ARRAY(1), ARRAY(10000000)) | +----------------------------------------+ | [1, -128] | +----------------------------------------+ 1 row in set (0.03 sec) mysql> select array_union([NULL],[1]); ERROR 1105 (HY000): RpcException, msg: io.grpc.StatusRuntimeException: UNAVAILABLE: Network closed for unknown reason mysql> select array_union([],[1]); ERROR 1105 (HY000): RpcException, msg: io.grpc.StatusRuntimeException: UNAVAILABLE: Network closed for unknown reason This commit make a small fix to derivate the argument types of the array function 1、 For null type in arguments, cast the null type to boolean type, because null type should not be seen in be. 2、For different types in arguments, cast all arguments type to their compatible type. --- .../array/function_array_aggregation.cpp | 3 +- .../doris/analysis/FunctionCallExpr.java | 33 ++++++++- gensrc/script/doris_builtins_functions.py | 5 ++ .../test_array_functions_by_literal.out | 69 +++++++++++++++++++ .../test_array_functions_by_literal.groovy | 23 +++++++ 5 files changed, 131 insertions(+), 2 deletions(-) diff --git a/be/src/vec/functions/array/function_array_aggregation.cpp b/be/src/vec/functions/array/function_array_aggregation.cpp index aed53bc25b..d1b8623151 100644 --- a/be/src/vec/functions/array/function_array_aggregation.cpp +++ b/be/src/vec/functions/array/function_array_aggregation.cpp @@ -159,7 +159,8 @@ struct ArrayAggregateImpl { const IColumn* data = array.get_data_ptr().get(); const auto& offsets = array.get_offsets(); - if (execute_type(res, type, data, offsets) || + if (execute_type(res, type, data, offsets) || + execute_type(res, type, data, offsets) || execute_type(res, type, data, offsets) || execute_type(res, type, data, offsets) || execute_type(res, type, data, offsets) || diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index 6ff3404129..8ee6c81d64 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -812,6 +812,35 @@ public class FunctionCallExpr extends Expr { } + private void analyzeArrayFunction(Analyzer analyzer) throws AnalysisException { + if (fnName.getFunction().equalsIgnoreCase("array_distinct") + || fnName.getFunction().equalsIgnoreCase("array_max") + || fnName.getFunction().equalsIgnoreCase("array_min") + || fnName.getFunction().equalsIgnoreCase("array_sum") + || fnName.getFunction().equalsIgnoreCase("array_avg") + || fnName.getFunction().equalsIgnoreCase("array_product") + || fnName.getFunction().equalsIgnoreCase("array_union") + || fnName.getFunction().equalsIgnoreCase("array_except") + || fnName.getFunction().equalsIgnoreCase("array_intersect") + || fnName.getFunction().equalsIgnoreCase("arrays_overlap")) { + Type[] childTypes = collectChildReturnTypes(); + Type compatibleType = childTypes[0]; + for (int i = 1; i < childTypes.length; ++i) { + compatibleType = Type.getAssignmentCompatibleType(compatibleType, childTypes[i], true); + if (compatibleType == Type.INVALID) { + throw new AnalysisException(getFunctionNotFoundError(collectChildReturnTypes())); + } + } + // Make sure BE doesn't see any TYPE_NULL exprs + if (compatibleType.isNull()) { + compatibleType = Type.BOOLEAN; + } + for (int i = 0; i < childTypes.length; i++) { + uncheckedCastChild(compatibleType, i); + } + } + } + // Provide better error message for some aggregate builtins. These can be // a bit more user friendly than a generic function not found. // TODO: should we bother to do this? We could also improve the general @@ -903,6 +932,8 @@ public class FunctionCallExpr extends Expr { analyzeBuiltinAggFunction(analyzer); + analyzeArrayFunction(analyzer); + if (fnName.getFunction().equalsIgnoreCase("sum")) { if (this.children.isEmpty()) { throw new AnalysisException("The " + fnName + " function must has one input param"); @@ -1250,7 +1281,7 @@ public class FunctionCallExpr extends Expr { if (this.type instanceof ArrayType) { ArrayType arrayType = (ArrayType) type; - // Now Array type do not support ARRAY, set it too true temporarily + // Now Array type do not support ARRAY, set it to true temporarily boolean containsNull = true; for (Expr child : children) { Type childType = child.getType(); diff --git a/gensrc/script/doris_builtins_functions.py b/gensrc/script/doris_builtins_functions.py index 68308b64e0..0976d0b284 100755 --- a/gensrc/script/doris_builtins_functions.py +++ b/gensrc/script/doris_builtins_functions.py @@ -273,6 +273,7 @@ visible_functions = [ [['array_join'], 'STRING', ['ARRAY_VARCHAR','VARCHAR', 'VARCHAR'], '', '', '', 'vec', ''], [['array_join'], 'STRING', ['ARRAY_STRING','VARCHAR', 'VARCHAR'], '', '', '', 'vec', ''], + [['array_min'], 'BOOLEAN', ['ARRAY_BOOLEAN'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_min'], 'TINYINT', ['ARRAY_TINYINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_min'], 'SMALLINT', ['ARRAY_SMALLINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_min'], 'INT', ['ARRAY_INT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], @@ -285,6 +286,7 @@ visible_functions = [ [['array_min'], 'DATETIME', ['ARRAY_DATETIME'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_min'], 'DATEV2', ['ARRAY_DATEV2'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_min'], 'DATETIMEV2', ['ARRAY_DATETIMEV2'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], + [['array_max'], 'BOOLEAN', ['ARRAY_BOOLEAN'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_max'], 'TINYINT', ['ARRAY_TINYINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_max'], 'SMALLINT', ['ARRAY_SMALLINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_max'], 'INT', ['ARRAY_INT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], @@ -297,6 +299,7 @@ visible_functions = [ [['array_max'], 'DATETIME', ['ARRAY_DATETIME'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_max'], 'DATEV2', ['ARRAY_DATEV2'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_max'], 'DATETIMEV2', ['ARRAY_DATETIMEV2'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], + [['array_sum'], 'BIGINT', ['ARRAY_BOOLEAN'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_sum'], 'BIGINT', ['ARRAY_TINYINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_sum'], 'BIGINT', ['ARRAY_SMALLINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_sum'], 'BIGINT', ['ARRAY_INT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], @@ -305,6 +308,7 @@ visible_functions = [ [['array_sum'], 'DOUBLE', ['ARRAY_FLOAT'], '', '', '','vec', 'ALWAYS_NULLABLE'], [['array_sum'], 'DOUBLE', ['ARRAY_DOUBLE'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_sum'], 'DECIMALV2',['ARRAY_DECIMALV2'],'', '', '', 'vec', 'ALWAYS_NULLABLE'], + [['array_avg'], 'DOUBLE', ['ARRAY_BOOLEAN'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_avg'], 'DOUBLE', ['ARRAY_TINYINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_avg'], 'DOUBLE', ['ARRAY_SMALLINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_avg'], 'DOUBLE', ['ARRAY_INT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], @@ -313,6 +317,7 @@ visible_functions = [ [['array_avg'], 'DOUBLE', ['ARRAY_FLOAT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_avg'], 'DOUBLE', ['ARRAY_DOUBLE'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_avg'], 'DECIMALV2',['ARRAY_DECIMALV2'],'', '', '', 'vec', 'ALWAYS_NULLABLE'], + [['array_product'], 'DOUBLE', ['ARRAY_BOOLEAN'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_product'], 'DOUBLE', ['ARRAY_TINYINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_product'], 'DOUBLE', ['ARRAY_SMALLINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], [['array_product'], 'DOUBLE', ['ARRAY_INT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'], diff --git a/regression-test/data/query_p0/sql_functions/array_functions/test_array_functions_by_literal.out b/regression-test/data/query_p0/sql_functions/array_functions/test_array_functions_by_literal.out index 64b7587d61..c50009fb8c 100644 --- a/regression-test/data/query_p0/sql_functions/array_functions/test_array_functions_by_literal.out +++ b/regression-test/data/query_p0/sql_functions/array_functions/test_array_functions_by_literal.out @@ -134,6 +134,30 @@ false -- !sql -- 3 +-- !sql -- +\N + +-- !sql -- +\N + +-- !sql -- +\N + +-- !sql -- +\N + +-- !sql -- +\N + +-- !sql -- +\N + +-- !sql -- +\N + +-- !sql -- +\N + -- !sql -- [1, 2, 3] @@ -149,6 +173,12 @@ false -- !sql -- [1, 0, NULL] +-- !sql -- +[] + +-- !sql -- +[NULL, NULL] + -- !sql -- [2, 3] @@ -200,6 +230,9 @@ true -- !sql -- false +-- !sql -- +false + -- !sql -- [1, 2, 3, 4] @@ -227,6 +260,42 @@ false -- !sql -- [0] +-- !sql -- +[] + +-- !sql -- +[] + +-- !sql -- +[] + +-- !sql -- +[1, 2, 3] + +-- !sql -- +[] + +-- !sql -- +[] + +-- !sql -- +[NULL, 1, 2, 3] + +-- !sql -- +[NULL] + +-- !sql -- +[] + +-- !sql -- +[1, 100000000] + +-- !sql -- +[1] + +-- !sql -- +[] + -- !sql -- [1] diff --git a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions_by_literal.groovy b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions_by_literal.groovy index 5867d218e1..ab5cd580aa 100644 --- a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions_by_literal.groovy +++ b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions_by_literal.groovy @@ -72,6 +72,14 @@ suite("test_array_functions_by_literal") { qt_sql "select array_sum([1,2,3,null])" qt_sql "select array_min([1,2,3,null])" qt_sql "select array_max([1,2,3,null])" + qt_sql "select array_avg([])" + qt_sql "select array_sum([])" + qt_sql "select array_min([])" + qt_sql "select array_max([])" + qt_sql "select array_avg([null])" + qt_sql "select array_sum([null])" + qt_sql "select array_min([null])" + qt_sql "select array_max([null])" // array_distinct function qt_sql "select array_distinct([1,1,2,2,3,3])" @@ -79,6 +87,8 @@ suite("test_array_functions_by_literal") { qt_sql "select array_distinct(['a','a','a'])" qt_sql "select array_distinct(['a','a','a',null])" qt_sql "select array_distinct([true, false, null, false])" + qt_sql "select array_distinct([])" + qt_sql "select array_distinct([null,null])" // array_remove function @@ -103,6 +113,7 @@ suite("test_array_functions_by_literal") { qt_sql "select arrays_overlap([1,2,3], [3,4,5])" qt_sql "select arrays_overlap([1,2,3,null], [3,4,5])" qt_sql "select arrays_overlap([true], [false])" + qt_sql "select arrays_overlap([], [])" // array_binary function qt_sql "select array_union([1,2,3], [2,3,4])" @@ -114,6 +125,18 @@ suite("test_array_functions_by_literal") { qt_sql "select array_union([true], [false])" qt_sql "select array_except([true, false], [true])" qt_sql "select array_intersect([false, true], [false])" + qt_sql "select array_union([], [])" + qt_sql "select array_except([], [])" + qt_sql "select array_intersect([], [])" + qt_sql "select array_union([], [1,2,3])" + qt_sql "select array_except([], [1,2,3])" + qt_sql "select array_intersect([], [1,2,3])" + qt_sql "select array_union([null], [1,2,3])" + qt_sql "select array_except([null], [1,2,3])" + qt_sql "select array_intersect([null], [1,2,3])" + qt_sql "select array_union([1], [100000000])" + qt_sql "select array_except([1], [100000000])" + qt_sql "select array_intersect([1], [100000000])" // arrat_slice function qt_sql "select [1,2,3][1:1]"