!16 增加了新的占位符

Merge pull request !16 from Luan-233/master
This commit is contained in:
opengauss_bot
2023-08-16 09:24:24 +00:00
committed by Gitee
4 changed files with 524 additions and 348 deletions

View File

@ -140,15 +140,15 @@ class DictCursor(DictCursorBase):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._prefetch = True self._prefetch = True
def execute(self, query, vars=None): def execute(self, query, vars=None, place_holder = '%'):
self.index = OrderedDict() self.index = OrderedDict()
self._query_executed = True self._query_executed = True
return super().execute(query, vars) return super().execute(query, vars, place_holder)
def callproc(self, procname, vars=None): def callproc(self, procname, vars=None, place_holder = '%'):
self.index = OrderedDict() self.index = OrderedDict()
self._query_executed = True self._query_executed = True
return super().callproc(procname, vars) return super().callproc(procname, vars, place_holder)
def _build_index(self): def _build_index(self):
if self._query_executed and self.description: if self._query_executed and self.description:
@ -230,15 +230,15 @@ class RealDictCursor(DictCursorBase):
kwargs['row_factory'] = RealDictRow kwargs['row_factory'] = RealDictRow
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def execute(self, query, vars=None): def execute(self, query, vars=None, place_holder = '%'):
self.column_mapping = [] self.column_mapping = []
self._query_executed = True self._query_executed = True
return super().execute(query, vars) return super().execute(query, vars, place_holder)
def callproc(self, procname, vars=None): def callproc(self, procname, vars=None, place_holder = '%'):
self.column_mapping = [] self.column_mapping = []
self._query_executed = True self._query_executed = True
return super().callproc(procname, vars) return super().callproc(procname, vars, place_holder)
def _build_index(self): def _build_index(self):
if self._query_executed and self.description: if self._query_executed and self.description:
@ -307,17 +307,17 @@ class NamedTupleCursor(_cursor):
Record = None Record = None
MAX_CACHE = 1024 MAX_CACHE = 1024
def execute(self, query, vars=None): def execute(self, query, vars=None, place_holder = '%'):
self.Record = None self.Record = None
return super().execute(query, vars) return super().execute(query, vars, place_holder)
def executemany(self, query, vars): def executemany(self, query, vars, place_holder = '%'):
self.Record = None self.Record = None
return super().executemany(query, vars) return super().executemany(query, vars, place_holder)
def callproc(self, procname, vars=None): def callproc(self, procname, vars=None, place_holder = '%'):
self.Record = None self.Record = None
return super().callproc(procname, vars) return super().callproc(procname, vars, place_holder)
def fetchone(self): def fetchone(self):
t = super().fetchone() t = super().fetchone()
@ -444,15 +444,15 @@ class LoggingConnection(_connection):
class LoggingCursor(_cursor): class LoggingCursor(_cursor):
"""A cursor that logs queries using its connection logging facilities.""" """A cursor that logs queries using its connection logging facilities."""
def execute(self, query, vars=None): def execute(self, query, vars=None, place_holder = '%'):
try: try:
return super().execute(query, vars) return super().execute(query, vars, place_holder)
finally: finally:
self.connection.log(self.query, self) self.connection.log(self.query, self)
def callproc(self, procname, vars=None): def callproc(self, procname, vars=None, place_holder = '%'):
try: try:
return super().callproc(procname, vars) return super().callproc(procname, vars, place_holder)
finally: finally:
self.connection.log(self.query, self) self.connection.log(self.query, self)
@ -488,13 +488,13 @@ class MinTimeLoggingConnection(LoggingConnection):
class MinTimeLoggingCursor(LoggingCursor): class MinTimeLoggingCursor(LoggingCursor):
"""The cursor sub-class companion to `MinTimeLoggingConnection`.""" """The cursor sub-class companion to `MinTimeLoggingConnection`."""
def execute(self, query, vars=None): def execute(self, query, vars=None, place_holder = '%'):
self.timestamp = _time.time() self.timestamp = _time.time()
return LoggingCursor.execute(self, query, vars) return LoggingCursor.execute(self, query, vars, place_holder)
def callproc(self, procname, vars=None): def callproc(self, procname, vars=None, place_holder = '%'):
self.timestamp = _time.time() self.timestamp = _time.time()
return LoggingCursor.callproc(self, procname, vars) return LoggingCursor.callproc(self, procname, vars, place_holder)
class LogicalReplicationConnection(_replicationConnection): class LogicalReplicationConnection(_replicationConnection):
@ -1099,8 +1099,7 @@ ORDER BY attnum;
recs = curs.fetchall() recs = curs.fetchall()
# revert the status of the connection as before the command # revert the status of the connection as before the command
if (conn_status != _ext.STATUS_IN_TRANSACTION if conn_status != _ext.STATUS_IN_TRANSACTION and not conn.autocommit:
and not conn.autocommit):
conn.rollback() conn.rollback()
if not recs: if not recs:
@ -1161,8 +1160,7 @@ def _paginate(seq, page_size):
yield page yield page
return return
def execute_batch(cur, sql, argslist, page_size=100, place_holder = '%'):
def execute_batch(cur, sql, argslist, page_size=100):
r"""Execute groups of statements in fewer server roundtrips. r"""Execute groups of statements in fewer server roundtrips.
Execute *sql* several times, against all parameters set (sequences or Execute *sql* several times, against all parameters set (sequences or
@ -1183,11 +1181,11 @@ def execute_batch(cur, sql, argslist, page_size=100):
""" """
for page in _paginate(argslist, page_size=page_size): for page in _paginate(argslist, page_size=page_size):
sqls = [cur.mogrify(sql, args) for args in page] sqls = [cur.mogrify(sql, args, place_holder) for args in page]
cur.execute(b";".join(sqls)) cur.execute(b";".join(sqls))
def execute_values(cur, sql, argslist, template=None, page_size=100, fetch=False): def execute_values(cur, sql, argslist, template=None, page_size=100, fetch=False, place_holder = '%'):
'''Execute a statement using :sql:`VALUES` with a sequence of parameters. '''Execute a statement using :sql:`VALUES` with a sequence of parameters.
:param cur: the cursor to use to execute the query. :param cur: the cursor to use to execute the query.
@ -1264,7 +1262,7 @@ def execute_values(cur, sql, argslist, template=None, page_size=100, fetch=False
template = b'(' + b','.join([b'%s'] * len(page[0])) + b')' template = b'(' + b','.join([b'%s'] * len(page[0])) + b')'
parts = pre[:] parts = pre[:]
for args in page: for args in page:
parts.append(cur.mogrify(template, args)) parts.append(cur.mogrify(template, args, place_holder))
parts.append(b',') parts.append(b',')
parts[-1:] = post parts[-1:] = post
cur.execute(b''.join(parts)) cur.execute(b''.join(parts))

View File

@ -85,225 +85,314 @@
/* Helpers for formatstring */ /* Helpers for formatstring */
BORROWED Py_LOCAL_INLINE(PyObject *) BORROWED Py_LOCAL_INLINE(PyObject *)
getnextarg(PyObject *args, Py_ssize_t arglen, Py_ssize_t *p_argidx) getnextarg(PyObject *args, Py_ssize_t arglen, Py_ssize_t *p_argidx) {
{
Py_ssize_t argidx = *p_argidx; Py_ssize_t argidx = *p_argidx;
if (argidx < arglen) { if (argidx < arglen) {
(*p_argidx)++; (*p_argidx)++;
if (arglen < 0) if (arglen < 0) return args;
return args; else return PyTuple_GetItem(args, argidx);
else
return PyTuple_GetItem(args, argidx);
} }
PyErr_SetString(PyExc_TypeError,
"not enough arguments for format string");
return NULL; return NULL;
} }
/*
for function getnextarg:
I delete the line including 'raise error', for making this func a iterator
just used to fill in the arguments array
*/
/* wrapper around _Bytes_Resize offering normal Python call semantics */ /* wrapper around _Bytes_Resize offering normal Python call semantics */
STEALS(1) STEALS(1)
Py_LOCAL_INLINE(PyObject *) Py_LOCAL_INLINE(PyObject *)
resize_bytes(PyObject *b, Py_ssize_t newsize) { resize_bytes(PyObject *b, Py_ssize_t newsize) {
if (0 == _Bytes_Resize(&b, newsize)) { if (0 == _Bytes_Resize(&b, newsize)) return b;
return b; else return NULL;
}
else {
return NULL;
}
} }
/* fmt%(v1,v2,...) is roughly equivalent to sprintf(fmt, v1, v2, ...) */ PyObject *Bytes_Format(PyObject *format, PyObject *args, char place_holder) {
char *fmt, *res; //array pointer of format, and array pointer of result
Py_ssize_t arglen, argidx; //length of arguments array, and index of arguments(when processing args_list)
Py_ssize_t reslen, rescnt, fmtcnt; //rescnt: blank space in result; reslen: the total length of result; fmtcnt: length of format
int args_owned = 0; //args is valid or invalid(or maybe refcnt), 0 for invalid,1 otherwise
PyObject *result; //function's return value
PyObject *dict = NULL; //dictionary
PyObject *args_value = NULL; //every argument store in it after parse
char **args_list = NULL; //arguments list as char **
char *args_buffer = NULL; //Bytes_AS_STRING(args_value)
Py_ssize_t * args_len = NULL; //every argument's length in args_list
int args_id = 0; //index of arguments(when generating result)
int index_type = 0; //if exists $number, it will be 1, otherwise 0
PyObject * if (format == NULL || !Bytes_Check(format) || args == NULL) { //check if arguments are valid
Bytes_Format(PyObject *format, PyObject *args)
{
char *fmt, *res;
Py_ssize_t arglen, argidx;
Py_ssize_t reslen, rescnt, fmtcnt;
int args_owned = 0;
PyObject *result;
PyObject *dict = NULL;
if (format == NULL || !Bytes_Check(format) || args == NULL) {
PyErr_SetString(PyExc_SystemError, "bad argument to internal function"); PyErr_SetString(PyExc_SystemError, "bad argument to internal function");
return NULL; return NULL;
} }
fmt = Bytes_AS_STRING(format); fmt = Bytes_AS_STRING(format); //get pointer of format
fmtcnt = Bytes_GET_SIZE(format); fmtcnt = Bytes_GET_SIZE(format); //get length of format
reslen = rescnt = fmtcnt + 100; reslen = rescnt = 1;
while (reslen <= fmtcnt) { //when space is not enough, double it's size
reslen *= 2;
rescnt *= 2;
}
result = Bytes_FromStringAndSize((char *)NULL, reslen); result = Bytes_FromStringAndSize((char *)NULL, reslen);
if (result == NULL) if (result == NULL) return NULL;
return NULL;
res = Bytes_AS_STRING(result); res = Bytes_AS_STRING(result);
if (PyTuple_Check(args)) { if (PyTuple_Check(args)) { //check if arguments are sequences
arglen = PyTuple_GET_SIZE(args); arglen = PyTuple_GET_SIZE(args);
argidx = 0; argidx = 0;
} }
else { else { //if no, then this two are of no importance
arglen = -1; arglen = -1;
argidx = -2; argidx = -2;
} }
if (Py_TYPE(args)->tp_as_mapping && !PyTuple_Check(args) && if (Py_TYPE(args)->tp_as_mapping && !PyTuple_Check(args) && !PyObject_TypeCheck(args, &Bytes_Type)) { //check if args is dict
!PyObject_TypeCheck(args, &Bytes_Type))
dict = args; dict = args;
while (--fmtcnt >= 0) { //Py_INCREF(dict);
if (*fmt != '%') { }
while (--fmtcnt >= 0) { //scan the format
if (*fmt != '%') { //if not %, pass it(for the special format '%(name)s')
if (--rescnt < 0) { if (--rescnt < 0) {
rescnt = fmtcnt + 100; rescnt = reslen; //double the space
reslen += rescnt; reslen *= 2;
if (!(result = resize_bytes(result, reslen))) { if (!(result = resize_bytes(result, reslen))) {
return NULL; return NULL;
} }
res = Bytes_AS_STRING(result) + reslen - rescnt; res = Bytes_AS_STRING(result) + reslen - rescnt;//calculate offset
--rescnt; --rescnt;
} }
*res++ = *fmt++; *res++ = *fmt++; //copy
} }
else { else {
/* Got a format specifier */ /* Got a format specifier */
Py_ssize_t width = -1;
int c = '\0';
PyObject *v = NULL;
PyObject *temp = NULL;
char *pbuf;
Py_ssize_t len;
fmt++; fmt++;
if (*fmt == '(') { if (*fmt == '(') {
char *keystart; char *keystart; //begin pos of left bracket
Py_ssize_t keylen; Py_ssize_t keylen; //length of content in bracket
PyObject *key; PyObject *key;
int pcount = 1; int pcount = 1; //counter of left bracket
Py_ssize_t length = 0;
if (dict == NULL) { if (dict == NULL) {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError, "format requires a mapping");
"format requires a mapping");
goto error; goto error;
} }
++fmt; ++fmt;
--fmtcnt; --fmtcnt;
keystart = fmt; keystart = fmt;
/* Skip over balanced parentheses */ /* Skip over balanced parentheses */
while (pcount > 0 && --fmtcnt >= 0) { while (pcount > 0 && --fmtcnt >= 0) { //find the matching right bracket
if (*fmt == ')') if (*fmt == ')') --pcount;
--pcount; else if (*fmt == '(') ++pcount;
else if (*fmt == '(')
++pcount;
fmt++; fmt++;
} }
keylen = fmt - keystart - 1; keylen = fmt - keystart - 1;
if (fmtcnt < 0 || pcount > 0) { if (fmtcnt < 0 || pcount > 0 || *(fmt++) != 's') { //not found, raise an error
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError, "incomplete format key");
"incomplete format key");
goto error; goto error;
} }
key = Text_FromUTF8AndSize(keystart, keylen); --fmtcnt;
if (key == NULL) key = Text_FromUTF8AndSize(keystart, keylen);//get key
goto error; if (key == NULL) goto error;
if (args_owned) { if (args_owned) { //if refcnt > 0, then release
Py_DECREF(args); Py_DECREF(args);
args_owned = 0; args_owned = 0;
} }
args = PyObject_GetItem(dict, key); args = PyObject_GetItem(dict, key); //get value with key
Py_DECREF(key); Py_DECREF(key);
if (args == NULL) { if (args == NULL) goto error;
if (!Bytes_CheckExact(args)) {
PyErr_Format(PyExc_ValueError, "only bytes values expected, got %s", Py_TYPE(args)->tp_name); //raise error, but may have bug
goto error; goto error;
} }
args_buffer = Bytes_AS_STRING(args); //temporary buffer
length = Bytes_GET_SIZE(args);
if (rescnt < length) {
while (rescnt < length) {
rescnt += reslen;
reslen *= 2;
}
if ((result = resize_bytes(result, reslen)) == NULL) goto error;
}
res = Bytes_AS_STRING(result) + reslen - rescnt;
Py_MEMCPY(res, args_buffer, length);
rescnt -= length;
res += length;
args_owned = 1; args_owned = 1;
arglen = -1; arglen = -1; //exists place holder as "%(name)s", set these arguments to invalid
argidx = -2; argidx = -2;
} }
while (--fmtcnt >= 0) { } /* '%' */
c = *fmt++; } /* until end */
break;
} if (dict) { //if args' type is dict, the func ends
if (fmtcnt < 0) { if (args_owned) Py_DECREF(args);
PyErr_SetString(PyExc_ValueError, if (!(result = resize_bytes(result, reslen - rescnt))) return NULL; //resize and return
"incomplete format"); if (place_holder != '%') {
PyErr_SetString(PyExc_TypeError, "place holder only expect %% when using dict");
goto error; goto error;
} }
switch (c) { return result;
case '%': }
pbuf = "%";
len = 1; args_list = (char **)malloc(sizeof(char *) * arglen); //buffer
break; args_len = (Py_ssize_t *)malloc(sizeof(Py_ssize_t *) * arglen); //length of every argument
case 's': while ((args_value = getnextarg(args, arglen, &argidx)) != NULL) { //stop when receive NULL
/* only bytes! */ Py_ssize_t length = 0;
if (!(v = getnextarg(args, arglen, &argidx))) if (!Bytes_CheckExact(args_value)) {
goto error; PyErr_Format(PyExc_ValueError, "only bytes values expected, got %s", Py_TYPE(args_value)->tp_name); //may have bug
if (!Bytes_CheckExact(v)) {
PyErr_Format(PyExc_ValueError,
"only bytes values expected, got %s",
Py_TYPE(v)->tp_name);
goto error; goto error;
} }
temp = v; Py_INCREF(args_value); //increase refcnt
Py_INCREF(v); args_buffer = Bytes_AS_STRING(args_value);
pbuf = Bytes_AS_STRING(temp); length = Bytes_GET_SIZE(args_value);
len = Bytes_GET_SIZE(temp); //printf("type: %s, len: %d, value: %s\n", Py_TYPE(args_value)->tp_name, length, args_buffer);
break; args_len[argidx - 1] = length;
default: args_list[argidx - 1] = (char *)malloc(sizeof(char *) * (length + 1));
PyErr_Format(PyExc_ValueError, Py_MEMCPY(args_list[argidx - 1], args_buffer, length);
"unsupported format character '%c' (0x%x) " args_list[argidx - 1][length] = '\0';
Py_XDECREF(args_value);
}
fmt = Bytes_AS_STRING(format); //get pointer of format
fmtcnt = Bytes_GET_SIZE(format); //get length of format
reslen = rescnt = 1;
while (reslen <= fmtcnt) {
reslen *= 2;
rescnt *= 2;
}
if ((result = resize_bytes(result, reslen)) == NULL) goto error;
res = Bytes_AS_STRING(result);
memset(res, 0, sizeof(char) * reslen);
while (*fmt != '\0') {
if (*fmt != place_holder) { //not place holder, pass it
if (!rescnt) {
rescnt += reslen;
reslen *= 2;
if ((result = resize_bytes(result, reslen)) == NULL) goto error;
res = Bytes_AS_STRING(result) + reslen - rescnt;
}
*(res++) = *(fmt++);
--rescnt;
continue;
}
if (*fmt == '%') {
char c = *(++fmt);
if (c == '\0') { //if there is nothing after '%', raise an error
PyErr_SetString(PyExc_ValueError, "incomplete format");
goto error;
}
else if (c == '%') { //'%%' will be transfered to '%'
if (!rescnt) {
rescnt += reslen;
reslen *= 2;
if ((result = resize_bytes(result, reslen)) == NULL) goto error;
res = Bytes_AS_STRING(result) + reslen - rescnt;
}
*res = c;
--rescnt;
++res;
++fmt;
}
else if (c == 's') { //'%s', replace it with corresponding string
if (args_id >= arglen) { //index is out of bound
PyErr_SetString(PyExc_TypeError, "arguments not enough during string formatting");
goto error;
}
if (rescnt < args_len[args_id]) {
while (rescnt < args_len[args_id]) {
rescnt += reslen;
reslen *= 2;
}
if ((result = resize_bytes(result, reslen)) == NULL) goto error;
res = Bytes_AS_STRING(result) + reslen - rescnt;
}
Py_MEMCPY(res, args_list[args_id], args_len[args_id]);
rescnt -= args_len[args_id];
res += args_len[args_id];
++args_id;
++fmt;
}
else { //not support the character currently
PyErr_Format(PyExc_ValueError, "unsupported format character '%c' (0x%x) "
"at index " FORMAT_CODE_PY_SSIZE_T, "at index " FORMAT_CODE_PY_SSIZE_T,
c, c, c, c,
(Py_ssize_t)(fmt - 1 - Bytes_AS_STRING(format))); (Py_ssize_t)(fmt - 1 - Bytes_AS_STRING(format)));
goto error; goto error;
} }
if (width < len) continue;
width = len;
if (rescnt < width) {
reslen -= rescnt;
rescnt = width + fmtcnt + 100;
reslen += rescnt;
if (reslen < 0) {
Py_DECREF(result);
Py_XDECREF(temp);
if (args_owned)
Py_DECREF(args);
return PyErr_NoMemory();
} }
if (!(result = resize_bytes(result, reslen))) { if (*fmt == '$') {
Py_XDECREF(temp); char c = *(++fmt);
if (args_owned) if (c == '\0') { //if there is nothing after '$', raise an error
Py_DECREF(args); PyErr_SetString(PyExc_ValueError, "incomplete format");
return NULL; goto error;
} }
res = Bytes_AS_STRING(result) else if (c == '$') { //'$$' will be transfered to'$'
+ reslen - rescnt; if (!rescnt) { //resize buffer
rescnt += reslen;
reslen *= 2;
if ((result = resize_bytes(result, reslen)) == NULL) goto error;
res = Bytes_AS_STRING(result) + reslen - rescnt;
} }
Py_MEMCPY(res, pbuf, len); *res = c;
res += len;
rescnt -= len;
while (--width >= len) {
--rescnt; --rescnt;
*res++ = ' '; ++res;
++fmt;
} }
if (dict && (argidx < arglen) && c != '%') { else if (isdigit(c)) { //represents '$number'
PyErr_SetString(PyExc_TypeError, int index = 0;
"not all arguments converted during string formatting"); index_type = 1;
Py_XDECREF(temp); while (isdigit(*fmt)) {
index = index * 10 + (*fmt) -'0';
++fmt;
}
if ((index > arglen) || (index <= 0)) { //invalid index
PyErr_SetString(PyExc_ValueError, "invalid index");
goto error; goto error;
} }
Py_XDECREF(temp); --index;
} /* '%' */ if (rescnt < args_len[index]) {
} /* until end */ while (rescnt < args_len[index]) {
if (argidx < arglen && !dict) { rescnt += reslen;
PyErr_SetString(PyExc_TypeError, reslen *= 2;
"not all arguments converted during string formatting"); }
if ((result = resize_bytes(result, reslen)) == NULL) goto error;
res = Bytes_AS_STRING(result) + reslen - rescnt;
}
Py_MEMCPY(res, args_list[index], args_len[index]);
rescnt -= args_len[index];
res += args_len[index];
}
else { //invalid place holder
PyErr_Format(PyExc_ValueError, "unsupported format character '%c' (0x%x) "
"at index " FORMAT_CODE_PY_SSIZE_T,
c, c,
(Py_ssize_t)(fmt - 1 - Bytes_AS_STRING(format)));
goto error; goto error;
} }
if (args_owned) {
Py_DECREF(args);
} }
if (!(result = resize_bytes(result, reslen - rescnt))) {
return NULL;
} }
if ((args_id < arglen) && (!dict) && (!index_type)) { //not all arguments are used
PyErr_SetString(PyExc_TypeError, "not all arguments converted during string formatting");
goto error;
}
if (args_list != NULL) {
while (--argidx >= 0) free(args_list[argidx]);
free(args_list);
free(args_len);
}
if (args_owned) Py_DECREF(args);
if (!(result = resize_bytes(result, reslen - rescnt))) return NULL; //resize
return result; return result;
error: error:
Py_DECREF(result); if (args_list != NULL) { //release all the refcnt
if (args_owned) { while (--argidx >= 0) free(args_list[argidx]);
Py_DECREF(args); free(args_list);
free(args_len);
} }
Py_DECREF(result);
if (args_owned) Py_DECREF(args);
return NULL; return NULL;
} }

View File

@ -127,12 +127,13 @@ exit:
/* mogrify a query string and build argument array or dict */ /* mogrify a query string and build argument array or dict */
RAISES_NEG static int RAISES_NEG static int
_mogrify(PyObject *var, PyObject *fmt, cursorObject *curs, PyObject **new) _mogrify(PyObject *var, PyObject *fmt, cursorObject *curs, PyObject **new, char place_holder)
{ {
PyObject *key, *value, *n; PyObject *key, *value, *n;
const char *d, *c; const char *d, *c;
Py_ssize_t index = 0; Py_ssize_t index = 0;
int force = 0, kind = 0; int force = 0, kind = 0;
int max_index = 0;
/* from now on we'll use n and replace its value in *new only at the end, /* from now on we'll use n and replace its value in *new only at the end,
just before returning. we also init *new to NULL to exit with an error just before returning. we also init *new to NULL to exit with an error
@ -141,28 +142,12 @@ _mogrify(PyObject *var, PyObject *fmt, cursorObject *curs, PyObject **new)
c = Bytes_AsString(fmt); c = Bytes_AsString(fmt);
while(*c) { while(*c) {
if (*c++ != '%') { while ((*c != '\0') && (*c != place_holder)) ++c;
/* a regular character */ if (*c == '%') {
continue;
}
switch (*c) {
/* handle plain percent symbol in format string */
case '%':
++c; ++c;
force = 1; force = 1;
break; if (*c == '(') {
if ((kind == 2) || (kind == 3)) {
/* if we find '%(' then this is a dictionary, we:
1/ find the matching ')' and extract the key name
2/ locate the value in the dictionary (or return an error)
3/ mogrify the value into something useful (quoting)...
4/ ...and add it to the new dictionary to be used as argument
*/
case '(':
/* check if some crazy guy mixed formats */
if (kind == 2) {
Py_XDECREF(n); Py_XDECREF(n);
psyco_set_error(ProgrammingError, curs, psyco_set_error(ProgrammingError, curs,
"argument formats can't be mixed"); "argument formats can't be mixed");
@ -240,16 +225,15 @@ _mogrify(PyObject *var, PyObject *fmt, cursorObject *curs, PyObject **new)
"incomplete placeholder: '%(' without ')'"); "incomplete placeholder: '%(' without ')'");
return -1; return -1;
} }
c = d + 1; /* after the ) */ c = d + 1;
break; }
else if (*c == 's') {
default:
/* this is a format that expects a tuple; it is much easier, /* this is a format that expects a tuple; it is much easier,
because we don't need to check the old/new dictionary for because we don't need to check the old/new dictionary for
keys */ keys */
/* check if some crazy guy mixed formats */ /* check if some crazy guy mixed formats */
if (kind == 1) { if ((kind == 1) || (kind == 3)) {
Py_XDECREF(n); Py_XDECREF(n);
psyco_set_error(ProgrammingError, curs, psyco_set_error(ProgrammingError, curs,
"argument formats can't be mixed"); "argument formats can't be mixed");
@ -297,8 +281,60 @@ _mogrify(PyObject *var, PyObject *fmt, cursorObject *curs, PyObject **new)
} }
} }
if (force && n == NULL) else if (*c == '$') { //new place holder $
n = PyTuple_New(0); int tmp_index = 0;
if ((kind == 1) || (kind == 2)) {
Py_XDECREF(n);
psyco_set_error(ProgrammingError, curs,
"argument formats can't be mixed");
return -1;
}
kind = 3; //kind = 3 means using
++c;
while (isdigit(*c)) { //calculate index
tmp_index = tmp_index * 10 + (*c) -'0';
++c;
}
--tmp_index;
for (; max_index <= tmp_index; ++max_index) {
//to avoid index not cover all arguments, which may cause double free in bytes_format
int id = max_index;
value = PySequence_GetItem(var, id);
if (value == NULL) {
Py_XDECREF(n);
return -1;
}
if (n == NULL) {
if (!(n = PyTuple_New(PyObject_Length(var)))) {
Py_DECREF(value);
return -1;
}
}
if (value == Py_None) {
Py_INCREF(psyco_null);
PyTuple_SET_ITEM(n, id, psyco_null);
Py_DECREF(value);
}
else {
PyObject *t = microprotocol_getquoted(value, curs->conn);
if (t != NULL) {
PyTuple_SET_ITEM(n, id, t);
Py_DECREF(value);
}
else {
Py_DECREF(n);
Py_DECREF(value);
return -1;
}
}
}
}
}
if (force && n == NULL) n = PyTuple_New(0);
*new = n; *new = n;
return 0; return 0;
@ -314,7 +350,7 @@ _mogrify(PyObject *var, PyObject *fmt, cursorObject *curs, PyObject **new)
*/ */
static PyObject * static PyObject *
_psyco_curs_merge_query_args(cursorObject *self, _psyco_curs_merge_query_args(cursorObject *self,
PyObject *query, PyObject *args) PyObject *query, PyObject *args, char place_holder)
{ {
PyObject *fquery; PyObject *fquery;
@ -329,7 +365,7 @@ _psyco_curs_merge_query_args(cursorObject *self,
the current exception (we will later restore it if the type or the the current exception (we will later restore it if the type or the
strings do not match.) */ strings do not match.) */
if (!(fquery = Bytes_Format(query, args))) { if (!(fquery = Bytes_Format(query, args, place_holder))) {
PyObject *err, *arg, *trace; PyObject *err, *arg, *trace;
int pe = 0; int pe = 0;
@ -376,7 +412,7 @@ _psyco_curs_merge_query_args(cursorObject *self,
RAISES_NEG static int RAISES_NEG static int
_psyco_curs_execute(cursorObject *self, _psyco_curs_execute(cursorObject *self,
PyObject *query, PyObject *vars, PyObject *query, PyObject *vars,
long int async, int no_result) char place_holder, long int async, int no_result)
{ {
int res = -1; int res = -1;
int tmp; int tmp;
@ -396,12 +432,12 @@ _psyco_curs_execute(cursorObject *self,
the right thing (i.e., what the user expects) */ the right thing (i.e., what the user expects) */
if (vars && vars != Py_None) if (vars && vars != Py_None)
{ {
if (0 > _mogrify(vars, query, self, &cvt)) { goto exit; } if (0 > _mogrify(vars, query, self, &cvt, place_holder)) { goto exit; }
} }
/* Merge the query to the arguments if needed */ /* Merge the query to the arguments if needed */
if (cvt) { if (cvt) {
if (!(fquery = _psyco_curs_merge_query_args(self, query, cvt))) { if (!(fquery = _psyco_curs_merge_query_args(self, query, cvt, place_holder))) {
goto exit; goto exit;
} }
} }
@ -461,15 +497,28 @@ exit:
static PyObject * static PyObject *
curs_execute(cursorObject *self, PyObject *args, PyObject *kwargs) curs_execute(cursorObject *self, PyObject *args, PyObject *kwargs)
{ {
PyObject *vars = NULL, *operation = NULL; PyObject *vars = NULL, *operation = NULL, *Place_holder = NULL;
char place_holder = '%'; //default value: '%'
static char *kwlist[] = {"query", "vars", NULL}; static char *kwlist[] = {"query", "vars", "place_holder", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O", kwlist, if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OO", kwlist,
&operation, &vars)) { &operation, &vars, &Place_holder)) {
return NULL; return NULL;
} }
if (Place_holder != NULL) { //if exists place holder argument, it will be checked and parse
if (!(Place_holder = curs_validate_sql_basic(self, Place_holder))) {
psyco_set_error(ProgrammingError, self, "can't parse place holder");
return NULL;
}
if (Bytes_GET_SIZE(Place_holder) != 1) {
psyco_set_error(ProgrammingError, self, "place holder must be a character");
return NULL;
}
place_holder = Bytes_AS_STRING(Place_holder)[0];
}
if (self->name != NULL) { if (self->name != NULL) {
if (self->query) { if (self->query) {
psyco_set_error(ProgrammingError, self, psyco_set_error(ProgrammingError, self,
@ -488,7 +537,7 @@ curs_execute(cursorObject *self, PyObject *args, PyObject *kwargs)
EXC_IF_ASYNC_IN_PROGRESS(self, execute); EXC_IF_ASYNC_IN_PROGRESS(self, execute);
EXC_IF_TPC_PREPARED(self->conn, execute); EXC_IF_TPC_PREPARED(self->conn, execute);
if (0 > _psyco_curs_execute(self, operation, vars, self->conn->async, 0)) { if (0 > _psyco_curs_execute(self, operation, vars, place_holder, self->conn->async, 0)) {
return NULL; return NULL;
} }
@ -502,20 +551,33 @@ curs_execute(cursorObject *self, PyObject *args, PyObject *kwargs)
static PyObject * static PyObject *
curs_executemany(cursorObject *self, PyObject *args, PyObject *kwargs) curs_executemany(cursorObject *self, PyObject *args, PyObject *kwargs)
{ {
PyObject *operation = NULL, *vars = NULL; PyObject *operation = NULL, *vars = NULL, *Place_holder = NULL;
PyObject *v, *iter = NULL; PyObject *v, *iter = NULL;
char place_holder = '%';
long rowcount = 0; long rowcount = 0;
static char *kwlist[] = {"query", "vars_list", NULL}; static char *kwlist[] = {"query", "vars_list", "plae_holder", NULL};
/* reset rowcount to -1 to avoid setting it when an exception is raised */ /* reset rowcount to -1 to avoid setting it when an exception is raised */
self->rowcount = -1; self->rowcount = -1;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OO", kwlist, if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OO", kwlist,
&operation, &vars)) { &operation, &vars, &Place_holder)) {
return NULL; return NULL;
} }
if (Place_holder != NULL) {
if (!(Place_holder = curs_validate_sql_basic(self, Place_holder))) {
psyco_set_error(ProgrammingError, self, "can't parse place holder");
return NULL;
}
if (Bytes_GET_SIZE(Place_holder) != 1) {
psyco_set_error(ProgrammingError, self, "place holder must be a character");
return NULL;
}
place_holder = Bytes_AS_STRING(Place_holder)[0];
}
EXC_IF_CURS_CLOSED(self); EXC_IF_CURS_CLOSED(self);
EXC_IF_CURS_ASYNC(self, executemany); EXC_IF_CURS_ASYNC(self, executemany);
EXC_IF_TPC_PREPARED(self->conn, executemany); EXC_IF_TPC_PREPARED(self->conn, executemany);
@ -532,7 +594,7 @@ curs_executemany(cursorObject *self, PyObject *args, PyObject *kwargs)
} }
while ((v = PyIter_Next(vars)) != NULL) { while ((v = PyIter_Next(vars)) != NULL) {
if (0 > _psyco_curs_execute(self, operation, v, 0, 1)) { if (0 > _psyco_curs_execute(self, operation, v, place_holder, 0, 1)) {
Py_DECREF(v); Py_DECREF(v);
Py_XDECREF(iter); Py_XDECREF(iter);
return NULL; return NULL;
@ -562,7 +624,7 @@ curs_executemany(cursorObject *self, PyObject *args, PyObject *kwargs)
static PyObject * static PyObject *
_psyco_curs_mogrify(cursorObject *self, _psyco_curs_mogrify(cursorObject *self,
PyObject *operation, PyObject *vars) PyObject *operation, PyObject *vars, char place_holder)
{ {
PyObject *fquery = NULL, *cvt = NULL; PyObject *fquery = NULL, *cvt = NULL;
@ -577,13 +639,13 @@ _psyco_curs_mogrify(cursorObject *self,
if (vars && vars != Py_None) if (vars && vars != Py_None)
{ {
if (0 > _mogrify(vars, operation, self, &cvt)) { if (0 > _mogrify(vars, operation, self, &cvt, place_holder)) {
goto cleanup; goto cleanup;
} }
} }
if (vars && cvt) { if (vars && cvt) {
if (!(fquery = _psyco_curs_merge_query_args(self, operation, cvt))) { if (!(fquery = _psyco_curs_merge_query_args(self, operation, cvt, place_holder))) {
goto cleanup; goto cleanup;
} }
@ -606,16 +668,29 @@ cleanup:
static PyObject * static PyObject *
curs_mogrify(cursorObject *self, PyObject *args, PyObject *kwargs) curs_mogrify(cursorObject *self, PyObject *args, PyObject *kwargs)
{ {
PyObject *vars = NULL, *operation = NULL; PyObject *vars = NULL, *operation = NULL, *Place_holder = NULL;
char place_holder = '%';
static char *kwlist[] = {"query", "vars", NULL}; static char *kwlist[] = {"query", "vars", "place_holder", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O", kwlist, if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OO", kwlist,
&operation, &vars)) { &operation, &vars, &Place_holder)) {
return NULL; return NULL;
} }
return _psyco_curs_mogrify(self, operation, vars); if (Place_holder != NULL) {
if (!(Place_holder = curs_validate_sql_basic(self, Place_holder))) {
psyco_set_error(ProgrammingError, self, "can't parse place holder");
return NULL;
}
if (Bytes_GET_SIZE(Place_holder) != 1) {
psyco_set_error(ProgrammingError, self, "place holder must be a character");
return NULL;
}
place_holder = Bytes_AS_STRING(Place_holder)[0];
}
return _psyco_curs_mogrify(self, operation, vars, place_holder);
} }
@ -1016,12 +1091,26 @@ curs_callproc(cursorObject *self, PyObject *args)
PyObject *pvals = NULL; PyObject *pvals = NULL;
char *cpname = NULL; char *cpname = NULL;
char **scpnames = NULL; char **scpnames = NULL;
PyObject *Place_holder = NULL;
char place_holder = '%';
if (!PyArg_ParseTuple(args, "s#|O", &procname, &procname_len, if (!PyArg_ParseTuple(args, "s#|OO", &procname, &procname_len,
&parameters)) { &parameters, &Place_holder)) {
goto exit; goto exit;
} }
if (Place_holder != NULL) {
if (!(Place_holder = curs_validate_sql_basic(self, Place_holder))) {
psyco_set_error(ProgrammingError, self, "can't parse place holder");
return NULL;
}
if (Bytes_GET_SIZE(Place_holder) != 1) {
psyco_set_error(ProgrammingError, self, "place holder must be a character");
return NULL;
}
place_holder = Bytes_AS_STRING(Place_holder)[0];
}
EXC_IF_CURS_CLOSED(self); EXC_IF_CURS_CLOSED(self);
EXC_IF_ASYNC_IN_PROGRESS(self, callproc); EXC_IF_ASYNC_IN_PROGRESS(self, callproc);
EXC_IF_TPC_PREPARED(self->conn, callproc); EXC_IF_TPC_PREPARED(self->conn, callproc);
@ -1114,7 +1203,7 @@ curs_callproc(cursorObject *self, PyObject *args)
} }
if (0 <= _psyco_curs_execute( if (0 <= _psyco_curs_execute(
self, operation, pvals, self->conn->async, 0)) { self, operation, pvals, place_holder, self->conn->async, 0)) {
/* The dict case is outside DBAPI scope anyway, so simply return None */ /* The dict case is outside DBAPI scope anyway, so simply return None */
if (using_dict) { if (using_dict) {
res = Py_None; res = Py_None;

View File

@ -59,7 +59,7 @@ HIDDEN RAISES BORROWED PyObject *psyco_set_error(
HIDDEN PyObject *psyco_get_decimal_type(void); HIDDEN PyObject *psyco_get_decimal_type(void);
HIDDEN PyObject *Bytes_Format(PyObject *format, PyObject *args); HIDDEN PyObject *Bytes_Format(PyObject *format, PyObject *args, char place_holder);
#endif /* !defined(UTILS_H) */ #endif /* !defined(UTILS_H) */