[CP] [CP][to #44746355]fix size overflow when reconstruct sql

This commit is contained in:
seuwebber 2022-12-29 10:45:38 +00:00 committed by ob-robot
parent 46b0f9b176
commit a2610ab94c
2 changed files with 125 additions and 74 deletions

View File

@ -1968,83 +1968,127 @@ int ObSQLUtils::reconstruct_sql(ObIAllocator &allocator, const ObStmt *stmt, ObS
ret = OB_ERR_UNEXPECTED;
LOG_WARN("Unexpected stmt type", K(stmt->get_stmt_type()), K(stmt->get_query_ctx()->get_sql_stmt()), K(ret));
} else {
//First try 64K buf on the stack, if it fails, then try 128K.
//If it still fails, allocate 256K from the heap. If it continues to fail, expand twice each time.
SMART_VAR(char[OB_MAX_SQL_LENGTH], buf) {
int64_t buf_len = OB_MAX_SQL_LENGTH;
int64_t pos = 0;
switch (stmt->get_stmt_type()) {
case stmt::T_SELECT: {
bool is_set_subquery = false;
ObSelectStmtPrinter printer(buf, buf_len, &pos, static_cast<const ObSelectStmt*>(stmt),
print_params, NULL, is_set_subquery);
if (OB_FAIL(printer.do_print())) {
LOG_WARN("fail to print select stmt", K(ret));
} else if (OB_FAIL(ob_write_string(allocator, ObString(pos, buf), sql))) {
LOG_WARN("fail to deep copy select stmt string", K(ret));
} else { /*do nothing*/ }
}
break;
case stmt::T_INSERT_ALL: {
ObInsertAllStmtPrinter printer(buf, buf_len, &pos,
static_cast<const ObInsertAllStmt*>(stmt), print_params);
if (OB_FAIL(printer.do_print())) {
LOG_WARN("fail to print insert stmt", K(ret));
} else if (OB_FAIL(ob_write_string(allocator, ObString(pos, buf), sql))) {
LOG_WARN("fail to deep copy insert stmt string", K(ret));
} else { /*do nothing*/ }
}
break;
case stmt::T_INSERT: {
ObInsertStmtPrinter printer(buf, buf_len, &pos,
static_cast<const ObInsertStmt*>(stmt), print_params);
if (OB_FAIL(printer.do_print())) {
LOG_WARN("fail to print insert stmt", K(ret));
} else if (OB_FAIL(ob_write_string(allocator, ObString(pos, buf), sql))) {
LOG_WARN("fail to deep copy insert stmt string", K(ret));
} else { /*do nothing*/ }
}
break;
case stmt::T_REPLACE: {
}
break;
case stmt::T_DELETE: {
ObDeleteStmtPrinter printer(buf, buf_len, &pos,
static_cast<const ObDeleteStmt*>(stmt), print_params);
if (OB_FAIL(printer.do_print())) {
LOG_WARN("fail to print delete stmt", K(ret));
} else if (OB_FAIL(ob_write_string(allocator, ObString(pos, buf), sql))) {
LOG_WARN("fail to deep copy delete stmt string", K(ret));
} else { /*do nothing*/ }
}
break;
case stmt::T_UPDATE: {
ObUpdateStmtPrinter printer(buf, buf_len, &pos,
static_cast<const ObUpdateStmt*>(stmt), print_params);
if (OB_FAIL(printer.do_print())) {
LOG_WARN("fail to print update stmt", K(ret));
} else if (OB_FAIL(ob_write_string(allocator, ObString(pos, buf), sql))) {
LOG_WARN("fail to deep copy update stmt string", K(ret));
} else { /*do nothing*/ }
}
break;
case stmt::T_MERGE: {
ObMergeStmtPrinter printer(buf, buf_len, &pos,
static_cast<const ObMergeStmt*>(stmt), print_params);
if (OB_FAIL(printer.do_print())) {
LOG_WARN("failed to print merge stmt", K(ret));
} else if (OB_FAIL(ob_write_string(allocator, ObString(pos, buf), sql))) {
LOG_WARN("failed to deep copy merge stmt string", K(ret));
}
}
break;
default: {
ret = OB_NOT_SUPPORTED;
LOG_WARN("Invalid stmt type", K(stmt->get_stmt_type()), K(stmt->get_query_ctx()->get_sql_stmt()), K(ret));
LOG_USER_ERROR(OB_NOT_SUPPORTED, "stmt type");
}
break;
if (OB_FAIL(print_sql(allocator, buf, sizeof(buf), stmt, sql, print_params))) {
LOG_WARN("failed to print sql", K(sizeof(buf)), K(ret));
}
}
if (OB_SIZE_OVERFLOW == ret) {
ret = OB_SUCCESS;
SMART_VAR(char[OB_MAX_SQL_LENGTH * 2], buf) {
if (OB_FAIL(print_sql(allocator, buf, sizeof(buf), stmt, sql, print_params))) {
LOG_WARN("failed to print sql", K(sizeof(buf)), K(ret));
}
}
}
if (OB_SIZE_OVERFLOW == ret) {
bool is_succ = false;
ret = OB_SUCCESS;
for (int64_t i = 4; OB_SUCC(ret) && !is_succ && i <= 1024; i = i * 2) {
ObArenaAllocator alloc;
const int64_t length = OB_MAX_SQL_LENGTH * i;
char *buf = NULL;
if (OB_ISNULL(buf = static_cast<char*>(alloc.alloc(length)))) {
ret = OB_ALLOCATE_MEMORY_FAILED;
LOG_WARN("failed to alloc memory for set sql", K(ret), K(length));
} else if (OB_FAIL(print_sql(allocator, buf, length, stmt, sql, print_params))) {
LOG_WARN("failed to print sql", K(length), K(i), K(ret));
}
if (OB_SUCC(ret)) {
is_succ = true;
} else if (OB_SIZE_OVERFLOW == ret) {
ret = OB_SUCCESS;
}
}
}
}
return ret;
}
int ObSQLUtils::print_sql(ObIAllocator &allocator,
char *buf,
int64_t buf_len,
const ObStmt *stmt,
ObString &sql,
ObObjPrintParams print_params)
{
int ret = OB_SUCCESS;
MEMSET(buf, 0, buf_len);
int64_t pos = 0;
switch (stmt->get_stmt_type()) {
case stmt::T_SELECT: {
bool is_set_subquery = false;
ObSelectStmtPrinter printer(buf, buf_len, &pos, static_cast<const ObSelectStmt*>(stmt),
print_params, NULL, is_set_subquery);
if (OB_FAIL(printer.do_print())) {
LOG_WARN("fail to print select stmt", K(ret));
} else if (OB_FAIL(ob_write_string(allocator, ObString(pos, buf), sql))) {
LOG_WARN("fail to deep copy select stmt string", K(ret));
} else { /*do nothing*/ }
}
break;
case stmt::T_INSERT_ALL: {
ObInsertAllStmtPrinter printer(buf, buf_len, &pos,
static_cast<const ObInsertAllStmt*>(stmt), print_params);
if (OB_FAIL(printer.do_print())) {
LOG_WARN("fail to print insert stmt", K(ret));
} else if (OB_FAIL(ob_write_string(allocator, ObString(pos, buf), sql))) {
LOG_WARN("fail to deep copy insert stmt string", K(ret));
} else { /*do nothing*/ }
}
break;
case stmt::T_INSERT: {
ObInsertStmtPrinter printer(buf, buf_len, &pos,
static_cast<const ObInsertStmt*>(stmt), print_params);
if (OB_FAIL(printer.do_print())) {
LOG_WARN("fail to print insert stmt", K(ret));
} else if (OB_FAIL(ob_write_string(allocator, ObString(pos, buf), sql))) {
LOG_WARN("fail to deep copy insert stmt string", K(ret));
} else { /*do nothing*/ }
}
break;
case stmt::T_REPLACE: {
}
break;
case stmt::T_DELETE: {
ObDeleteStmtPrinter printer(buf, buf_len, &pos,
static_cast<const ObDeleteStmt*>(stmt), print_params);
if (OB_FAIL(printer.do_print())) {
LOG_WARN("fail to print delete stmt", K(ret));
} else if (OB_FAIL(ob_write_string(allocator, ObString(pos, buf), sql))) {
LOG_WARN("fail to deep copy delete stmt string", K(ret));
} else { /*do nothing*/ }
}
break;
case stmt::T_UPDATE: {
ObUpdateStmtPrinter printer(buf, buf_len, &pos,
static_cast<const ObUpdateStmt*>(stmt), print_params);
if (OB_FAIL(printer.do_print())) {
LOG_WARN("fail to print update stmt", K(ret));
} else if (OB_FAIL(ob_write_string(allocator, ObString(pos, buf), sql))) {
LOG_WARN("fail to deep copy update stmt string", K(ret));
} else { /*do nothing*/ }
}
break;
case stmt::T_MERGE: {
ObMergeStmtPrinter printer(buf, buf_len, &pos,
static_cast<const ObMergeStmt*>(stmt), print_params);
if (OB_FAIL(printer.do_print())) {
LOG_WARN("failed to print merge stmt", K(ret));
} else if (OB_FAIL(ob_write_string(allocator, ObString(pos, buf), sql))) {
LOG_WARN("failed to deep copy merge stmt string", K(ret));
}
}
break;
default: {
ret = OB_NOT_SUPPORTED;
LOG_WARN("Invalid stmt type", K(stmt->get_stmt_type()), K(stmt->get_query_ctx()->get_sql_stmt()), K(ret));
LOG_USER_ERROR(OB_NOT_SUPPORTED, "stmt type");
}
break;
}
return ret;
}

View File

@ -364,6 +364,13 @@ public:
static int reconstruct_sql(ObIAllocator &allocator, const ObStmt *stmt, ObString &sql,
ObObjPrintParams print_params = ObObjPrintParams());
static int print_sql(ObIAllocator &allocator,
char *buf,
int64_t buf_len,
const ObStmt *stmt,
ObString &sql,
ObObjPrintParams print_params);
static int wrap_expr_ctx(const stmt::StmtType &stmt_type,
ObExecContext &exec_ctx,