diff --git a/be/src/pipeline/exec/union_source_operator.cpp b/be/src/pipeline/exec/union_source_operator.cpp index 591e9b04f9..c9deae0d03 100644 --- a/be/src/pipeline/exec/union_source_operator.cpp +++ b/be/src/pipeline/exec/union_source_operator.cpp @@ -169,14 +169,14 @@ Status UnionSourceOperatorX::get_block(RuntimeState* state, vectorized::Block* b SCOPED_TIMER(local_state.exec_time_counter()); if (local_state._need_read_for_const_expr) { if (has_more_const(state)) { - static_cast(get_next_const(state, block)); + RETURN_IF_ERROR(get_next_const(state, block)); } local_state._need_read_for_const_expr = has_more_const(state); } else { std::unique_ptr output_block = vectorized::Block::create_unique(); int child_idx = 0; - static_cast(local_state._shared_state->data_queue.get_block_from_queue(&output_block, - &child_idx)); + RETURN_IF_ERROR(local_state._shared_state->data_queue.get_block_from_queue(&output_block, + &child_idx)); if (!output_block) { return Status::OK(); } diff --git a/be/src/vec/functions/random.cpp b/be/src/vec/functions/random.cpp index 14580dbf82..564a51d932 100644 --- a/be/src/vec/functions/random.cpp +++ b/be/src/vec/functions/random.cpp @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#include #include #include @@ -55,6 +56,9 @@ public: bool is_variadic() const override { return true; } DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + if (arguments.size() == 2) { + return std::make_shared(); + } return std::make_shared(); } @@ -74,6 +78,7 @@ public: } generator->seed(seed); } else { + // 0 or 2 args generator->seed(std::random_device()()); } } @@ -83,12 +88,63 @@ public: Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) const override { + if (arguments.size() == 2) { + return _execute_int_range(context, block, arguments, result, input_rows_count); + } + return _execute_float(context, block, arguments, result, input_rows_count); + } + + Status close(FunctionContext* context, FunctionContext::FunctionStateScope scope) override { + return Status::OK(); + } + +private: + static Status _execute_int_range(FunctionContext* context, Block& block, + const ColumnNumbers& arguments, size_t result, + size_t input_rows_count) { + auto res_column = ColumnInt64::create(input_rows_count); + auto& res_data = static_cast(*res_column).get_data(); + + auto* generator = reinterpret_cast( + context->get_function_state(FunctionContext::THREAD_LOCAL)); + DCHECK(generator != nullptr); + + Int64 min = assert_cast( + assert_cast( + block.get_by_position(arguments[0]).column.get()) + ->get_data_column_ptr() + .get()) + ->get_element(0); + Int64 max = assert_cast( + assert_cast( + block.get_by_position(arguments[1]).column.get()) + ->get_data_column_ptr() + .get()) + ->get_element(0); + if (min >= max) { + return Status::InvalidArgument(fmt::format( + "random's lower bound should less than upper bound, but got [{}, {})", min, + max)); + } + + std::uniform_int_distribution distribution(min, max); + for (int i = 0; i < input_rows_count; i++) { + res_data[i] = distribution(*generator); + } + + block.replace_by_position(result, std::move(res_column)); + return Status::OK(); + } + + static Status _execute_float(FunctionContext* context, Block& block, + const ColumnNumbers& arguments, size_t result, + size_t input_rows_count) { static const double min = 0.0; static const double max = 1.0; auto res_column = ColumnFloat64::create(input_rows_count); - auto& res_data = assert_cast(*res_column).get_data(); + auto& res_data = static_cast(*res_column).get_data(); - std::mt19937_64* generator = reinterpret_cast( + auto* generator = reinterpret_cast( context->get_function_state(FunctionContext::THREAD_LOCAL)); DCHECK(generator != nullptr); @@ -100,10 +156,6 @@ public: block.replace_by_position(result, std::move(res_column)); return Status::OK(); } - - Status close(FunctionContext* context, FunctionContext::FunctionStateScope scope) override { - return Status::OK(); - } }; void register_function_random(SimpleFunctionFactory& factory) { diff --git a/be/src/vec/functions/simple_function_factory.h b/be/src/vec/functions/simple_function_factory.h index 66cd5c67ba..0992239e99 100644 --- a/be/src/vec/functions/simple_function_factory.h +++ b/be/src/vec/functions/simple_function_factory.h @@ -157,7 +157,7 @@ public: int be_version = BeExecVersionManager::get_newest_version()) { std::string key_str = name; - if (function_alias.count(name)) { + if (function_alias.contains(name)) { key_str = function_alias[name]; } @@ -165,7 +165,7 @@ public: // if function is variadic, added types_str as key if (function_variadic_set.count(key_str)) { - for (auto& arg : arguments) { + for (const auto& arg : arguments) { key_str.append(arg.type->is_nullable() ? reinterpret_cast(arg.type.get()) ->get_nested_type() diff --git a/docs/en/docs/sql-manual/sql-functions/numeric-functions/random.md b/docs/en/docs/sql-manual/sql-functions/numeric-functions/random.md index 79f266b27e..53e3ba4158 100644 --- a/docs/en/docs/sql-manual/sql-functions/numeric-functions/random.md +++ b/docs/en/docs/sql-manual/sql-functions/numeric-functions/random.md @@ -28,18 +28,50 @@ under the License. #### Syntax `DOUBLE random()` -Returns a random number between 0-1. +Returns a random number between 0 and 1. + +`DOUBLE random(DOUBLE seed)` +Returns a random number between 0 and 1, seeded with `seed`. + +`BIGINT random(BIGINT a, BIGINT b)` +Returns a random number between a and b. a must be less than b. + +Alias: `rand`. ### example -``` +```sql mysql> select random(); +---------------------+ | random() | +---------------------+ | 0.35446706030596947 | +---------------------+ + +mysql> select rand(1.2); ++---------------------+ +| rand(1) | ++---------------------+ +| 0.13387664401253274 | ++---------------------+ +1 row in set (0.13 sec) + +mysql> select rand(1.2); ++---------------------+ +| rand(1) | ++---------------------+ +| 0.13387664401253274 | ++---------------------+ +1 row in set (0.11 sec) + +mysql> select rand(-20, -10); ++------------------+ +| random(-20, -10) | ++------------------+ +| -13 | ++------------------+ +1 row in set (0.10 sec) ``` ### keywords - RANDOM + RANDOM, RAND diff --git a/docs/zh-CN/docs/sql-manual/sql-functions/numeric-functions/random.md b/docs/zh-CN/docs/sql-manual/sql-functions/numeric-functions/random.md index 2ba988f5c7..36442afffd 100644 --- a/docs/zh-CN/docs/sql-manual/sql-functions/numeric-functions/random.md +++ b/docs/zh-CN/docs/sql-manual/sql-functions/numeric-functions/random.md @@ -28,18 +28,50 @@ under the License. #### Syntax `DOUBLE random()` -返回0-1的随机数。 +返回0-1之间的随机数。 + +`DOUBLE random(DOUBLE seed)` +返回0-1之间的随机数,以`seed`作为种子。 + +`BIGINT random(BIGINT a, BIGINT b)` +返回a-b之间的随机数,a必须小于b。 + +别名:`rand` ### example -``` +```sql mysql> select random(); +---------------------+ | random() | +---------------------+ | 0.35446706030596947 | +---------------------+ + +mysql> select rand(1.2); ++---------------------+ +| rand(1) | ++---------------------+ +| 0.13387664401253274 | ++---------------------+ +1 row in set (0.13 sec) + +mysql> select rand(1.2); ++---------------------+ +| rand(1) | ++---------------------+ +| 0.13387664401253274 | ++---------------------+ +1 row in set (0.11 sec) + +mysql> select rand(-20, -10); ++------------------+ +| random(-20, -10) | ++------------------+ +| -13 | ++------------------+ +1 row in set (0.10 sec) ``` ### keywords - RANDOM + RANDOM, RAND diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Random.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Random.java index 29530adfa0..a7f3a360a6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Random.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Random.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.functions.Nondeterministic; @@ -39,7 +40,8 @@ public class Random extends ScalarFunction public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(DoubleType.INSTANCE).args(), - FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE) + FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE), + FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE, BigIntType.INSTANCE) ); /** @@ -58,6 +60,17 @@ public class Random extends ScalarFunction Preconditions.checkState(arg instanceof Literal, "The param of rand function must be literal"); } + /** + * constructor with 2 argument. + */ + public Random(Expression lchild, Expression rchild) { + super("random", lchild, rchild); + // align with original planner behavior, refer to: + // org/apache/doris/analysis/Expr.getBuiltinFunction() + Preconditions.checkState(lchild instanceof Literal && rchild instanceof Literal, + "The param of rand function must be literal"); + } + /** * custom compute nullable. */ @@ -80,13 +93,14 @@ public class Random extends ScalarFunction */ @Override public Random withChildren(List children) { - Preconditions.checkArgument(children.size() == 0 - || children.size() == 1); - if (children.isEmpty() && arity() == 0) { - return this; - } else { + if (children.isEmpty()) { + return new Random(); + } else if (children.size() == 1) { return new Random(children.get(0)); + } else if (children.size() == 2) { + return new Random(children.get(0), children.get(1)); } + throw new AnalysisException("random function only accept 0-2 arguments"); } @Override diff --git a/gensrc/script/doris_builtins_functions.py b/gensrc/script/doris_builtins_functions.py index 0bcb5e080a..0e4bd56921 100644 --- a/gensrc/script/doris_builtins_functions.py +++ b/gensrc/script/doris_builtins_functions.py @@ -1373,6 +1373,7 @@ visible_functions = { [['radians'], 'DOUBLE', ['DOUBLE'], ''], [['rand', 'random'], 'DOUBLE', [], 'ALWAYS_NOT_NULLABLE'], [['rand', 'random'], 'DOUBLE', ['BIGINT'], ''], + [['rand', 'random'], 'BIGINT', ['BIGINT', 'BIGINT'], ''], [['round', 'dround'], 'DOUBLE', ['DOUBLE'], ''], [['round', 'dround'], 'DOUBLE', ['DOUBLE', 'INT'], ''], [['round', 'dround'], 'DECIMAL32', ['DECIMAL32'], ''], diff --git a/regression-test/suites/query_p0/system/test_query_sys.groovy b/regression-test/suites/query_p0/system/test_query_sys.groovy index a87bad8094..0be52a301d 100644 --- a/regression-test/suites/query_p0/system/test_query_sys.groovy +++ b/regression-test/suites/query_p0/system/test_query_sys.groovy @@ -27,6 +27,12 @@ suite("test_query_sys", "query,p0") { sql "select rand(20);" sql "select random();" sql "select random(20);" + sql "select rand(1, 10);" + sql "select random(-5, -3);" + test{ + sql "select rand(10,1);" + exception "random's lower bound should less than upper bound" + } sql "SELECT CONNECTION_ID();" sql "SELECT CURRENT_USER();" sql "SELECT CURRENT_CATALOG();"