From ebc2cb5420eedbd923b363941b134ce20dd3fc8b Mon Sep 17 00:00:00 2001 From: Alexey Kopytov Date: Sat, 14 Jan 2017 20:12:27 +0300 Subject: [PATCH] Export prepared statements API to Lua. Fix function name in an error message. Fix prepared statements for string parameters. --- sysbench/drivers/mysql/drv_mysql.c | 7 +- sysbench/lua/internal/sysbench.sql.lua | 134 ++++++++++++++++++++++--- tests/t/api_sql.t | 30 ++++-- 3 files changed, 147 insertions(+), 24 deletions(-) diff --git a/sysbench/drivers/mysql/drv_mysql.c b/sysbench/drivers/mysql/drv_mysql.c index b8c3842..42ff085 100644 --- a/sysbench/drivers/mysql/drv_mysql.c +++ b/sysbench/drivers/mysql/drv_mysql.c @@ -281,7 +281,10 @@ int mysql_drv_thread_init(int thread_id) { (void) thread_id; /* unused */ - return mysql_thread_init() != 0; + const my_bool rc = mysql_thread_init(); + DEBUG("mysql_thread_init() = %d", (int) rc); + + return rc != 0; } /* Thread-local driver deinitialization */ @@ -290,6 +293,7 @@ int mysql_drv_thread_done(int thread_id) { (void) thread_id; /* unused */ + DEBUG("mysql_thread_end(%s)", ""); mysql_thread_end(); return 0; @@ -805,6 +809,7 @@ db_error_t mysql_drv_execute(db_stmt_t *stmt, db_result_t *rs) &rs->stat_type); } + rs->stat_type = stmt->stat_type; rs->nrows = (uint32_t) mysql_stmt_num_rows(stmt->ptr); DEBUG("mysql_stmt_num_rows(%p) = %u", rs->statement->ptr, (unsigned) (rs->nrows)); diff --git a/sysbench/lua/internal/sysbench.sql.lua b/sysbench/lua/internal/sysbench.sql.lua index 945c498..650ff3f 100644 --- a/sysbench/lua/internal/sysbench.sql.lua +++ b/sysbench/lua/internal/sysbench.sql.lua @@ -78,14 +78,17 @@ sql_result *db_query(sql_connection *con, const char *query, size_t len); sql_statement *db_prepare(sql_connection *con, const char *query, size_t len); int db_bind_param(sql_statement *stmt, sql_bind *params, size_t len); int db_bind_result(sql_statement *stmt, sql_bind *results, size_t len); -sql_result db_execute(sql_statement *stmt); +sql_result *db_execute(sql_statement *stmt); int db_close(sql_statement *stmt); + +int db_free_results(sql_result *); ]] local sql_driver = ffi.typeof('sql_driver *') local sql_connection = ffi.typeof('sql_connection *') local sql_statement = ffi.typeof('sql_statement *') local sql_bind = ffi.typeof('sql_bind'); +local sql_result = ffi.typeof('sql_result'); sysbench.sql.type = { @@ -104,10 +107,10 @@ sysbench.sql.type = VARCHAR = ffi.C.SQL_TYPE_VARCHAR } -local function check_type(type, var, func) - if var == nil or not ffi.istype(type, var) then +local function check_type(vtype, var, func) + if var == nil or not ffi.istype(vtype, var) then error(string.format("bad argument '%s' to %s() where a '%s' was expected", - var, func, type)) + var, func, vtype)) end end @@ -158,24 +161,111 @@ end function sysbench.sql.prepare(con, query) check_type(sql_connection, con, 'sysbench.sql.prepare') - return ffi.C.db_prepare(con, query, #query) + local stmt = ffi.C.db_prepare(con, query, #query) + return stmt end -function sysbench.sql.bind_param(stmt, params) - local len = #params - local sql_params = {} +function sysbench.sql.bind_create(stmt, btype, maxlen) + local sql_type = sysbench.sql.type + local buf, buflen, datalen, isnull + + check_type(sql_statement, stmt, 'sysbench.sql.bind_create') + + if btype == sql_type.TINYINT or + btype == sql_type.SMALLINT or + btype == sql_type.INT or + btype == sql_type.BIGINT + then + btype = sql_type.BIGINT + buf = ffi.new('int64_t[1]') + buflen = 8 + elseif btype == sql_type.FLOAT or + btype == sql_type.DOUBLE + then + btype = sql_type.DOUBLE + buf = ffi.new('double[1]') + buflen = 8 + elseif btype == sql_type.CHAR or + btype == sql_type.VARCHAR + then + btype = sql_type.VARCHAR + buf = ffi.new('char[?]', maxlen) + buflen = maxlen + else + error("Unsupported argument type: " .. btype) + end + + datalen = ffi.new('unsigned long[1]') + isnull = ffi.new('char[1]') + + return ffi.new(sql_bind, btype, buf, datalen, buflen, isnull) +end + +function sysbench.sql.bind_set(bind, value) + local sql_type = sysbench.sql.type + local btype = bind.type + + check_type(sql_bind, bind, 'sysbench.sql.bind_set') + + if (value == nil) then + bind.is_null[0] = true + return + end + + bind.is_null[0] = false + + if btype == sql_type.TINYINT or + btype == sql_type.SMALLINT or + btype == sql_type.INT or + btype == sql_type.BIGINT + then + ffi.copy(bind.buffer, ffi.new('int64_t[1]', value), 8) + elseif btype == sql_type.FLOAT or + btype == sql_type.DOUBLE + then + ffi.copy(bind.buffer, ffi.new('double[1]', value), 8) + elseif btype == sql_type.CHAR or + btype == sql_type.VARCHAR + then + local len = #value + len = bind.max_len < len and bind.max_len or len + ffi.copy(bind.buffer, value, len) + bind.data_len[0] = len + else + error("Unsupported argument type: " .. btype) + end +end + +function sysbench.sql.bind_destroy(bind) + check_type(sql_bind, bind, 'sysbench.sql.bind_destroy') +end + +function sysbench.sql.bind_param(stmt, ...) + local len = #{...} local i check_type(sql_statement, stmt, 'sysbench.sql.bind_param') - error("NYI") - return ffi.C.db_bind_param(stmt, - ffi.new("sql_bind[" .. len .."]", - sql_params), + ffi.new("sql_bind[?]", len, {...}), len) end +function sysbench.sql.execute(stmt) + check_type(sql_statement, stmt, 'sysbench.sql.execute') + return ffi.C.db_execute(stmt) +end + +function sysbench.sql.close(stmt) + check_type(sql_statement, stmt, 'sysbench.sql.close') + return ffi.C.db_close(stmt) +end + +function sysbench.sql.free_results(result) + check_type(sql_result, result, 'sysbench.sql.free_results') + return ffi.C.db_free_results(result) +end + -- sql_driver metatable local driver_mt = { __index = { @@ -205,18 +295,36 @@ ffi.metatype("struct db_conn", connection_mt) local statement_mt = { __index = { bind_param = sysbench.sql.bind_param, + bind_create = sysbench.sql.bind_create, + execute = sysbench.sql.execute, + close = sysbench.sql.close }, __tostring = function() return '' end, - __gc = ffi.C.db_close, + __gc = sysbench.sql.close, } ffi.metatype("struct db_stmt", statement_mt) -- sql_bind metatable local bind_mt = { + __index = { + set = sysbench.sql.bind_set + }, __tostring = function() return '' end, + __gc = sysbench.sql.bind_destroy } ffi.metatype("sql_bind", bind_mt) +-- sql_results metatable +local result_mt = { + __index = { + free = sysbench.sql.free_results + }, + __tostring = function() return '' end, + __gc = sysbench.sql.free_results +} +ffi.metatype("sql_result", result_mt) + + -- error codes sysbench.sql.ERROR_NONE = ffi.C.DB_ERROR_NONE sysbench.sql.ERROR_IGNORABLE = ffi.C.DB_ERROR_IGNORABLE diff --git a/tests/t/api_sql.t b/tests/t/api_sql.t index 27698d5..4bd5857 100644 --- a/tests/t/api_sql.t +++ b/tests/t/api_sql.t @@ -36,13 +36,20 @@ SQL Lua API tests > con:query("SELECT 1") > > con = drv:connect() - > con:query[[UPDATE t SET a = a + 100]] > - > -- local stmt = con:prepare("UPDATE t SET a = a + ?") - > -- stmt:bind_param({{sysbench.sql.type.INT}}) - > -- rs = stmt:execute() - > -- rs:free_results() - > -- stmt:close() + > con:query("ALTER TABLE t ADD COLUMN b CHAR(10)") + > + > local stmt = con:prepare("UPDATE t SET a = a + ?, b = ?") + > local a = stmt:bind_create(sysbench.sql.type.INT) + > local b = stmt:bind_create(sysbench.sql.type.CHAR, 10) + > stmt:bind_param(a, b) + > a:set(100) + > rs = stmt:execute() + > a:set(200) + > b:set("01234567890") + > rs = stmt:execute() + > rs:free() + > stmt:close() > end > EOF @@ -69,15 +76,18 @@ SQL Lua API tests *************************** 1. row *************************** t CREATE TABLE `t` ( - `a` int(11) DEFAULT NULL + `a` int(11) DEFAULT NULL, + `b` char(10)* DEFAULT NULL (glob) ) * (glob) $ mysql -uroot sbtest -Nse "SELECT COUNT(DISTINCT a) FROM t" 100 - $ mysql -uroot sbtest -Nse "SELECT MIN(a), MAX(a) FROM t\G" + $ mysql -uroot sbtest -Nse "SELECT MIN(a), MAX(a), MIN(b), MAX(b) FROM t\G" *************************** 1. row *************************** - 101 - 200 + 301 + 400 + 0123456789 + 0123456789 $ mysql -uroot sbtest -Nse "DROP TABLE t"