[branch-2.1](functions) fix be crash for function random_bytes and mark_first/last_n (#36003)

pick #35884
This commit is contained in:
zclllyybb
2024-06-07 10:30:41 +08:00
committed by GitHub
parent c794ea18c8
commit f751ca4e04
5 changed files with 52 additions and 16 deletions

View File

@ -792,10 +792,7 @@ public:
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override {
DCHECK_GE(arguments.size(), 1);
DCHECK_LE(arguments.size(), 2);
int n = -1;
int n = -1; // means unassigned
auto res = ColumnString::create();
auto col = block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
@ -803,17 +800,20 @@ public:
if (arguments.size() == 2) {
const auto& col = *block.get_by_position(arguments[1]).column;
// the 2nd arg is const. checked in fe.
if (col.get_int(0) < 0) [[unlikely]] {
return Status::InvalidArgument(
"function {} only accept non-negative input for 2nd argument but got {}",
name, col.get_int(0));
}
n = col.get_int(0);
} else if (arguments.size() > 2) {
return Status::InvalidArgument(
fmt::format("too many arguments for function {}", get_name()));
}
if (n == -1) {
if (n == -1) { // no 2nd arg, just mask all
FunctionMask::vector_mask(source_column, *res, FunctionMask::DEFAULT_UPPER_MASK,
FunctionMask::DEFAULT_LOWER_MASK,
FunctionMask::DEFAULT_NUMBER_MASK);
} else if (n >= 0) {
} else { // n >= 0
vector(source_column, n, *res);
}
@ -2901,19 +2901,18 @@ public:
ColumnPtr argument_column =
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
const auto* length_col = check_and_get_column<ColumnInt32>(argument_column.get());
if (!length_col) {
return Status::InternalError("Not supported input argument type");
}
const auto* length_col = assert_cast<const ColumnInt32*>(argument_column.get());
std::vector<uint8_t> random_bytes;
std::random_device rd;
std::mt19937 gen(rd());
for (size_t i = 0; i < input_rows_count; ++i) {
UInt64 length = length_col->get64(i);
random_bytes.resize(length);
if (length_col->get_element(i) < 0) [[unlikely]] {
return Status::InvalidArgument("argument {} of function {} at row {} was invalid.",
length_col->get_element(i), name, i);
}
random_bytes.resize(length_col->get_element(i));
std::uniform_int_distribution<uint8_t> distribution(0, 255);
for (auto& byte : random_bytes) {

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
@ -65,6 +66,13 @@ public class MaskFirstN extends ScalarFunction implements ExplicitlyCastableSign
return new MaskFirstN(children.get(0), children.get(1));
}
@Override
public void checkLegalityAfterRewrite() {
if (arity() == 2 && !child(1).isLiteral()) {
throw new AnalysisException("mask_first_n must accept literal for 2nd argument");
}
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
@ -65,6 +66,13 @@ public class MaskLastN extends ScalarFunction implements ExplicitlyCastableSigna
return new MaskLastN(children.get(0), children.get(1));
}
@Override
public void checkLegalityAfterRewrite() {
if (arity() == 2 && !child(1).isLiteral()) {
throw new AnalysisException("mask_last_n must accept literal for 2nd argument");
}
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;

View File

@ -75,4 +75,21 @@ suite("test_mask_function") {
qt_select_digital_masking """
select digital_masking(13812345678);
"""
test {
sql """ select mask_last_n("12345", -100); """
exception "function mask_last_n only accept non-negative input for 2nd argument but got -100"
}
test {
sql """ select mask_first_n("12345", -100); """
exception "function mask_first_n only accept non-negative input for 2nd argument but got -100"
}
test {
sql """ select mask_last_n("12345", id) from table_mask_test; """
exception "mask_last_n must accept literal for 2nd argument"
}
test {
sql """ select mask_first_n("12345", id) from table_mask_test; """
exception "mask_first_n must accept literal for 2nd argument"
}
}

View File

@ -101,4 +101,8 @@ suite("nereids_scalar_fn_R") {
qt_sql_rtrim_String_String_notnull "select rtrim(kstr, '1') from fn_test_not_nullable order by kstr"
sql "SELECT random_bytes(7);"
qt_sql_random_bytes "SELECT random_bytes(null);"
test {
sql " select random_bytes(-1); "
exception "argument -1 of function random_bytes at row 0 was invalid"
}
}