add defensive check for ObExprWeightString

This commit is contained in:
AntiTopQuark
2024-04-10 10:15:57 +00:00
committed by ob-robot
parent 6a13cb1f82
commit ce6ee9cf55

View File

@ -43,9 +43,10 @@ int ObExprWeightString::calc_result_typeN(ObExprResType &type,
ObExprTypeCtx &type_ctx) const ObExprTypeCtx &type_ctx) const
{ {
int ret = OB_SUCCESS; int ret = OB_SUCCESS;
UNUSED(param_num); CK (5 == param_num);
CK (OB_NOT_NULL(type_ctx.get_session())); CK (OB_NOT_NULL(type_ctx.get_session()));
if (NOT_ROW_DIMENSION != row_dimension_ || ObMaxType == types_stack[0].get_type()) { if (OB_FAIL(ret)) {
} else if (NOT_ROW_DIMENSION != row_dimension_ || ObMaxType == types_stack[0].get_type()) {
ret = OB_ERR_INVALID_TYPE_FOR_OP; ret = OB_ERR_INVALID_TYPE_FOR_OP;
} else { } else {
uint64_t max_length = OB_MAX_VARBINARY_LENGTH; // The maximum length of the result of WEIGHT_STRING() uint64_t max_length = OB_MAX_VARBINARY_LENGTH; // The maximum length of the result of WEIGHT_STRING()
@ -70,25 +71,28 @@ int ObExprWeightString::calc_result_typeN(ObExprResType &type,
} }
ObCollationType collation_type = types_stack[0].get_collation_type(); ObCollationType collation_type = types_stack[0].get_collation_type();
const ObCharsetInfo *cs = ObCharset::get_charset(collation_type); const ObCharsetInfo *cs = ObCharset::get_charset(collation_type);
if (types_stack[0].get_type() == ObDateTimeType || CK (OB_NOT_NULL(cs));
types_stack[0].get_type() == ObTimestampType || if (OB_SUCC(ret)) {
types_stack[0].get_type() == ObDateType || if (types_stack[0].get_type() == ObDateTimeType ||
types_stack[0].get_type() == ObTimeType ) { types_stack[0].get_type() == ObTimestampType ||
// For types such as date, time, etc., the max_lenght is the length of the type entered types_stack[0].get_type() == ObDateType ||
max_length = types_stack[0].get_length(); types_stack[0].get_type() == ObTimeType ) {
} else if (result_length > 0) { // For types such as date, time, etc., the max_lenght is the length of the type entered
max_length = result_length; max_length = types_stack[0].get_length();
} else if (as_binary) { } else if (result_length > 0) {
// In the case of as_binary, the max_length with nweight as the output result max_length = result_length;
max_length = nweight; } else if (as_binary) {
} else { // In the case of as_binary, the max_length with nweight as the output result
// If the input is others, use cs->mbmaxlen to calculate the max_length max_length = nweight;
max_length = cs->mbmaxlen * MAX(nweight, types_stack[0].get_length()*cs->mbmaxlen); } else {
// If the input is others, use cs->mbmaxlen to calculate the max_length
max_length = cs->mbmaxlen * MAX(nweight, types_stack[0].get_length()*cs->mbmaxlen);
}
type.set_varchar();
type.set_collation_type(CS_TYPE_BINARY);
type.set_collation_level(coll_level);
type.set_length(max_length);
} }
type.set_varchar();
type.set_collation_type(CS_TYPE_BINARY);
type.set_collation_level(coll_level);
type.set_length(max_length);
} }
return ret; return ret;
} }
@ -101,7 +105,11 @@ int ObExprWeightString::eval_weight_string(const ObExpr &expr, ObEvalCtx &ctx, O
ObDatum *nweights_arg = NULL; ObDatum *nweights_arg = NULL;
ObDatum *flags_arg = NULL; ObDatum *flags_arg = NULL;
ObDatum *as_binary_arg = NULL; ObDatum *as_binary_arg = NULL;
if (OB_FAIL(expr.args_[0]->eval(ctx, arg)) || if (OB_ISNULL(expr.args_[0]) || OB_ISNULL(expr.args_[1]) || OB_ISNULL(expr.args_[2]) ||
OB_ISNULL(expr.args_[3]) || OB_ISNULL(expr.args_[4])) {
ret = OB_INVALID_ARGUMENT;
LOG_WARN("invalid argument", K(ret));
} else if (OB_FAIL(expr.args_[0]->eval(ctx, arg)) ||
OB_FAIL(expr.args_[1]->eval(ctx, result_length_arg)) || OB_FAIL(expr.args_[1]->eval(ctx, result_length_arg)) ||
OB_FAIL(expr.args_[2]->eval(ctx, nweights_arg)) || OB_FAIL(expr.args_[2]->eval(ctx, nweights_arg)) ||
OB_FAIL(expr.args_[3]->eval(ctx, flags_arg)) || OB_FAIL(expr.args_[3]->eval(ctx, flags_arg)) ||
@ -137,64 +145,70 @@ int ObExprWeightString::eval_weight_string(const ObExpr &expr, ObEvalCtx &ctx, O
} else { } else {
LOG_WARN("Failed to get max allow packet size", K(ret)); LOG_WARN("Failed to get max allow packet size", K(ret));
} }
}
// Get the character set and collation information of the input string
ObCollationType collation_type = CS_TYPE_INVALID;
if (as_binary) {
collation_type = CS_TYPE_BINARY;
} else { } else {
collation_type = expr.args_[0]->datum_meta_.cs_type_;
}
const ObCharsetInfo *cs = ObCharset::get_charset(collation_type); // Get the character set and collation information of the input string
flags = ob_strxfrm_flag_normalize(flags, cs->levels_for_order); ObCollationType collation_type = CS_TYPE_INVALID;
// calc the length of result if (as_binary) {
size_t frm_length = 0; collation_type = CS_TYPE_BINARY;
size_t tmp_length = 0;
if (result_length > 0) {
tmp_length = result_length;
} else {
tmp_length = cs->coll->strnxfrmlen(cs, cs->mbmaxlen*MAX(str.length() , nweights));
}
if (tmp_length >= max_allowed_packet) {
// The return result exceeds the maximum limit and returns NULL.
res_datum.set_null();
} else {
int used_nweights = nweights;
size_t input_length = str.length();
if (used_nweights) {
//truncate input string
input_length = std::min(input_length, cs->cset->charpos(cs, str.ptr(), str.ptr() + str.length(), nweights));
} else { } else {
//calc char length collation_type = expr.args_[0]->datum_meta_.cs_type_;
used_nweights = cs->cset->numchars(cs, str.ptr(), str.ptr() + str.length());
} }
bool is_valid_unicode_tmp = 1; const ObCharsetInfo *cs = ObCharset::get_charset(collation_type);
char *out_buf = expr.get_str_res_mem(ctx, tmp_length); CK (OB_NOT_NULL(cs));
// For the case where the input is an empty string but the nweight is not 0, if (OB_SUCC(ret)) {
// the weight_string function will call strnxfrm() to padding the result flags = ob_strxfrm_flag_normalize(flags, cs->levels_for_order);
// eg: // calc the length of result
// mysql> select HEX(WEIGHT_STRING('' as char(3))); size_t frm_length = 0;
// +-----------------------------------+ size_t tmp_length = 0;
// | HEX(WEIGHT_STRING('' as char(3))) | if (result_length > 0) {
// +-----------------------------------+ tmp_length = result_length;
// | 002000200020 | } else {
// +-----------------------------------+ tmp_length = cs->coll->strnxfrmlen(cs, cs->mbmaxlen*MAX(str.length() , nweights));
// However, the strnxfrm requires that the input cannot be a null ptr, }
// so an empty string is set as the input. if (tmp_length >= max_allowed_packet) {
const char* tmp_empty_str = ""; // The return result exceeds the maximum limit and returns NULL.
if (OB_ISNULL(out_buf)) { res_datum.set_null();
ret = OB_ALLOCATE_MEMORY_FAILED; } else {
LOG_WARN("failed to alloc output buf",K(ret)); int used_nweights = nweights;
} else { size_t input_length = str.length();
frm_length = cs->coll->strnxfrm(cs, if (used_nweights) {
reinterpret_cast<uchar *>(out_buf), //truncate input string
tmp_length, input_length = std::min(input_length, cs->cset->charpos(cs, str.ptr(), str.ptr() + str.length(), nweights));
used_nweights, } else {
str.ptr() != NULL? reinterpret_cast<const uchar *>(str.ptr()) : reinterpret_cast<const uchar *>(tmp_empty_str), //calc char length
input_length, used_nweights = cs->cset->numchars(cs, str.ptr(), str.ptr() + str.length());
flags, }
&is_valid_unicode_tmp); bool is_valid_unicode_tmp = 1;
res_datum.set_string(out_buf,frm_length); char *out_buf = expr.get_str_res_mem(ctx, tmp_length);
// For the case where the input is an empty string but the nweight is not 0,
// the weight_string function will call strnxfrm() to padding the result
// eg:
// mysql> select HEX(WEIGHT_STRING('' as char(3)));
// +-----------------------------------+
// | HEX(WEIGHT_STRING('' as char(3))) |
// +-----------------------------------+
// | 002000200020 |
// +-----------------------------------+
// However, the strnxfrm requires that the input cannot be a null ptr,
// so an empty string is set as the input.
const char* tmp_empty_str = "";
if (OB_ISNULL(out_buf)) {
ret = OB_ALLOCATE_MEMORY_FAILED;
LOG_WARN("failed to alloc output buf",K(ret));
} else {
frm_length = cs->coll->strnxfrm(cs,
reinterpret_cast<uchar *>(out_buf),
tmp_length,
used_nweights,
str.ptr() != NULL? reinterpret_cast<const uchar *>(str.ptr()) : reinterpret_cast<const uchar *>(tmp_empty_str),
input_length,
flags,
&is_valid_unicode_tmp);
res_datum.set_string(out_buf,frm_length);
}
}
} }
} }
} }