Export prepared statements API to Lua.

Fix function name in an error message.

Fix prepared statements for string parameters.
This commit is contained in:
Alexey Kopytov
2017-01-14 20:12:27 +03:00
parent 1d434809ec
commit ebc2cb5420
3 changed files with 147 additions and 24 deletions

View File

@ -281,7 +281,10 @@ int mysql_drv_thread_init(int thread_id)
{
(void) thread_id; /* unused */
return mysql_thread_init() != 0;
const my_bool rc = mysql_thread_init();
DEBUG("mysql_thread_init() = %d", (int) rc);
return rc != 0;
}
/* Thread-local driver deinitialization */
@ -290,6 +293,7 @@ int mysql_drv_thread_done(int thread_id)
{
(void) thread_id; /* unused */
DEBUG("mysql_thread_end(%s)", "");
mysql_thread_end();
return 0;
@ -805,6 +809,7 @@ db_error_t mysql_drv_execute(db_stmt_t *stmt, db_result_t *rs)
&rs->stat_type);
}
rs->stat_type = stmt->stat_type;
rs->nrows = (uint32_t) mysql_stmt_num_rows(stmt->ptr);
DEBUG("mysql_stmt_num_rows(%p) = %u", rs->statement->ptr,
(unsigned) (rs->nrows));

View File

@ -78,14 +78,17 @@ sql_result *db_query(sql_connection *con, const char *query, size_t len);
sql_statement *db_prepare(sql_connection *con, const char *query, size_t len);
int db_bind_param(sql_statement *stmt, sql_bind *params, size_t len);
int db_bind_result(sql_statement *stmt, sql_bind *results, size_t len);
sql_result db_execute(sql_statement *stmt);
sql_result *db_execute(sql_statement *stmt);
int db_close(sql_statement *stmt);
int db_free_results(sql_result *);
]]
local sql_driver = ffi.typeof('sql_driver *')
local sql_connection = ffi.typeof('sql_connection *')
local sql_statement = ffi.typeof('sql_statement *')
local sql_bind = ffi.typeof('sql_bind');
local sql_result = ffi.typeof('sql_result');
sysbench.sql.type =
{
@ -104,10 +107,10 @@ sysbench.sql.type =
VARCHAR = ffi.C.SQL_TYPE_VARCHAR
}
local function check_type(type, var, func)
if var == nil or not ffi.istype(type, var) then
local function check_type(vtype, var, func)
if var == nil or not ffi.istype(vtype, var) then
error(string.format("bad argument '%s' to %s() where a '%s' was expected",
var, func, type))
var, func, vtype))
end
end
@ -158,24 +161,111 @@ end
function sysbench.sql.prepare(con, query)
check_type(sql_connection, con, 'sysbench.sql.prepare')
return ffi.C.db_prepare(con, query, #query)
local stmt = ffi.C.db_prepare(con, query, #query)
return stmt
end
function sysbench.sql.bind_param(stmt, params)
local len = #params
local sql_params = {}
function sysbench.sql.bind_create(stmt, btype, maxlen)
local sql_type = sysbench.sql.type
local buf, buflen, datalen, isnull
check_type(sql_statement, stmt, 'sysbench.sql.bind_create')
if btype == sql_type.TINYINT or
btype == sql_type.SMALLINT or
btype == sql_type.INT or
btype == sql_type.BIGINT
then
btype = sql_type.BIGINT
buf = ffi.new('int64_t[1]')
buflen = 8
elseif btype == sql_type.FLOAT or
btype == sql_type.DOUBLE
then
btype = sql_type.DOUBLE
buf = ffi.new('double[1]')
buflen = 8
elseif btype == sql_type.CHAR or
btype == sql_type.VARCHAR
then
btype = sql_type.VARCHAR
buf = ffi.new('char[?]', maxlen)
buflen = maxlen
else
error("Unsupported argument type: " .. btype)
end
datalen = ffi.new('unsigned long[1]')
isnull = ffi.new('char[1]')
return ffi.new(sql_bind, btype, buf, datalen, buflen, isnull)
end
function sysbench.sql.bind_set(bind, value)
local sql_type = sysbench.sql.type
local btype = bind.type
check_type(sql_bind, bind, 'sysbench.sql.bind_set')
if (value == nil) then
bind.is_null[0] = true
return
end
bind.is_null[0] = false
if btype == sql_type.TINYINT or
btype == sql_type.SMALLINT or
btype == sql_type.INT or
btype == sql_type.BIGINT
then
ffi.copy(bind.buffer, ffi.new('int64_t[1]', value), 8)
elseif btype == sql_type.FLOAT or
btype == sql_type.DOUBLE
then
ffi.copy(bind.buffer, ffi.new('double[1]', value), 8)
elseif btype == sql_type.CHAR or
btype == sql_type.VARCHAR
then
local len = #value
len = bind.max_len < len and bind.max_len or len
ffi.copy(bind.buffer, value, len)
bind.data_len[0] = len
else
error("Unsupported argument type: " .. btype)
end
end
function sysbench.sql.bind_destroy(bind)
check_type(sql_bind, bind, 'sysbench.sql.bind_destroy')
end
function sysbench.sql.bind_param(stmt, ...)
local len = #{...}
local i
check_type(sql_statement, stmt, 'sysbench.sql.bind_param')
error("NYI")
return ffi.C.db_bind_param(stmt,
ffi.new("sql_bind[" .. len .."]",
sql_params),
ffi.new("sql_bind[?]", len, {...}),
len)
end
function sysbench.sql.execute(stmt)
check_type(sql_statement, stmt, 'sysbench.sql.execute')
return ffi.C.db_execute(stmt)
end
function sysbench.sql.close(stmt)
check_type(sql_statement, stmt, 'sysbench.sql.close')
return ffi.C.db_close(stmt)
end
function sysbench.sql.free_results(result)
check_type(sql_result, result, 'sysbench.sql.free_results')
return ffi.C.db_free_results(result)
end
-- sql_driver metatable
local driver_mt = {
__index = {
@ -205,18 +295,36 @@ ffi.metatype("struct db_conn", connection_mt)
local statement_mt = {
__index = {
bind_param = sysbench.sql.bind_param,
bind_create = sysbench.sql.bind_create,
execute = sysbench.sql.execute,
close = sysbench.sql.close
},
__tostring = function() return '<sql_statement>' end,
__gc = ffi.C.db_close,
__gc = sysbench.sql.close,
}
ffi.metatype("struct db_stmt", statement_mt)
-- sql_bind metatable
local bind_mt = {
__index = {
set = sysbench.sql.bind_set
},
__tostring = function() return '<sql_bind>' end,
__gc = sysbench.sql.bind_destroy
}
ffi.metatype("sql_bind", bind_mt)
-- sql_results metatable
local result_mt = {
__index = {
free = sysbench.sql.free_results
},
__tostring = function() return '<sql_result>' end,
__gc = sysbench.sql.free_results
}
ffi.metatype("sql_result", result_mt)
-- error codes
sysbench.sql.ERROR_NONE = ffi.C.DB_ERROR_NONE
sysbench.sql.ERROR_IGNORABLE = ffi.C.DB_ERROR_IGNORABLE

View File

@ -36,13 +36,20 @@ SQL Lua API tests
> con:query("SELECT 1")
>
> con = drv:connect()
> con:query[[UPDATE t SET a = a + 100]]
>
> -- local stmt = con:prepare("UPDATE t SET a = a + ?")
> -- stmt:bind_param({{sysbench.sql.type.INT}})
> -- rs = stmt:execute()
> -- rs:free_results()
> -- stmt:close()
> con:query("ALTER TABLE t ADD COLUMN b CHAR(10)")
>
> local stmt = con:prepare("UPDATE t SET a = a + ?, b = ?")
> local a = stmt:bind_create(sysbench.sql.type.INT)
> local b = stmt:bind_create(sysbench.sql.type.CHAR, 10)
> stmt:bind_param(a, b)
> a:set(100)
> rs = stmt:execute()
> a:set(200)
> b:set("01234567890")
> rs = stmt:execute()
> rs:free()
> stmt:close()
> end
> EOF
@ -69,15 +76,18 @@ SQL Lua API tests
*************************** 1. row ***************************
t
CREATE TABLE `t` (
`a` int(11) DEFAULT NULL
`a` int(11) DEFAULT NULL,
`b` char(10)* DEFAULT NULL (glob)
) * (glob)
$ mysql -uroot sbtest -Nse "SELECT COUNT(DISTINCT a) FROM t"
100
$ mysql -uroot sbtest -Nse "SELECT MIN(a), MAX(a) FROM t\G"
$ mysql -uroot sbtest -Nse "SELECT MIN(a), MAX(a), MIN(b), MAX(b) FROM t\G"
*************************** 1. row ***************************
101
200
301
400
0123456789
0123456789
$ mysql -uroot sbtest -Nse "DROP TABLE t"