[branch-2.1](functions) fix be crash for function random_bytes and mark_first/last_n (#36003)
pick #35884
This commit is contained in:
@ -792,10 +792,7 @@ public:
|
||||
|
||||
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
|
||||
size_t result, size_t input_rows_count) const override {
|
||||
DCHECK_GE(arguments.size(), 1);
|
||||
DCHECK_LE(arguments.size(), 2);
|
||||
|
||||
int n = -1;
|
||||
int n = -1; // means unassigned
|
||||
|
||||
auto res = ColumnString::create();
|
||||
auto col = block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
|
||||
@ -803,17 +800,20 @@ public:
|
||||
|
||||
if (arguments.size() == 2) {
|
||||
const auto& col = *block.get_by_position(arguments[1]).column;
|
||||
// the 2nd arg is const. checked in fe.
|
||||
if (col.get_int(0) < 0) [[unlikely]] {
|
||||
return Status::InvalidArgument(
|
||||
"function {} only accept non-negative input for 2nd argument but got {}",
|
||||
name, col.get_int(0));
|
||||
}
|
||||
n = col.get_int(0);
|
||||
} else if (arguments.size() > 2) {
|
||||
return Status::InvalidArgument(
|
||||
fmt::format("too many arguments for function {}", get_name()));
|
||||
}
|
||||
|
||||
if (n == -1) {
|
||||
if (n == -1) { // no 2nd arg, just mask all
|
||||
FunctionMask::vector_mask(source_column, *res, FunctionMask::DEFAULT_UPPER_MASK,
|
||||
FunctionMask::DEFAULT_LOWER_MASK,
|
||||
FunctionMask::DEFAULT_NUMBER_MASK);
|
||||
} else if (n >= 0) {
|
||||
} else { // n >= 0
|
||||
vector(source_column, n, *res);
|
||||
}
|
||||
|
||||
@ -2901,19 +2901,18 @@ public:
|
||||
|
||||
ColumnPtr argument_column =
|
||||
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
|
||||
const auto* length_col = check_and_get_column<ColumnInt32>(argument_column.get());
|
||||
|
||||
if (!length_col) {
|
||||
return Status::InternalError("Not supported input argument type");
|
||||
}
|
||||
const auto* length_col = assert_cast<const ColumnInt32*>(argument_column.get());
|
||||
|
||||
std::vector<uint8_t> random_bytes;
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
|
||||
for (size_t i = 0; i < input_rows_count; ++i) {
|
||||
UInt64 length = length_col->get64(i);
|
||||
random_bytes.resize(length);
|
||||
if (length_col->get_element(i) < 0) [[unlikely]] {
|
||||
return Status::InvalidArgument("argument {} of function {} at row {} was invalid.",
|
||||
length_col->get_element(i), name, i);
|
||||
}
|
||||
random_bytes.resize(length_col->get_element(i));
|
||||
|
||||
std::uniform_int_distribution<uint8_t> distribution(0, 255);
|
||||
for (auto& byte : random_bytes) {
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.trees.expressions.functions.scalar;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
@ -65,6 +66,13 @@ public class MaskFirstN extends ScalarFunction implements ExplicitlyCastableSign
|
||||
return new MaskFirstN(children.get(0), children.get(1));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkLegalityAfterRewrite() {
|
||||
if (arity() == 2 && !child(1).isLiteral()) {
|
||||
throw new AnalysisException("mask_first_n must accept literal for 2nd argument");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<FunctionSignature> getSignatures() {
|
||||
return SIGNATURES;
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.trees.expressions.functions.scalar;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
@ -65,6 +66,13 @@ public class MaskLastN extends ScalarFunction implements ExplicitlyCastableSigna
|
||||
return new MaskLastN(children.get(0), children.get(1));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkLegalityAfterRewrite() {
|
||||
if (arity() == 2 && !child(1).isLiteral()) {
|
||||
throw new AnalysisException("mask_last_n must accept literal for 2nd argument");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<FunctionSignature> getSignatures() {
|
||||
return SIGNATURES;
|
||||
|
||||
@ -75,4 +75,21 @@ suite("test_mask_function") {
|
||||
qt_select_digital_masking """
|
||||
select digital_masking(13812345678);
|
||||
"""
|
||||
|
||||
test {
|
||||
sql """ select mask_last_n("12345", -100); """
|
||||
exception "function mask_last_n only accept non-negative input for 2nd argument but got -100"
|
||||
}
|
||||
test {
|
||||
sql """ select mask_first_n("12345", -100); """
|
||||
exception "function mask_first_n only accept non-negative input for 2nd argument but got -100"
|
||||
}
|
||||
test {
|
||||
sql """ select mask_last_n("12345", id) from table_mask_test; """
|
||||
exception "mask_last_n must accept literal for 2nd argument"
|
||||
}
|
||||
test {
|
||||
sql """ select mask_first_n("12345", id) from table_mask_test; """
|
||||
exception "mask_first_n must accept literal for 2nd argument"
|
||||
}
|
||||
}
|
||||
|
||||
@ -101,4 +101,8 @@ suite("nereids_scalar_fn_R") {
|
||||
qt_sql_rtrim_String_String_notnull "select rtrim(kstr, '1') from fn_test_not_nullable order by kstr"
|
||||
sql "SELECT random_bytes(7);"
|
||||
qt_sql_random_bytes "SELECT random_bytes(null);"
|
||||
test {
|
||||
sql " select random_bytes(-1); "
|
||||
exception "argument -1 of function random_bytes at row 0 was invalid"
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user