first commit based on psycopg2 2.9 version
This commit is contained in:
104
tests/__init__.py
Executable file
104
tests/__init__.py
Executable file
@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# psycopg2 test suite
|
||||
#
|
||||
# Copyright (C) 2007-2019 Federico Di Gregorio <fog@debian.org>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
# Convert warnings into errors here. We can't do it with -W because on
|
||||
# Travis importing site raises a warning.
|
||||
import warnings
|
||||
warnings.simplefilter('error') # noqa
|
||||
|
||||
import sys
|
||||
from .testconfig import dsn
|
||||
import unittest
|
||||
|
||||
from . import test_async
|
||||
from . import test_bugX000
|
||||
from . import test_bug_gc
|
||||
from . import test_cancel
|
||||
from . import test_connection
|
||||
from . import test_copy
|
||||
from . import test_cursor
|
||||
from . import test_dates
|
||||
from . import test_errcodes
|
||||
from . import test_errors
|
||||
from . import test_extras_dictcursor
|
||||
from . import test_fast_executemany
|
||||
from . import test_green
|
||||
from . import test_ipaddress
|
||||
from . import test_lobject
|
||||
from . import test_module
|
||||
from . import test_notify
|
||||
from . import test_psycopg2_dbapi20
|
||||
from . import test_quote
|
||||
from . import test_replication
|
||||
from . import test_sql
|
||||
from . import test_transaction
|
||||
from . import test_types_basic
|
||||
from . import test_types_extras
|
||||
from . import test_with
|
||||
|
||||
|
||||
def test_suite():
|
||||
# If connection to test db fails, bail out early.
|
||||
import psycopg2
|
||||
try:
|
||||
cnn = psycopg2.connect(dsn)
|
||||
except Exception as e:
|
||||
print("Failed connection to test db:", e.__class__.__name__, e)
|
||||
print("Please set env vars 'PSYCOPG2_TESTDB*' to valid values.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
cnn.close()
|
||||
|
||||
suite = unittest.TestSuite()
|
||||
suite.addTest(test_async.test_suite())
|
||||
suite.addTest(test_bugX000.test_suite())
|
||||
suite.addTest(test_bug_gc.test_suite())
|
||||
suite.addTest(test_cancel.test_suite())
|
||||
suite.addTest(test_connection.test_suite())
|
||||
suite.addTest(test_copy.test_suite())
|
||||
suite.addTest(test_cursor.test_suite())
|
||||
suite.addTest(test_dates.test_suite())
|
||||
suite.addTest(test_errcodes.test_suite())
|
||||
suite.addTest(test_errors.test_suite())
|
||||
suite.addTest(test_extras_dictcursor.test_suite())
|
||||
suite.addTest(test_fast_executemany.test_suite())
|
||||
suite.addTest(test_green.test_suite())
|
||||
suite.addTest(test_ipaddress.test_suite())
|
||||
suite.addTest(test_lobject.test_suite())
|
||||
suite.addTest(test_module.test_suite())
|
||||
suite.addTest(test_notify.test_suite())
|
||||
suite.addTest(test_psycopg2_dbapi20.test_suite())
|
||||
suite.addTest(test_quote.test_suite())
|
||||
suite.addTest(test_replication.test_suite())
|
||||
suite.addTest(test_sql.test_suite())
|
||||
suite.addTest(test_transaction.test_suite())
|
||||
suite.addTest(test_types_basic.test_suite())
|
||||
suite.addTest(test_types_extras.test_suite())
|
||||
suite.addTest(test_with.test_suite())
|
||||
return suite
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(defaultTest='test_suite')
|
||||
862
tests/dbapi20.py
Normal file
862
tests/dbapi20.py
Normal file
@ -0,0 +1,862 @@
|
||||
#!/usr/bin/env python
|
||||
''' Python DB API 2.0 driver compliance unit test suite.
|
||||
|
||||
This software is Public Domain and may be used without restrictions.
|
||||
|
||||
"Now we have booze and barflies entering the discussion, plus rumours of
|
||||
DBAs on drugs... and I won't tell you what flashes through my mind each
|
||||
time I read the subject line with 'Anal Compliance' in it. All around
|
||||
this is turning out to be a thoroughly unwholesome unit test."
|
||||
|
||||
-- Ian Bicking
|
||||
'''
|
||||
|
||||
__rcs_id__ = '$Id: dbapi20.py,v 1.11 2005/01/02 02:41:01 zenzen Exp $'
|
||||
__version__ = '$Revision: 1.12 $'[11:-2]
|
||||
__author__ = 'Stuart Bishop <stuart@stuartbishop.net>'
|
||||
|
||||
import unittest
|
||||
import time
|
||||
import sys
|
||||
|
||||
|
||||
# Revision 1.12 2009/02/06 03:35:11 kf7xm
|
||||
# Tested okay with Python 3.0, includes last minute patches from Mark H.
|
||||
#
|
||||
# Revision 1.1.1.1.2.1 2008/09/20 19:54:59 rupole
|
||||
# Include latest changes from main branch
|
||||
# Updates for py3k
|
||||
#
|
||||
# Revision 1.11 2005/01/02 02:41:01 zenzen
|
||||
# Update author email address
|
||||
#
|
||||
# Revision 1.10 2003/10/09 03:14:14 zenzen
|
||||
# Add test for DB API 2.0 optional extension, where database exceptions
|
||||
# are exposed as attributes on the Connection object.
|
||||
#
|
||||
# Revision 1.9 2003/08/13 01:16:36 zenzen
|
||||
# Minor tweak from Stefan Fleiter
|
||||
#
|
||||
# Revision 1.8 2003/04/10 00:13:25 zenzen
|
||||
# Changes, as per suggestions by M.-A. Lemburg
|
||||
# - Add a table prefix, to ensure namespace collisions can always be avoided
|
||||
#
|
||||
# Revision 1.7 2003/02/26 23:33:37 zenzen
|
||||
# Break out DDL into helper functions, as per request by David Rushby
|
||||
#
|
||||
# Revision 1.6 2003/02/21 03:04:33 zenzen
|
||||
# Stuff from Henrik Ekelund:
|
||||
# added test_None
|
||||
# added test_nextset & hooks
|
||||
#
|
||||
# Revision 1.5 2003/02/17 22:08:43 zenzen
|
||||
# Implement suggestions and code from Henrik Eklund - test that cursor.arraysize
|
||||
# defaults to 1 & generic cursor.callproc test added
|
||||
#
|
||||
# Revision 1.4 2003/02/15 00:16:33 zenzen
|
||||
# Changes, as per suggestions and bug reports by M.-A. Lemburg,
|
||||
# Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar
|
||||
# - Class renamed
|
||||
# - Now a subclass of TestCase, to avoid requiring the driver stub
|
||||
# to use multiple inheritance
|
||||
# - Reversed the polarity of buggy test in test_description
|
||||
# - Test exception hierarchy correctly
|
||||
# - self.populate is now self._populate(), so if a driver stub
|
||||
# overrides self.ddl1 this change propagates
|
||||
# - VARCHAR columns now have a width, which will hopefully make the
|
||||
# DDL even more portable (this will be reversed if it causes more problems)
|
||||
# - cursor.rowcount being checked after various execute and fetchXXX methods
|
||||
# - Check for fetchall and fetchmany returning empty lists after results
|
||||
# are exhausted (already checking for empty lists if select retrieved
|
||||
# nothing
|
||||
# - Fix bugs in test_setoutputsize_basic and test_setinputsizes
|
||||
#
|
||||
|
||||
class DatabaseAPI20Test(unittest.TestCase):
|
||||
''' Test a database self.driver for DB API 2.0 compatibility.
|
||||
This implementation tests Gadfly, but the TestCase
|
||||
is structured so that other self.drivers can subclass this
|
||||
test case to ensure compliance with the DB-API. It is
|
||||
expected that this TestCase may be expanded in the future
|
||||
if ambiguities or edge conditions are discovered.
|
||||
|
||||
The 'Optional Extensions' are not yet being tested.
|
||||
|
||||
self.drivers should subclass this test, overriding setUp, tearDown,
|
||||
self.driver, connect_args and connect_kw_args. Class specification
|
||||
should be as follows:
|
||||
|
||||
from . import dbapi20
|
||||
class mytest(dbapi20.DatabaseAPI20Test):
|
||||
[...]
|
||||
|
||||
Don't 'from .dbapi20 import DatabaseAPI20Test', or you will
|
||||
confuse the unit tester - just 'from . import dbapi20'.
|
||||
'''
|
||||
|
||||
# The self.driver module. This should be the module where the 'connect'
|
||||
# method is to be found
|
||||
driver = None
|
||||
connect_args = () # List of arguments to pass to connect
|
||||
connect_kw_args = {} # Keyword arguments for connect
|
||||
table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables
|
||||
|
||||
ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix
|
||||
ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix
|
||||
xddl1 = 'drop table %sbooze' % table_prefix
|
||||
xddl2 = 'drop table %sbarflys' % table_prefix
|
||||
|
||||
lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase
|
||||
|
||||
# Some drivers may need to override these helpers, for example adding
|
||||
# a 'commit' after the execute.
|
||||
def executeDDL1(self,cursor):
|
||||
cursor.execute(self.ddl1)
|
||||
|
||||
def executeDDL2(self,cursor):
|
||||
cursor.execute(self.ddl2)
|
||||
|
||||
def setUp(self):
|
||||
''' self.drivers should override this method to perform required setup
|
||||
if any is necessary, such as creating the database.
|
||||
'''
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
''' self.drivers should override this method to perform required cleanup
|
||||
if any is necessary, such as deleting the test database.
|
||||
The default drops the tables that may be created.
|
||||
'''
|
||||
con = self._connect()
|
||||
try:
|
||||
cur = con.cursor()
|
||||
for ddl in (self.xddl1,self.xddl2):
|
||||
try:
|
||||
cur.execute(ddl)
|
||||
con.commit()
|
||||
except self.driver.Error:
|
||||
# Assume table didn't exist. Other tests will check if
|
||||
# execute is busted.
|
||||
pass
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def _connect(self):
|
||||
try:
|
||||
return self.driver.connect(
|
||||
*self.connect_args,**self.connect_kw_args
|
||||
)
|
||||
except AttributeError:
|
||||
self.fail("No connect method found in self.driver module")
|
||||
|
||||
def test_connect(self):
|
||||
con = self._connect()
|
||||
con.close()
|
||||
|
||||
def test_apilevel(self):
|
||||
try:
|
||||
# Must exist
|
||||
apilevel = self.driver.apilevel
|
||||
# Must equal 2.0
|
||||
self.assertEqual(apilevel,'2.0')
|
||||
except AttributeError:
|
||||
self.fail("Driver doesn't define apilevel")
|
||||
|
||||
def test_threadsafety(self):
|
||||
try:
|
||||
# Must exist
|
||||
threadsafety = self.driver.threadsafety
|
||||
# Must be a valid value
|
||||
self.failUnless(threadsafety in (0,1,2,3))
|
||||
except AttributeError:
|
||||
self.fail("Driver doesn't define threadsafety")
|
||||
|
||||
def test_paramstyle(self):
|
||||
try:
|
||||
# Must exist
|
||||
paramstyle = self.driver.paramstyle
|
||||
# Must be a valid value
|
||||
self.failUnless(paramstyle in (
|
||||
'qmark','numeric','named','format','pyformat'
|
||||
))
|
||||
except AttributeError:
|
||||
self.fail("Driver doesn't define paramstyle")
|
||||
|
||||
def test_Exceptions(self):
|
||||
# Make sure required exceptions exist, and are in the
|
||||
# defined hierarchy.
|
||||
self.failUnless(issubclass(self.driver.Warning,Exception))
|
||||
self.failUnless(issubclass(self.driver.Error,Exception))
|
||||
self.failUnless(
|
||||
issubclass(self.driver.InterfaceError,self.driver.Error)
|
||||
)
|
||||
self.failUnless(
|
||||
issubclass(self.driver.DatabaseError,self.driver.Error)
|
||||
)
|
||||
self.failUnless(
|
||||
issubclass(self.driver.OperationalError,self.driver.Error)
|
||||
)
|
||||
self.failUnless(
|
||||
issubclass(self.driver.IntegrityError,self.driver.Error)
|
||||
)
|
||||
self.failUnless(
|
||||
issubclass(self.driver.InternalError,self.driver.Error)
|
||||
)
|
||||
self.failUnless(
|
||||
issubclass(self.driver.ProgrammingError,self.driver.Error)
|
||||
)
|
||||
self.failUnless(
|
||||
issubclass(self.driver.NotSupportedError,self.driver.Error)
|
||||
)
|
||||
|
||||
def test_ExceptionsAsConnectionAttributes(self):
|
||||
# OPTIONAL EXTENSION
|
||||
# Test for the optional DB API 2.0 extension, where the exceptions
|
||||
# are exposed as attributes on the Connection object
|
||||
# I figure this optional extension will be implemented by any
|
||||
# driver author who is using this test suite, so it is enabled
|
||||
# by default.
|
||||
con = self._connect()
|
||||
drv = self.driver
|
||||
self.failUnless(con.Warning is drv.Warning)
|
||||
self.failUnless(con.Error is drv.Error)
|
||||
self.failUnless(con.InterfaceError is drv.InterfaceError)
|
||||
self.failUnless(con.DatabaseError is drv.DatabaseError)
|
||||
self.failUnless(con.OperationalError is drv.OperationalError)
|
||||
self.failUnless(con.IntegrityError is drv.IntegrityError)
|
||||
self.failUnless(con.InternalError is drv.InternalError)
|
||||
self.failUnless(con.ProgrammingError is drv.ProgrammingError)
|
||||
self.failUnless(con.NotSupportedError is drv.NotSupportedError)
|
||||
|
||||
|
||||
def test_commit(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
# Commit must work, even if it doesn't do anything
|
||||
con.commit()
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_rollback(self):
|
||||
con = self._connect()
|
||||
# If rollback is defined, it should either work or throw
|
||||
# the documented exception
|
||||
if hasattr(con,'rollback'):
|
||||
try:
|
||||
con.rollback()
|
||||
except self.driver.NotSupportedError:
|
||||
pass
|
||||
|
||||
def test_cursor(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
cur = con.cursor()
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_cursor_isolation(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
# Make sure cursors created from the same connection have
|
||||
# the documented transaction isolation level
|
||||
cur1 = con.cursor()
|
||||
cur2 = con.cursor()
|
||||
self.executeDDL1(cur1)
|
||||
cur1.execute("insert into %sbooze values ('Victoria Bitter')" % (
|
||||
self.table_prefix
|
||||
))
|
||||
cur2.execute("select name from %sbooze" % self.table_prefix)
|
||||
booze = cur2.fetchall()
|
||||
self.assertEqual(len(booze),1)
|
||||
self.assertEqual(len(booze[0]),1)
|
||||
self.assertEqual(booze[0][0],'Victoria Bitter')
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_description(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
cur = con.cursor()
|
||||
self.executeDDL1(cur)
|
||||
self.assertEqual(cur.description,None,
|
||||
'cursor.description should be none after executing a '
|
||||
'statement that can return no rows (such as DDL)'
|
||||
)
|
||||
cur.execute('select name from %sbooze' % self.table_prefix)
|
||||
self.assertEqual(len(cur.description),1,
|
||||
'cursor.description describes too many columns'
|
||||
)
|
||||
self.assertEqual(len(cur.description[0]),7,
|
||||
'cursor.description[x] tuples must have 7 elements'
|
||||
)
|
||||
self.assertEqual(cur.description[0][0].lower(),'name',
|
||||
'cursor.description[x][0] must return column name'
|
||||
)
|
||||
self.assertEqual(cur.description[0][1],self.driver.STRING,
|
||||
'cursor.description[x][1] must return column type. Got %r'
|
||||
% cur.description[0][1]
|
||||
)
|
||||
|
||||
# Make sure self.description gets reset
|
||||
self.executeDDL2(cur)
|
||||
self.assertEqual(cur.description,None,
|
||||
'cursor.description not being set to None when executing '
|
||||
'no-result statements (eg. DDL)'
|
||||
)
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_rowcount(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
cur = con.cursor()
|
||||
self.executeDDL1(cur)
|
||||
self.assertEqual(cur.rowcount,-1,
|
||||
'cursor.rowcount should be -1 after executing no-result '
|
||||
'statements'
|
||||
)
|
||||
cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
|
||||
self.table_prefix
|
||||
))
|
||||
self.failUnless(cur.rowcount in (-1,1),
|
||||
'cursor.rowcount should == number or rows inserted, or '
|
||||
'set to -1 after executing an insert statement'
|
||||
)
|
||||
cur.execute("select name from %sbooze" % self.table_prefix)
|
||||
self.failUnless(cur.rowcount in (-1,1),
|
||||
'cursor.rowcount should == number of rows returned, or '
|
||||
'set to -1 after executing a select statement'
|
||||
)
|
||||
self.executeDDL2(cur)
|
||||
self.assertEqual(cur.rowcount,-1,
|
||||
'cursor.rowcount not being reset to -1 after executing '
|
||||
'no-result statements'
|
||||
)
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
lower_func = 'lower'
|
||||
def test_callproc(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
cur = con.cursor()
|
||||
if self.lower_func and hasattr(cur,'callproc'):
|
||||
r = cur.callproc(self.lower_func,('FOO',))
|
||||
self.assertEqual(len(r),1)
|
||||
self.assertEqual(r[0],'FOO')
|
||||
r = cur.fetchall()
|
||||
self.assertEqual(len(r),1,'callproc produced no result set')
|
||||
self.assertEqual(len(r[0]),1,
|
||||
'callproc produced invalid result set'
|
||||
)
|
||||
self.assertEqual(r[0][0],'foo',
|
||||
'callproc produced invalid results'
|
||||
)
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_close(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
cur = con.cursor()
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
# cursor.execute should raise an Error if called after connection
|
||||
# closed
|
||||
self.assertRaises(self.driver.Error,self.executeDDL1,cur)
|
||||
|
||||
# connection.commit should raise an Error if called after connection'
|
||||
# closed.'
|
||||
self.assertRaises(self.driver.Error,con.commit)
|
||||
|
||||
# connection.close should raise an Error if called more than once
|
||||
# Issue discussed on DB-SIG: consensus seem that close() should not
|
||||
# raised if called on closed objects. Issue reported back to Stuart.
|
||||
# self.assertRaises(self.driver.Error,con.close)
|
||||
|
||||
def test_execute(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
cur = con.cursor()
|
||||
self._paraminsert(cur)
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def _paraminsert(self,cur):
|
||||
self.executeDDL1(cur)
|
||||
cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
|
||||
self.table_prefix
|
||||
))
|
||||
self.failUnless(cur.rowcount in (-1,1))
|
||||
|
||||
if self.driver.paramstyle == 'qmark':
|
||||
cur.execute(
|
||||
'insert into %sbooze values (?)' % self.table_prefix,
|
||||
("Cooper's",)
|
||||
)
|
||||
elif self.driver.paramstyle == 'numeric':
|
||||
cur.execute(
|
||||
'insert into %sbooze values (:1)' % self.table_prefix,
|
||||
("Cooper's",)
|
||||
)
|
||||
elif self.driver.paramstyle == 'named':
|
||||
cur.execute(
|
||||
'insert into %sbooze values (:beer)' % self.table_prefix,
|
||||
{'beer':"Cooper's"}
|
||||
)
|
||||
elif self.driver.paramstyle == 'format':
|
||||
cur.execute(
|
||||
'insert into %sbooze values (%%s)' % self.table_prefix,
|
||||
("Cooper's",)
|
||||
)
|
||||
elif self.driver.paramstyle == 'pyformat':
|
||||
cur.execute(
|
||||
'insert into %sbooze values (%%(beer)s)' % self.table_prefix,
|
||||
{'beer':"Cooper's"}
|
||||
)
|
||||
else:
|
||||
self.fail('Invalid paramstyle')
|
||||
self.failUnless(cur.rowcount in (-1,1))
|
||||
|
||||
cur.execute('select name from %sbooze' % self.table_prefix)
|
||||
res = cur.fetchall()
|
||||
self.assertEqual(len(res),2,'cursor.fetchall returned too few rows')
|
||||
beers = [res[0][0],res[1][0]]
|
||||
beers.sort()
|
||||
self.assertEqual(beers[0],"Cooper's",
|
||||
'cursor.fetchall retrieved incorrect data, or data inserted '
|
||||
'incorrectly'
|
||||
)
|
||||
self.assertEqual(beers[1],"Victoria Bitter",
|
||||
'cursor.fetchall retrieved incorrect data, or data inserted '
|
||||
'incorrectly'
|
||||
)
|
||||
|
||||
def test_executemany(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
cur = con.cursor()
|
||||
self.executeDDL1(cur)
|
||||
largs = [ ("Cooper's",) , ("Boag's",) ]
|
||||
margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ]
|
||||
if self.driver.paramstyle == 'qmark':
|
||||
cur.executemany(
|
||||
'insert into %sbooze values (?)' % self.table_prefix,
|
||||
largs
|
||||
)
|
||||
elif self.driver.paramstyle == 'numeric':
|
||||
cur.executemany(
|
||||
'insert into %sbooze values (:1)' % self.table_prefix,
|
||||
largs
|
||||
)
|
||||
elif self.driver.paramstyle == 'named':
|
||||
cur.executemany(
|
||||
'insert into %sbooze values (:beer)' % self.table_prefix,
|
||||
margs
|
||||
)
|
||||
elif self.driver.paramstyle == 'format':
|
||||
cur.executemany(
|
||||
'insert into %sbooze values (%%s)' % self.table_prefix,
|
||||
largs
|
||||
)
|
||||
elif self.driver.paramstyle == 'pyformat':
|
||||
cur.executemany(
|
||||
'insert into %sbooze values (%%(beer)s)' % (
|
||||
self.table_prefix
|
||||
),
|
||||
margs
|
||||
)
|
||||
else:
|
||||
self.fail('Unknown paramstyle')
|
||||
self.failUnless(cur.rowcount in (-1,2),
|
||||
'insert using cursor.executemany set cursor.rowcount to '
|
||||
'incorrect value %r' % cur.rowcount
|
||||
)
|
||||
cur.execute('select name from %sbooze' % self.table_prefix)
|
||||
res = cur.fetchall()
|
||||
self.assertEqual(len(res),2,
|
||||
'cursor.fetchall retrieved incorrect number of rows'
|
||||
)
|
||||
beers = [res[0][0],res[1][0]]
|
||||
beers.sort()
|
||||
self.assertEqual(beers[0],"Boag's",'incorrect data retrieved')
|
||||
self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved')
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_fetchone(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
cur = con.cursor()
|
||||
|
||||
# cursor.fetchone should raise an Error if called before
|
||||
# executing a select-type query
|
||||
self.assertRaises(self.driver.Error,cur.fetchone)
|
||||
|
||||
# cursor.fetchone should raise an Error if called after
|
||||
# executing a query that cannot return rows
|
||||
self.executeDDL1(cur)
|
||||
self.assertRaises(self.driver.Error,cur.fetchone)
|
||||
|
||||
cur.execute('select name from %sbooze' % self.table_prefix)
|
||||
self.assertEqual(cur.fetchone(),None,
|
||||
'cursor.fetchone should return None if a query retrieves '
|
||||
'no rows'
|
||||
)
|
||||
self.failUnless(cur.rowcount in (-1,0))
|
||||
|
||||
# cursor.fetchone should raise an Error if called after
|
||||
# executing a query that cannot return rows
|
||||
cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
|
||||
self.table_prefix
|
||||
))
|
||||
self.assertRaises(self.driver.Error,cur.fetchone)
|
||||
|
||||
cur.execute('select name from %sbooze' % self.table_prefix)
|
||||
r = cur.fetchone()
|
||||
self.assertEqual(len(r),1,
|
||||
'cursor.fetchone should have retrieved a single row'
|
||||
)
|
||||
self.assertEqual(r[0],'Victoria Bitter',
|
||||
'cursor.fetchone retrieved incorrect data'
|
||||
)
|
||||
self.assertEqual(cur.fetchone(),None,
|
||||
'cursor.fetchone should return None if no more rows available'
|
||||
)
|
||||
self.failUnless(cur.rowcount in (-1,1))
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
samples = [
|
||||
'Carlton Cold',
|
||||
'Carlton Draft',
|
||||
'Mountain Goat',
|
||||
'Redback',
|
||||
'Victoria Bitter',
|
||||
'XXXX'
|
||||
]
|
||||
|
||||
def _populate(self):
|
||||
''' Return a list of sql commands to setup the DB for the fetch
|
||||
tests.
|
||||
'''
|
||||
populate = [
|
||||
f"insert into {self.table_prefix}booze values ('{s}')"
|
||||
for s in self.samples
|
||||
]
|
||||
return populate
|
||||
|
||||
def test_fetchmany(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
cur = con.cursor()
|
||||
|
||||
# cursor.fetchmany should raise an Error if called without
|
||||
#issuing a query
|
||||
self.assertRaises(self.driver.Error,cur.fetchmany,4)
|
||||
|
||||
self.executeDDL1(cur)
|
||||
for sql in self._populate():
|
||||
cur.execute(sql)
|
||||
|
||||
cur.execute('select name from %sbooze' % self.table_prefix)
|
||||
r = cur.fetchmany()
|
||||
self.assertEqual(len(r),1,
|
||||
'cursor.fetchmany retrieved incorrect number of rows, '
|
||||
'default of arraysize is one.'
|
||||
)
|
||||
cur.arraysize=10
|
||||
r = cur.fetchmany(3) # Should get 3 rows
|
||||
self.assertEqual(len(r),3,
|
||||
'cursor.fetchmany retrieved incorrect number of rows'
|
||||
)
|
||||
r = cur.fetchmany(4) # Should get 2 more
|
||||
self.assertEqual(len(r),2,
|
||||
'cursor.fetchmany retrieved incorrect number of rows'
|
||||
)
|
||||
r = cur.fetchmany(4) # Should be an empty sequence
|
||||
self.assertEqual(len(r),0,
|
||||
'cursor.fetchmany should return an empty sequence after '
|
||||
'results are exhausted'
|
||||
)
|
||||
self.failUnless(cur.rowcount in (-1,6))
|
||||
|
||||
# Same as above, using cursor.arraysize
|
||||
cur.arraysize=4
|
||||
cur.execute('select name from %sbooze' % self.table_prefix)
|
||||
r = cur.fetchmany() # Should get 4 rows
|
||||
self.assertEqual(len(r),4,
|
||||
'cursor.arraysize not being honoured by fetchmany'
|
||||
)
|
||||
r = cur.fetchmany() # Should get 2 more
|
||||
self.assertEqual(len(r),2)
|
||||
r = cur.fetchmany() # Should be an empty sequence
|
||||
self.assertEqual(len(r),0)
|
||||
self.failUnless(cur.rowcount in (-1,6))
|
||||
|
||||
cur.arraysize=6
|
||||
cur.execute('select name from %sbooze' % self.table_prefix)
|
||||
rows = cur.fetchmany() # Should get all rows
|
||||
self.failUnless(cur.rowcount in (-1,6))
|
||||
self.assertEqual(len(rows),6)
|
||||
self.assertEqual(len(rows),6)
|
||||
rows = [r[0] for r in rows]
|
||||
rows.sort()
|
||||
|
||||
# Make sure we get the right data back out
|
||||
for i in range(0,6):
|
||||
self.assertEqual(rows[i],self.samples[i],
|
||||
'incorrect data retrieved by cursor.fetchmany'
|
||||
)
|
||||
|
||||
rows = cur.fetchmany() # Should return an empty list
|
||||
self.assertEqual(len(rows),0,
|
||||
'cursor.fetchmany should return an empty sequence if '
|
||||
'called after the whole result set has been fetched'
|
||||
)
|
||||
self.failUnless(cur.rowcount in (-1,6))
|
||||
|
||||
self.executeDDL2(cur)
|
||||
cur.execute('select name from %sbarflys' % self.table_prefix)
|
||||
r = cur.fetchmany() # Should get empty sequence
|
||||
self.assertEqual(len(r),0,
|
||||
'cursor.fetchmany should return an empty sequence if '
|
||||
'query retrieved no rows'
|
||||
)
|
||||
self.failUnless(cur.rowcount in (-1,0))
|
||||
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_fetchall(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
cur = con.cursor()
|
||||
# cursor.fetchall should raise an Error if called
|
||||
# without executing a query that may return rows (such
|
||||
# as a select)
|
||||
self.assertRaises(self.driver.Error, cur.fetchall)
|
||||
|
||||
self.executeDDL1(cur)
|
||||
for sql in self._populate():
|
||||
cur.execute(sql)
|
||||
|
||||
# cursor.fetchall should raise an Error if called
|
||||
# after executing a a statement that cannot return rows
|
||||
self.assertRaises(self.driver.Error,cur.fetchall)
|
||||
|
||||
cur.execute('select name from %sbooze' % self.table_prefix)
|
||||
rows = cur.fetchall()
|
||||
self.failUnless(cur.rowcount in (-1,len(self.samples)))
|
||||
self.assertEqual(len(rows),len(self.samples),
|
||||
'cursor.fetchall did not retrieve all rows'
|
||||
)
|
||||
rows = [r[0] for r in rows]
|
||||
rows.sort()
|
||||
for i in range(0,len(self.samples)):
|
||||
self.assertEqual(rows[i],self.samples[i],
|
||||
'cursor.fetchall retrieved incorrect rows'
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
self.assertEqual(
|
||||
len(rows),0,
|
||||
'cursor.fetchall should return an empty list if called '
|
||||
'after the whole result set has been fetched'
|
||||
)
|
||||
self.failUnless(cur.rowcount in (-1,len(self.samples)))
|
||||
|
||||
self.executeDDL2(cur)
|
||||
cur.execute('select name from %sbarflys' % self.table_prefix)
|
||||
rows = cur.fetchall()
|
||||
self.failUnless(cur.rowcount in (-1,0))
|
||||
self.assertEqual(len(rows),0,
|
||||
'cursor.fetchall should return an empty list if '
|
||||
'a select query returns no rows'
|
||||
)
|
||||
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_mixedfetch(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
cur = con.cursor()
|
||||
self.executeDDL1(cur)
|
||||
for sql in self._populate():
|
||||
cur.execute(sql)
|
||||
|
||||
cur.execute('select name from %sbooze' % self.table_prefix)
|
||||
rows1 = cur.fetchone()
|
||||
rows23 = cur.fetchmany(2)
|
||||
rows4 = cur.fetchone()
|
||||
rows56 = cur.fetchall()
|
||||
self.failUnless(cur.rowcount in (-1,6))
|
||||
self.assertEqual(len(rows23),2,
|
||||
'fetchmany returned incorrect number of rows'
|
||||
)
|
||||
self.assertEqual(len(rows56),2,
|
||||
'fetchall returned incorrect number of rows'
|
||||
)
|
||||
|
||||
rows = [rows1[0]]
|
||||
rows.extend([rows23[0][0],rows23[1][0]])
|
||||
rows.append(rows4[0])
|
||||
rows.extend([rows56[0][0],rows56[1][0]])
|
||||
rows.sort()
|
||||
for i in range(0,len(self.samples)):
|
||||
self.assertEqual(rows[i],self.samples[i],
|
||||
'incorrect data retrieved or inserted'
|
||||
)
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def help_nextset_setUp(self,cur):
|
||||
''' Should create a procedure called deleteme
|
||||
that returns two result sets, first the
|
||||
number of rows in booze then "name from booze"
|
||||
'''
|
||||
raise NotImplementedError('Helper not implemented')
|
||||
#sql="""
|
||||
# create procedure deleteme as
|
||||
# begin
|
||||
# select count(*) from booze
|
||||
# select name from booze
|
||||
# end
|
||||
#"""
|
||||
#cur.execute(sql)
|
||||
|
||||
def help_nextset_tearDown(self,cur):
|
||||
'If cleaning up is needed after nextSetTest'
|
||||
raise NotImplementedError('Helper not implemented')
|
||||
#cur.execute("drop procedure deleteme")
|
||||
|
||||
def test_nextset(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
cur = con.cursor()
|
||||
if not hasattr(cur,'nextset'):
|
||||
return
|
||||
|
||||
try:
|
||||
self.executeDDL1(cur)
|
||||
sql=self._populate()
|
||||
for sql in self._populate():
|
||||
cur.execute(sql)
|
||||
|
||||
self.help_nextset_setUp(cur)
|
||||
|
||||
cur.callproc('deleteme')
|
||||
numberofrows=cur.fetchone()
|
||||
assert numberofrows[0]== len(self.samples)
|
||||
assert cur.nextset()
|
||||
names=cur.fetchall()
|
||||
assert len(names) == len(self.samples)
|
||||
s=cur.nextset()
|
||||
assert s is None, 'No more return sets, should return None'
|
||||
finally:
|
||||
self.help_nextset_tearDown(cur)
|
||||
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_nextset(self):
|
||||
raise NotImplementedError('Drivers need to override this test')
|
||||
|
||||
def test_arraysize(self):
|
||||
# Not much here - rest of the tests for this are in test_fetchmany
|
||||
con = self._connect()
|
||||
try:
|
||||
cur = con.cursor()
|
||||
self.failUnless(hasattr(cur,'arraysize'),
|
||||
'cursor.arraysize must be defined'
|
||||
)
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_setinputsizes(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
cur = con.cursor()
|
||||
cur.setinputsizes( (25,) )
|
||||
self._paraminsert(cur) # Make sure cursor still works
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_setoutputsize_basic(self):
|
||||
# Basic test is to make sure setoutputsize doesn't blow up
|
||||
con = self._connect()
|
||||
try:
|
||||
cur = con.cursor()
|
||||
cur.setoutputsize(1000)
|
||||
cur.setoutputsize(2000,0)
|
||||
self._paraminsert(cur) # Make sure the cursor still works
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_setoutputsize(self):
|
||||
# Real test for setoutputsize is driver dependent
|
||||
raise NotImplementedError('Driver needed to override this test')
|
||||
|
||||
def test_None(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
cur = con.cursor()
|
||||
self.executeDDL1(cur)
|
||||
cur.execute('insert into %sbooze values (NULL)' % self.table_prefix)
|
||||
cur.execute('select name from %sbooze' % self.table_prefix)
|
||||
r = cur.fetchall()
|
||||
self.assertEqual(len(r),1)
|
||||
self.assertEqual(len(r[0]),1)
|
||||
self.assertEqual(r[0][0],None,'NULL value not returned as None')
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_Date(self):
|
||||
d1 = self.driver.Date(2002,12,25)
|
||||
d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0)))
|
||||
# Can we assume this? API doesn't specify, but it seems implied
|
||||
# self.assertEqual(str(d1),str(d2))
|
||||
|
||||
def test_Time(self):
|
||||
t1 = self.driver.Time(13,45,30)
|
||||
t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0)))
|
||||
# Can we assume this? API doesn't specify, but it seems implied
|
||||
# self.assertEqual(str(t1),str(t2))
|
||||
|
||||
def test_Timestamp(self):
|
||||
t1 = self.driver.Timestamp(2002,12,25,13,45,30)
|
||||
t2 = self.driver.TimestampFromTicks(
|
||||
time.mktime((2002,12,25,13,45,30,0,0,0))
|
||||
)
|
||||
# Can we assume this? API doesn't specify, but it seems implied
|
||||
# self.assertEqual(str(t1),str(t2))
|
||||
|
||||
def test_Binary(self):
|
||||
b = self.driver.Binary(b'Something')
|
||||
b = self.driver.Binary(b'')
|
||||
|
||||
def test_STRING(self):
|
||||
self.failUnless(hasattr(self.driver,'STRING'),
|
||||
'module.STRING must be defined'
|
||||
)
|
||||
|
||||
def test_BINARY(self):
|
||||
self.failUnless(hasattr(self.driver,'BINARY'),
|
||||
'module.BINARY must be defined.'
|
||||
)
|
||||
|
||||
def test_NUMBER(self):
|
||||
self.failUnless(hasattr(self.driver,'NUMBER'),
|
||||
'module.NUMBER must be defined.'
|
||||
)
|
||||
|
||||
def test_DATETIME(self):
|
||||
self.failUnless(hasattr(self.driver,'DATETIME'),
|
||||
'module.DATETIME must be defined.'
|
||||
)
|
||||
|
||||
def test_ROWID(self):
|
||||
self.failUnless(hasattr(self.driver,'ROWID'),
|
||||
'module.ROWID must be defined.'
|
||||
)
|
||||
144
tests/dbapi20_tpc.py
Normal file
144
tests/dbapi20_tpc.py
Normal file
@ -0,0 +1,144 @@
|
||||
""" Python DB API 2.0 driver Two Phase Commit compliance test suite.
|
||||
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TwoPhaseCommitTests(unittest.TestCase):
|
||||
|
||||
driver = None
|
||||
|
||||
def connect(self):
|
||||
"""Make a database connection."""
|
||||
raise NotImplementedError
|
||||
|
||||
_last_id = 0
|
||||
_global_id_prefix = "dbapi20_tpc:"
|
||||
|
||||
def make_xid(self, con):
|
||||
id = TwoPhaseCommitTests._last_id
|
||||
TwoPhaseCommitTests._last_id += 1
|
||||
return con.xid(42, f"{self._global_id_prefix}{id}", "qualifier")
|
||||
|
||||
def test_xid(self):
|
||||
con = self.connect()
|
||||
try:
|
||||
xid = con.xid(42, "global", "bqual")
|
||||
except self.driver.NotSupportedError:
|
||||
self.fail("Driver does not support transaction IDs.")
|
||||
|
||||
self.assertEquals(xid[0], 42)
|
||||
self.assertEquals(xid[1], "global")
|
||||
self.assertEquals(xid[2], "bqual")
|
||||
|
||||
# Try some extremes for the transaction ID:
|
||||
xid = con.xid(0, "", "")
|
||||
self.assertEquals(tuple(xid), (0, "", ""))
|
||||
xid = con.xid(0x7fffffff, "a" * 64, "b" * 64)
|
||||
self.assertEquals(tuple(xid), (0x7fffffff, "a" * 64, "b" * 64))
|
||||
|
||||
def test_tpc_begin(self):
|
||||
con = self.connect()
|
||||
try:
|
||||
xid = self.make_xid(con)
|
||||
try:
|
||||
con.tpc_begin(xid)
|
||||
except self.driver.NotSupportedError:
|
||||
self.fail("Driver does not support tpc_begin()")
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_tpc_commit_without_prepare(self):
|
||||
con = self.connect()
|
||||
try:
|
||||
xid = self.make_xid(con)
|
||||
con.tpc_begin(xid)
|
||||
cursor = con.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
con.tpc_commit()
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_tpc_rollback_without_prepare(self):
|
||||
con = self.connect()
|
||||
try:
|
||||
xid = self.make_xid(con)
|
||||
con.tpc_begin(xid)
|
||||
cursor = con.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
con.tpc_rollback()
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_tpc_commit_with_prepare(self):
|
||||
con = self.connect()
|
||||
try:
|
||||
xid = self.make_xid(con)
|
||||
con.tpc_begin(xid)
|
||||
cursor = con.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
con.tpc_prepare()
|
||||
con.tpc_commit()
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_tpc_rollback_with_prepare(self):
|
||||
con = self.connect()
|
||||
try:
|
||||
xid = self.make_xid(con)
|
||||
con.tpc_begin(xid)
|
||||
cursor = con.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
con.tpc_prepare()
|
||||
con.tpc_rollback()
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_tpc_begin_in_transaction_fails(self):
|
||||
con = self.connect()
|
||||
try:
|
||||
xid = self.make_xid(con)
|
||||
|
||||
cursor = con.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
self.assertRaises(self.driver.ProgrammingError,
|
||||
con.tpc_begin, xid)
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_tpc_begin_in_tpc_transaction_fails(self):
|
||||
con = self.connect()
|
||||
try:
|
||||
xid = self.make_xid(con)
|
||||
|
||||
cursor = con.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
self.assertRaises(self.driver.ProgrammingError,
|
||||
con.tpc_begin, xid)
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_commit_in_tpc_fails(self):
|
||||
# calling commit() within a TPC transaction fails with
|
||||
# ProgrammingError.
|
||||
con = self.connect()
|
||||
try:
|
||||
xid = self.make_xid(con)
|
||||
con.tpc_begin(xid)
|
||||
|
||||
self.assertRaises(self.driver.ProgrammingError, con.commit)
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_rollback_in_tpc_fails(self):
|
||||
# calling rollback() within a TPC transaction fails with
|
||||
# ProgrammingError.
|
||||
con = self.connect()
|
||||
try:
|
||||
xid = self.make_xid(con)
|
||||
con.tpc_begin(xid)
|
||||
|
||||
self.assertRaises(self.driver.ProgrammingError, con.rollback)
|
||||
finally:
|
||||
con.close()
|
||||
546
tests/test_async.py
Executable file
546
tests/test_async.py
Executable file
@ -0,0 +1,546 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# test_async.py - unit test for asynchronous API
|
||||
#
|
||||
# Copyright (C) 2010-2019 Jan Urbański <wulczer@wulczer.org>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import gc
|
||||
import time
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.errors
|
||||
from psycopg2 import extensions as ext
|
||||
|
||||
from .testutils import (ConnectingTestCase, StringIO, skip_before_postgres,
|
||||
skip_if_crdb, crdb_version, slow)
|
||||
|
||||
|
||||
class PollableStub:
|
||||
"""A 'pollable' wrapper allowing analysis of the `poll()` calls."""
|
||||
def __init__(self, pollable):
|
||||
self.pollable = pollable
|
||||
self.polls = []
|
||||
|
||||
def fileno(self):
|
||||
return self.pollable.fileno()
|
||||
|
||||
def poll(self):
|
||||
rv = self.pollable.poll()
|
||||
self.polls.append(rv)
|
||||
return rv
|
||||
|
||||
|
||||
class AsyncTests(ConnectingTestCase):
|
||||
|
||||
def setUp(self):
|
||||
ConnectingTestCase.setUp(self)
|
||||
|
||||
self.sync_conn = self.conn
|
||||
self.conn = self.connect(async_=True)
|
||||
|
||||
self.wait(self.conn)
|
||||
|
||||
curs = self.conn.cursor()
|
||||
if crdb_version(self.sync_conn) is not None:
|
||||
curs.execute("set experimental_enable_temp_tables = 'on'")
|
||||
self.wait(curs)
|
||||
|
||||
curs.execute('''
|
||||
CREATE TEMPORARY TABLE table1 (
|
||||
id int PRIMARY KEY
|
||||
)''')
|
||||
self.wait(curs)
|
||||
|
||||
def test_connection_setup(self):
|
||||
cur = self.conn.cursor()
|
||||
sync_cur = self.sync_conn.cursor()
|
||||
del cur, sync_cur
|
||||
|
||||
self.assert_(self.conn.async_)
|
||||
self.assert_(not self.sync_conn.async_)
|
||||
|
||||
# the async connection should be autocommit
|
||||
self.assert_(self.conn.autocommit)
|
||||
self.assertEquals(self.conn.isolation_level, ext.ISOLATION_LEVEL_DEFAULT)
|
||||
|
||||
# check other properties to be found on the connection
|
||||
self.assert_(self.conn.server_version)
|
||||
self.assert_(self.conn.protocol_version in (2, 3))
|
||||
self.assert_(self.conn.encoding in ext.encodings)
|
||||
|
||||
def test_async_named_cursor(self):
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
self.conn.cursor, "name")
|
||||
|
||||
def test_async_select(self):
|
||||
cur = self.conn.cursor()
|
||||
self.assertFalse(self.conn.isexecuting())
|
||||
cur.execute("select 'a'")
|
||||
self.assertTrue(self.conn.isexecuting())
|
||||
|
||||
self.wait(cur)
|
||||
|
||||
self.assertFalse(self.conn.isexecuting())
|
||||
self.assertEquals(cur.fetchone()[0], "a")
|
||||
|
||||
@slow
|
||||
@skip_before_postgres(8, 2)
|
||||
def test_async_callproc(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.callproc("pg_sleep", (0.1, ))
|
||||
self.assertTrue(self.conn.isexecuting())
|
||||
|
||||
self.wait(cur)
|
||||
self.assertFalse(self.conn.isexecuting())
|
||||
|
||||
@slow
|
||||
def test_async_after_async(self):
|
||||
cur = self.conn.cursor()
|
||||
cur2 = self.conn.cursor()
|
||||
del cur2
|
||||
|
||||
cur.execute("insert into table1 values (1)")
|
||||
|
||||
# an async execute after an async one raises an exception
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
cur.execute, "select * from table1")
|
||||
# same for callproc
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
cur.callproc, "version")
|
||||
# but after you've waited it should be good
|
||||
self.wait(cur)
|
||||
cur.execute("select * from table1")
|
||||
self.wait(cur)
|
||||
|
||||
self.assertEquals(cur.fetchall()[0][0], 1)
|
||||
|
||||
cur.execute("delete from table1")
|
||||
self.wait(cur)
|
||||
|
||||
cur.execute("select * from table1")
|
||||
self.wait(cur)
|
||||
|
||||
self.assertEquals(cur.fetchone(), None)
|
||||
|
||||
def test_fetch_after_async(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("select 'a'")
|
||||
|
||||
# a fetch after an asynchronous query should raise an error
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
cur.fetchall)
|
||||
# but after waiting it should work
|
||||
self.wait(cur)
|
||||
self.assertEquals(cur.fetchall()[0][0], "a")
|
||||
|
||||
def test_rollback_while_async(self):
|
||||
cur = self.conn.cursor()
|
||||
|
||||
cur.execute("select 'a'")
|
||||
|
||||
# a rollback should not work in asynchronous mode
|
||||
self.assertRaises(psycopg2.ProgrammingError, self.conn.rollback)
|
||||
|
||||
def test_commit_while_async(self):
|
||||
cur = self.conn.cursor()
|
||||
|
||||
cur.execute("begin")
|
||||
self.wait(cur)
|
||||
|
||||
cur.execute("insert into table1 values (1)")
|
||||
|
||||
# a commit should not work in asynchronous mode
|
||||
self.assertRaises(psycopg2.ProgrammingError, self.conn.commit)
|
||||
self.assertTrue(self.conn.isexecuting())
|
||||
|
||||
# but a manual commit should
|
||||
self.wait(cur)
|
||||
cur.execute("commit")
|
||||
self.wait(cur)
|
||||
|
||||
cur.execute("select * from table1")
|
||||
self.wait(cur)
|
||||
self.assertEquals(cur.fetchall()[0][0], 1)
|
||||
|
||||
cur.execute("delete from table1")
|
||||
self.wait(cur)
|
||||
|
||||
cur.execute("select * from table1")
|
||||
self.wait(cur)
|
||||
self.assertEquals(cur.fetchone(), None)
|
||||
|
||||
def test_set_parameters_while_async(self):
|
||||
cur = self.conn.cursor()
|
||||
|
||||
cur.execute("select 'c'")
|
||||
self.assertTrue(self.conn.isexecuting())
|
||||
|
||||
# getting transaction status works
|
||||
self.assertEquals(self.conn.info.transaction_status,
|
||||
ext.TRANSACTION_STATUS_ACTIVE)
|
||||
self.assertTrue(self.conn.isexecuting())
|
||||
|
||||
# setting connection encoding should fail
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
self.conn.set_client_encoding, "LATIN1")
|
||||
|
||||
# same for transaction isolation
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
self.conn.set_isolation_level, 1)
|
||||
|
||||
def test_reset_while_async(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("select 'c'")
|
||||
self.assertTrue(self.conn.isexecuting())
|
||||
|
||||
# a reset should fail
|
||||
self.assertRaises(psycopg2.ProgrammingError, self.conn.reset)
|
||||
|
||||
def test_async_iter(self):
|
||||
cur = self.conn.cursor()
|
||||
|
||||
cur.execute("begin")
|
||||
self.wait(cur)
|
||||
cur.execute("""
|
||||
insert into table1 values (1);
|
||||
insert into table1 values (2);
|
||||
insert into table1 values (3);
|
||||
""")
|
||||
self.wait(cur)
|
||||
cur.execute("select id from table1 order by id")
|
||||
|
||||
# iteration fails if a query is underway
|
||||
self.assertRaises(psycopg2.ProgrammingError, list, cur)
|
||||
|
||||
# but after it's done it should work
|
||||
self.wait(cur)
|
||||
self.assertEquals(list(cur), [(1, ), (2, ), (3, )])
|
||||
self.assertFalse(self.conn.isexecuting())
|
||||
|
||||
def test_copy_while_async(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("select 'a'")
|
||||
|
||||
# copy should fail
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
cur.copy_from,
|
||||
StringIO("1\n3\n5\n\\.\n"), "table1")
|
||||
|
||||
def test_lobject_while_async(self):
|
||||
# large objects should be prohibited
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
self.conn.lobject)
|
||||
|
||||
def test_async_executemany(self):
|
||||
cur = self.conn.cursor()
|
||||
self.assertRaises(
|
||||
psycopg2.ProgrammingError,
|
||||
cur.executemany, "insert into table1 values (%s)", [1, 2, 3])
|
||||
|
||||
def test_async_scroll(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("""
|
||||
insert into table1 values (1);
|
||||
insert into table1 values (2);
|
||||
insert into table1 values (3);
|
||||
""")
|
||||
self.wait(cur)
|
||||
cur.execute("select id from table1 order by id")
|
||||
|
||||
# scroll should fail if a query is underway
|
||||
self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 1)
|
||||
self.assertTrue(self.conn.isexecuting())
|
||||
|
||||
# but after it's done it should work
|
||||
self.wait(cur)
|
||||
cur.scroll(1)
|
||||
self.assertEquals(cur.fetchall(), [(2, ), (3, )])
|
||||
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("select id from table1 order by id")
|
||||
self.wait(cur)
|
||||
|
||||
cur2 = self.conn.cursor()
|
||||
self.assertRaises(psycopg2.ProgrammingError, cur2.scroll, 1)
|
||||
|
||||
self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 4)
|
||||
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("select id from table1 order by id")
|
||||
self.wait(cur)
|
||||
cur.scroll(2)
|
||||
cur.scroll(-1)
|
||||
self.assertEquals(cur.fetchall(), [(2, ), (3, )])
|
||||
|
||||
def test_scroll(self):
|
||||
cur = self.sync_conn.cursor()
|
||||
cur.execute("create table table1 (id int)")
|
||||
cur.execute("""
|
||||
insert into table1 values (1);
|
||||
insert into table1 values (2);
|
||||
insert into table1 values (3);
|
||||
""")
|
||||
cur.execute("select id from table1 order by id")
|
||||
cur.scroll(2)
|
||||
cur.scroll(-1)
|
||||
self.assertEquals(cur.fetchall(), [(2, ), (3, )])
|
||||
|
||||
def test_async_dont_read_all(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("select repeat('a', 10000); select repeat('b', 10000)")
|
||||
|
||||
# fetch the result
|
||||
self.wait(cur)
|
||||
|
||||
# it should be the result of the second query
|
||||
self.assertEquals(cur.fetchone()[0], "b" * 10000)
|
||||
|
||||
def test_async_subclass(self):
|
||||
class MyConn(ext.connection):
|
||||
def __init__(self, dsn, async_=0):
|
||||
ext.connection.__init__(self, dsn, async_=async_)
|
||||
|
||||
conn = self.connect(connection_factory=MyConn, async_=True)
|
||||
self.assert_(isinstance(conn, MyConn))
|
||||
self.assert_(conn.async_)
|
||||
conn.close()
|
||||
|
||||
@slow
|
||||
@skip_if_crdb("flush on write flakey")
|
||||
def test_flush_on_write(self):
|
||||
# a very large query requires a flush loop to be sent to the backend
|
||||
curs = self.conn.cursor()
|
||||
for mb in 1, 5, 10, 20, 50:
|
||||
size = mb * 1024 * 1024
|
||||
stub = PollableStub(self.conn)
|
||||
curs.execute("select %s;", ('x' * size,))
|
||||
self.wait(stub)
|
||||
self.assertEqual(size, len(curs.fetchone()[0]))
|
||||
if stub.polls.count(ext.POLL_WRITE) > 1:
|
||||
return
|
||||
|
||||
# This is more a testing glitch than an error: it happens
|
||||
# on high load on linux: probably because the kernel has more
|
||||
# buffers ready. A warning may be useful during development,
|
||||
# but an error is bad during regression testing.
|
||||
warnings.warn("sending a large query didn't trigger block on write.")
|
||||
|
||||
def test_sync_poll(self):
|
||||
cur = self.sync_conn.cursor()
|
||||
cur.execute("select 1")
|
||||
# polling with a sync query works
|
||||
cur.connection.poll()
|
||||
self.assertEquals(cur.fetchone()[0], 1)
|
||||
|
||||
@slow
|
||||
@skip_if_crdb("notify")
|
||||
def test_notify(self):
|
||||
cur = self.conn.cursor()
|
||||
sync_cur = self.sync_conn.cursor()
|
||||
|
||||
sync_cur.execute("listen test_notify")
|
||||
self.sync_conn.commit()
|
||||
cur.execute("notify test_notify")
|
||||
self.wait(cur)
|
||||
|
||||
self.assertEquals(self.sync_conn.notifies, [])
|
||||
|
||||
pid = self.conn.info.backend_pid
|
||||
for _ in range(5):
|
||||
self.wait(self.sync_conn)
|
||||
if not self.sync_conn.notifies:
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
self.assertEquals(len(self.sync_conn.notifies), 1)
|
||||
self.assertEquals(self.sync_conn.notifies.pop(),
|
||||
(pid, "test_notify"))
|
||||
return
|
||||
self.fail("No NOTIFY in 2.5 seconds")
|
||||
|
||||
def test_async_fetch_wrong_cursor(self):
|
||||
cur1 = self.conn.cursor()
|
||||
cur2 = self.conn.cursor()
|
||||
cur1.execute("select 1")
|
||||
|
||||
self.wait(cur1)
|
||||
self.assertFalse(self.conn.isexecuting())
|
||||
# fetching from a cursor with no results is an error
|
||||
self.assertRaises(psycopg2.ProgrammingError, cur2.fetchone)
|
||||
# fetching from the correct cursor works
|
||||
self.assertEquals(cur1.fetchone()[0], 1)
|
||||
|
||||
def test_error(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("insert into table1 values (%s)", (1, ))
|
||||
self.wait(cur)
|
||||
cur.execute("insert into table1 values (%s)", (1, ))
|
||||
# this should fail
|
||||
self.assertRaises(psycopg2.IntegrityError, self.wait, cur)
|
||||
cur.execute("insert into table1 values (%s); "
|
||||
"insert into table1 values (%s)", (2, 2))
|
||||
# this should fail as well (Postgres behaviour)
|
||||
self.assertRaises(psycopg2.IntegrityError, self.wait, cur)
|
||||
# but this should work
|
||||
if crdb_version(self.sync_conn) is None:
|
||||
cur.execute("insert into table1 values (%s)", (2, ))
|
||||
self.wait(cur)
|
||||
# and the cursor should be usable afterwards
|
||||
cur.execute("insert into table1 values (%s)", (3, ))
|
||||
self.wait(cur)
|
||||
cur.execute("select * from table1 order by id")
|
||||
self.wait(cur)
|
||||
self.assertEquals(cur.fetchall(), [(1, ), (2, ), (3, )])
|
||||
cur.execute("delete from table1")
|
||||
self.wait(cur)
|
||||
|
||||
def test_stop_on_first_error(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("select 1; select x; select 1/0; select 2")
|
||||
self.assertRaises(psycopg2.errors.UndefinedColumn, self.wait, cur)
|
||||
|
||||
cur.execute("select 1")
|
||||
self.wait(cur)
|
||||
self.assertEqual(cur.fetchone(), (1,))
|
||||
|
||||
def test_error_two_cursors(self):
|
||||
cur = self.conn.cursor()
|
||||
cur2 = self.conn.cursor()
|
||||
cur.execute("select * from no_such_table")
|
||||
self.assertRaises(psycopg2.ProgrammingError, self.wait, cur)
|
||||
cur2.execute("select 1")
|
||||
self.wait(cur2)
|
||||
self.assertEquals(cur2.fetchone()[0], 1)
|
||||
|
||||
@skip_if_crdb("notice")
|
||||
def test_notices(self):
|
||||
del self.conn.notices[:]
|
||||
cur = self.conn.cursor()
|
||||
if self.conn.info.server_version >= 90300:
|
||||
cur.execute("set client_min_messages=debug1")
|
||||
self.wait(cur)
|
||||
cur.execute("create temp table chatty (id serial primary key);")
|
||||
self.wait(cur)
|
||||
self.assertEqual("CREATE TABLE", cur.statusmessage)
|
||||
self.assert_(self.conn.notices)
|
||||
|
||||
def test_async_cursor_gone(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("select 42;")
|
||||
del cur
|
||||
gc.collect()
|
||||
self.assertRaises(psycopg2.InterfaceError, self.wait, self.conn)
|
||||
|
||||
# The connection is still usable
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("select 42;")
|
||||
self.wait(self.conn)
|
||||
self.assertEqual(cur.fetchone(), (42,))
|
||||
|
||||
@skip_if_crdb("copy")
|
||||
def test_async_connection_error_message(self):
|
||||
try:
|
||||
cnn = psycopg2.connect('dbname=thisdatabasedoesntexist', async_=True)
|
||||
self.wait(cnn)
|
||||
except psycopg2.Error as e:
|
||||
self.assertNotEqual(str(e), "asynchronous connection failed",
|
||||
"connection error reason lost")
|
||||
else:
|
||||
self.fail("no exception raised")
|
||||
|
||||
@skip_before_postgres(8, 2)
|
||||
def test_copy_no_hang(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("copy (select 1) to stdout")
|
||||
self.assertRaises(psycopg2.ProgrammingError, self.wait, self.conn)
|
||||
|
||||
@slow
|
||||
@skip_if_crdb("notice")
|
||||
@skip_before_postgres(9, 0)
|
||||
def test_non_block_after_notification(self):
|
||||
from select import select
|
||||
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("""
|
||||
select 1;
|
||||
do $$
|
||||
begin
|
||||
raise notice 'hello';
|
||||
end
|
||||
$$ language plpgsql;
|
||||
select pg_sleep(1);
|
||||
""")
|
||||
|
||||
polls = 0
|
||||
while True:
|
||||
state = self.conn.poll()
|
||||
if state == psycopg2.extensions.POLL_OK:
|
||||
break
|
||||
elif state == psycopg2.extensions.POLL_READ:
|
||||
select([self.conn], [], [], 0.1)
|
||||
elif state == psycopg2.extensions.POLL_WRITE:
|
||||
select([], [self.conn], [], 0.1)
|
||||
else:
|
||||
raise Exception("Unexpected result from poll: %r", state)
|
||||
polls += 1
|
||||
|
||||
self.assert_(polls >= 8, polls)
|
||||
|
||||
def test_poll_noop(self):
|
||||
self.conn.poll()
|
||||
|
||||
@skip_if_crdb("notify")
|
||||
@skip_before_postgres(9, 0)
|
||||
def test_poll_conn_for_notification(self):
|
||||
with self.conn.cursor() as cur:
|
||||
cur.execute("listen test")
|
||||
self.wait(cur)
|
||||
|
||||
with self.sync_conn.cursor() as cur:
|
||||
cur.execute("notify test, 'hello'")
|
||||
self.sync_conn.commit()
|
||||
|
||||
for i in range(10):
|
||||
self.conn.poll()
|
||||
|
||||
if self.conn.notifies:
|
||||
n = self.conn.notifies.pop()
|
||||
self.assertEqual(n.channel, 'test')
|
||||
self.assertEqual(n.payload, 'hello')
|
||||
break
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
self.fail("No notification received")
|
||||
|
||||
def test_close(self):
|
||||
self.conn.close()
|
||||
self.assertTrue(self.conn.closed)
|
||||
self.assertTrue(self.conn.async_)
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
48
tests/test_bugX000.py
Executable file
48
tests/test_bugX000.py
Executable file
@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# bugX000.py - test for DateTime object allocation bug
|
||||
#
|
||||
# Copyright (C) 2007-2019 Federico Di Gregorio <fog@debian.org>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import psycopg2
|
||||
import time
|
||||
import unittest
|
||||
|
||||
|
||||
class DateTimeAllocationBugTestCase(unittest.TestCase):
|
||||
def test_date_time_allocation_bug(self):
|
||||
d1 = psycopg2.Date(2002, 12, 25)
|
||||
d2 = psycopg2.DateFromTicks(time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0)))
|
||||
t1 = psycopg2.Time(13, 45, 30)
|
||||
t2 = psycopg2.TimeFromTicks(time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0)))
|
||||
t1 = psycopg2.Timestamp(2002, 12, 25, 13, 45, 30)
|
||||
t2 = psycopg2.TimestampFromTicks(
|
||||
time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)))
|
||||
del d1, d2, t1, t2
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
52
tests/test_bug_gc.py
Executable file
52
tests/test_bug_gc.py
Executable file
@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# bug_gc.py - test for refcounting/GC bug
|
||||
#
|
||||
# Copyright (C) 2010-2019 Federico Di Gregorio <fog@debian.org>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extensions
|
||||
import unittest
|
||||
import gc
|
||||
|
||||
from .testutils import ConnectingTestCase, skip_if_no_uuid
|
||||
|
||||
|
||||
class StolenReferenceTestCase(ConnectingTestCase):
|
||||
@skip_if_no_uuid
|
||||
def test_stolen_reference_bug(self):
|
||||
def fish(val, cur):
|
||||
gc.collect()
|
||||
return 42
|
||||
UUID = psycopg2.extensions.new_type((2950,), "UUID", fish)
|
||||
psycopg2.extensions.register_type(UUID, self.conn)
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select 'b5219e01-19ab-4994-b71e-149225dc51e4'::uuid")
|
||||
curs.fetchone()
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
117
tests/test_cancel.py
Executable file
117
tests/test_cancel.py
Executable file
@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# test_cancel.py - unit test for query cancellation
|
||||
#
|
||||
# Copyright (C) 2010-2019 Jan Urbański <wulczer@wulczer.org>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import time
|
||||
import threading
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extensions
|
||||
from psycopg2 import extras
|
||||
|
||||
from .testconfig import dsn
|
||||
import unittest
|
||||
from .testutils import ConnectingTestCase, skip_before_postgres, slow
|
||||
from .testutils import skip_if_crdb
|
||||
|
||||
|
||||
class CancelTests(ConnectingTestCase):
|
||||
|
||||
def setUp(self):
|
||||
ConnectingTestCase.setUp(self)
|
||||
|
||||
skip_if_crdb("cancel", self.conn)
|
||||
|
||||
cur = self.conn.cursor()
|
||||
cur.execute('''
|
||||
CREATE TEMPORARY TABLE table1 (
|
||||
id int PRIMARY KEY
|
||||
)''')
|
||||
self.conn.commit()
|
||||
|
||||
def test_empty_cancel(self):
|
||||
self.conn.cancel()
|
||||
|
||||
@slow
|
||||
@skip_before_postgres(8, 2)
|
||||
def test_cancel(self):
|
||||
errors = []
|
||||
|
||||
def neverending(conn):
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
self.assertRaises(psycopg2.extensions.QueryCanceledError,
|
||||
cur.execute, "select pg_sleep(60)")
|
||||
# make sure the connection still works
|
||||
conn.rollback()
|
||||
cur.execute("select 1")
|
||||
self.assertEqual(cur.fetchall(), [(1, )])
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
raise
|
||||
|
||||
def canceller(conn):
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
conn.cancel()
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
raise
|
||||
del cur
|
||||
|
||||
thread1 = threading.Thread(target=neverending, args=(self.conn, ))
|
||||
# wait a bit to make sure that the other thread is already in
|
||||
# pg_sleep -- ugly and racy, but the chances are ridiculously low
|
||||
thread2 = threading.Timer(0.3, canceller, args=(self.conn, ))
|
||||
thread1.start()
|
||||
thread2.start()
|
||||
thread1.join()
|
||||
thread2.join()
|
||||
|
||||
self.assertEqual(errors, [])
|
||||
|
||||
@slow
|
||||
@skip_before_postgres(8, 2)
|
||||
def test_async_cancel(self):
|
||||
async_conn = psycopg2.connect(dsn, async_=True)
|
||||
self.assertRaises(psycopg2.OperationalError, async_conn.cancel)
|
||||
extras.wait_select(async_conn)
|
||||
cur = async_conn.cursor()
|
||||
cur.execute("select pg_sleep(10)")
|
||||
time.sleep(1)
|
||||
self.assertTrue(async_conn.isexecuting())
|
||||
async_conn.cancel()
|
||||
self.assertRaises(psycopg2.extensions.QueryCanceledError,
|
||||
extras.wait_select, async_conn)
|
||||
cur.execute("select 1")
|
||||
extras.wait_select(async_conn)
|
||||
self.assertEqual(cur.fetchall(), [(1, )])
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
1944
tests/test_connection.py
Executable file
1944
tests/test_connection.py
Executable file
File diff suppressed because it is too large
Load Diff
404
tests/test_copy.py
Executable file
404
tests/test_copy.py
Executable file
@ -0,0 +1,404 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# test_copy.py - unit test for COPY support
|
||||
#
|
||||
# Copyright (C) 2010-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import io
|
||||
import sys
|
||||
import string
|
||||
import unittest
|
||||
from .testutils import ConnectingTestCase, skip_before_postgres, slow, StringIO
|
||||
from .testutils import skip_if_crdb
|
||||
from itertools import cycle
|
||||
from subprocess import Popen, PIPE
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extensions
|
||||
from .testutils import skip_copy_if_green, TextIOBase
|
||||
from .testconfig import dsn
|
||||
|
||||
|
||||
class MinimalRead(TextIOBase):
|
||||
"""A file wrapper exposing the minimal interface to copy from."""
|
||||
def __init__(self, f):
|
||||
self.f = f
|
||||
|
||||
def read(self, size):
|
||||
return self.f.read(size)
|
||||
|
||||
def readline(self):
|
||||
return self.f.readline()
|
||||
|
||||
|
||||
class MinimalWrite(TextIOBase):
|
||||
"""A file wrapper exposing the minimal interface to copy to."""
|
||||
def __init__(self, f):
|
||||
self.f = f
|
||||
|
||||
def write(self, data):
|
||||
return self.f.write(data)
|
||||
|
||||
|
||||
@skip_copy_if_green
|
||||
class CopyTests(ConnectingTestCase):
|
||||
|
||||
def setUp(self):
|
||||
ConnectingTestCase.setUp(self)
|
||||
self._create_temp_table()
|
||||
|
||||
def _create_temp_table(self):
|
||||
skip_if_crdb("copy", self.conn)
|
||||
curs = self.conn.cursor()
|
||||
curs.execute('''
|
||||
CREATE TEMPORARY TABLE tcopy (
|
||||
id serial PRIMARY KEY,
|
||||
data text
|
||||
)''')
|
||||
|
||||
@slow
|
||||
def test_copy_from(self):
|
||||
curs = self.conn.cursor()
|
||||
try:
|
||||
self._copy_from(curs, nrecs=1024, srec=10 * 1024, copykw={})
|
||||
finally:
|
||||
curs.close()
|
||||
|
||||
@slow
|
||||
def test_copy_from_insane_size(self):
|
||||
# Trying to trigger a "would block" error
|
||||
curs = self.conn.cursor()
|
||||
try:
|
||||
self._copy_from(curs, nrecs=10 * 1024, srec=10 * 1024,
|
||||
copykw={'size': 20 * 1024 * 1024})
|
||||
finally:
|
||||
curs.close()
|
||||
|
||||
def test_copy_from_cols(self):
|
||||
curs = self.conn.cursor()
|
||||
f = StringIO()
|
||||
for i in range(10):
|
||||
f.write(f"{i}\n")
|
||||
|
||||
f.seek(0)
|
||||
curs.copy_from(MinimalRead(f), "tcopy", columns=['id'])
|
||||
|
||||
curs.execute("select * from tcopy order by id")
|
||||
self.assertEqual([(i, None) for i in range(10)], curs.fetchall())
|
||||
|
||||
def test_copy_from_cols_err(self):
|
||||
curs = self.conn.cursor()
|
||||
f = StringIO()
|
||||
for i in range(10):
|
||||
f.write(f"{i}\n")
|
||||
|
||||
f.seek(0)
|
||||
|
||||
def cols():
|
||||
raise ZeroDivisionError()
|
||||
yield 'id'
|
||||
|
||||
self.assertRaises(ZeroDivisionError,
|
||||
curs.copy_from, MinimalRead(f), "tcopy", columns=cols())
|
||||
|
||||
@slow
|
||||
def test_copy_to(self):
|
||||
curs = self.conn.cursor()
|
||||
try:
|
||||
self._copy_from(curs, nrecs=1024, srec=10 * 1024, copykw={})
|
||||
self._copy_to(curs, srec=10 * 1024)
|
||||
finally:
|
||||
curs.close()
|
||||
|
||||
def test_copy_text(self):
|
||||
self.conn.set_client_encoding('latin1')
|
||||
self._create_temp_table() # the above call closed the xn
|
||||
|
||||
abin = bytes(list(range(32, 127))
|
||||
+ list(range(160, 256))).decode('latin1')
|
||||
about = abin.replace('\\', '\\\\')
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute('insert into tcopy values (%s, %s)',
|
||||
(42, abin))
|
||||
|
||||
f = io.StringIO()
|
||||
curs.copy_to(f, 'tcopy', columns=('data',))
|
||||
f.seek(0)
|
||||
self.assertEqual(f.readline().rstrip(), about)
|
||||
|
||||
def test_copy_bytes(self):
|
||||
self.conn.set_client_encoding('latin1')
|
||||
self._create_temp_table() # the above call closed the xn
|
||||
|
||||
abin = bytes(list(range(32, 127))
|
||||
+ list(range(160, 255))).decode('latin1')
|
||||
about = abin.replace('\\', '\\\\').encode('latin1')
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute('insert into tcopy values (%s, %s)',
|
||||
(42, abin))
|
||||
|
||||
f = io.BytesIO()
|
||||
curs.copy_to(f, 'tcopy', columns=('data',))
|
||||
f.seek(0)
|
||||
self.assertEqual(f.readline().rstrip(), about)
|
||||
|
||||
def test_copy_expert_textiobase(self):
|
||||
self.conn.set_client_encoding('latin1')
|
||||
self._create_temp_table() # the above call closed the xn
|
||||
|
||||
abin = bytes(list(range(32, 127))
|
||||
+ list(range(160, 256))).decode('latin1')
|
||||
about = abin.replace('\\', '\\\\')
|
||||
|
||||
f = io.StringIO()
|
||||
f.write(about)
|
||||
f.seek(0)
|
||||
|
||||
curs = self.conn.cursor()
|
||||
psycopg2.extensions.register_type(
|
||||
psycopg2.extensions.UNICODE, curs)
|
||||
|
||||
curs.copy_expert('COPY tcopy (data) FROM STDIN', f)
|
||||
curs.execute("select data from tcopy;")
|
||||
self.assertEqual(curs.fetchone()[0], abin)
|
||||
|
||||
f = io.StringIO()
|
||||
curs.copy_expert('COPY tcopy (data) TO STDOUT', f)
|
||||
f.seek(0)
|
||||
self.assertEqual(f.readline().rstrip(), about)
|
||||
|
||||
# same tests with setting size
|
||||
f = io.StringIO()
|
||||
f.write(about)
|
||||
f.seek(0)
|
||||
exp_size = 123
|
||||
# hack here to leave file as is, only check size when reading
|
||||
real_read = f.read
|
||||
|
||||
def read(_size, f=f, exp_size=exp_size):
|
||||
self.assertEqual(_size, exp_size)
|
||||
return real_read(_size)
|
||||
|
||||
f.read = read
|
||||
curs.copy_expert('COPY tcopy (data) FROM STDIN', f, size=exp_size)
|
||||
curs.execute("select data from tcopy;")
|
||||
self.assertEqual(curs.fetchone()[0], abin)
|
||||
|
||||
def _copy_from(self, curs, nrecs, srec, copykw):
|
||||
f = StringIO()
|
||||
for i, c in zip(range(nrecs), cycle(string.ascii_letters)):
|
||||
l = c * srec
|
||||
f.write(f"{i}\t{l}\n")
|
||||
|
||||
f.seek(0)
|
||||
curs.copy_from(MinimalRead(f), "tcopy", **copykw)
|
||||
|
||||
curs.execute("select count(*) from tcopy")
|
||||
self.assertEqual(nrecs, curs.fetchone()[0])
|
||||
|
||||
curs.execute("select data from tcopy where id < %s order by id",
|
||||
(len(string.ascii_letters),))
|
||||
for i, (l,) in enumerate(curs):
|
||||
self.assertEqual(l, string.ascii_letters[i] * srec)
|
||||
|
||||
def _copy_to(self, curs, srec):
|
||||
f = StringIO()
|
||||
curs.copy_to(MinimalWrite(f), "tcopy")
|
||||
|
||||
f.seek(0)
|
||||
ntests = 0
|
||||
for line in f:
|
||||
n, s = line.split()
|
||||
if int(n) < len(string.ascii_letters):
|
||||
self.assertEqual(s, string.ascii_letters[int(n)] * srec)
|
||||
ntests += 1
|
||||
|
||||
self.assertEqual(ntests, len(string.ascii_letters))
|
||||
|
||||
def test_copy_expert_file_refcount(self):
|
||||
class Whatever:
|
||||
pass
|
||||
|
||||
f = Whatever()
|
||||
curs = self.conn.cursor()
|
||||
self.assertRaises(TypeError,
|
||||
curs.copy_expert, 'COPY tcopy (data) FROM STDIN', f)
|
||||
|
||||
def test_copy_no_column_limit(self):
|
||||
cols = [f"c{i:050}" for i in range(200)]
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute('CREATE TEMPORARY TABLE manycols (%s)' % ',\n'.join(
|
||||
["%s int" % c for c in cols]))
|
||||
curs.execute("INSERT INTO manycols DEFAULT VALUES")
|
||||
|
||||
f = StringIO()
|
||||
curs.copy_to(f, "manycols", columns=cols)
|
||||
f.seek(0)
|
||||
self.assertEqual(f.read().split(), ['\\N'] * len(cols))
|
||||
|
||||
f.seek(0)
|
||||
curs.copy_from(f, "manycols", columns=cols)
|
||||
curs.execute("select count(*) from manycols;")
|
||||
self.assertEqual(curs.fetchone()[0], 2)
|
||||
|
||||
def test_copy_funny_names(self):
|
||||
cols = ["select", "insert", "group"]
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute('CREATE TEMPORARY TABLE "select" (%s)' % ',\n'.join(
|
||||
['"%s" int' % c for c in cols]))
|
||||
curs.execute('INSERT INTO "select" DEFAULT VALUES')
|
||||
|
||||
f = StringIO()
|
||||
curs.copy_to(f, "select", columns=cols)
|
||||
f.seek(0)
|
||||
self.assertEqual(f.read().split(), ['\\N'] * len(cols))
|
||||
|
||||
f.seek(0)
|
||||
curs.copy_from(f, "select", columns=cols)
|
||||
curs.execute('select count(*) from "select";')
|
||||
self.assertEqual(curs.fetchone()[0], 2)
|
||||
|
||||
@skip_before_postgres(8, 2) # they don't send the count
|
||||
def test_copy_rowcount(self):
|
||||
curs = self.conn.cursor()
|
||||
|
||||
curs.copy_from(StringIO('aaa\nbbb\nccc\n'), 'tcopy', columns=['data'])
|
||||
self.assertEqual(curs.rowcount, 3)
|
||||
|
||||
curs.copy_expert(
|
||||
"copy tcopy (data) from stdin",
|
||||
StringIO('ddd\neee\n'))
|
||||
self.assertEqual(curs.rowcount, 2)
|
||||
|
||||
curs.copy_to(StringIO(), "tcopy")
|
||||
self.assertEqual(curs.rowcount, 5)
|
||||
|
||||
curs.execute("insert into tcopy (data) values ('fff')")
|
||||
curs.copy_expert("copy tcopy to stdout", StringIO())
|
||||
self.assertEqual(curs.rowcount, 6)
|
||||
|
||||
def test_copy_rowcount_error(self):
|
||||
curs = self.conn.cursor()
|
||||
|
||||
curs.execute("insert into tcopy (data) values ('fff')")
|
||||
self.assertEqual(curs.rowcount, 1)
|
||||
|
||||
self.assertRaises(psycopg2.DataError,
|
||||
curs.copy_from, StringIO('aaa\nbbb\nccc\n'), 'tcopy')
|
||||
self.assertEqual(curs.rowcount, -1)
|
||||
|
||||
def test_copy_query(self):
|
||||
curs = self.conn.cursor()
|
||||
|
||||
curs.copy_from(StringIO('aaa\nbbb\nccc\n'), 'tcopy', columns=['data'])
|
||||
self.assert_(b"copy " in curs.query.lower())
|
||||
self.assert_(b" from stdin" in curs.query.lower())
|
||||
|
||||
curs.copy_expert(
|
||||
"copy tcopy (data) from stdin",
|
||||
StringIO('ddd\neee\n'))
|
||||
self.assert_(b"copy " in curs.query.lower())
|
||||
self.assert_(b" from stdin" in curs.query.lower())
|
||||
|
||||
curs.copy_to(StringIO(), "tcopy")
|
||||
self.assert_(b"copy " in curs.query.lower())
|
||||
self.assert_(b" to stdout" in curs.query.lower())
|
||||
|
||||
curs.execute("insert into tcopy (data) values ('fff')")
|
||||
curs.copy_expert("copy tcopy to stdout", StringIO())
|
||||
self.assert_(b"copy " in curs.query.lower())
|
||||
self.assert_(b" to stdout" in curs.query.lower())
|
||||
|
||||
@slow
|
||||
def test_copy_from_segfault(self):
|
||||
# issue #219
|
||||
script = f"""import psycopg2
|
||||
conn = psycopg2.connect({dsn!r})
|
||||
curs = conn.cursor()
|
||||
curs.execute("create table copy_segf (id int)")
|
||||
try:
|
||||
curs.execute("copy copy_segf from stdin")
|
||||
except psycopg2.ProgrammingError:
|
||||
pass
|
||||
conn.close()
|
||||
"""
|
||||
|
||||
proc = Popen([sys.executable, '-c', script])
|
||||
proc.communicate()
|
||||
self.assertEqual(0, proc.returncode)
|
||||
|
||||
@slow
|
||||
def test_copy_to_segfault(self):
|
||||
# issue #219
|
||||
script = f"""import psycopg2
|
||||
conn = psycopg2.connect({dsn!r})
|
||||
curs = conn.cursor()
|
||||
curs.execute("create table copy_segf (id int)")
|
||||
try:
|
||||
curs.execute("copy copy_segf to stdout")
|
||||
except psycopg2.ProgrammingError:
|
||||
pass
|
||||
conn.close()
|
||||
"""
|
||||
|
||||
proc = Popen([sys.executable, '-c', script], stdout=PIPE)
|
||||
proc.communicate()
|
||||
self.assertEqual(0, proc.returncode)
|
||||
|
||||
def test_copy_from_propagate_error(self):
|
||||
class BrokenRead(TextIOBase):
|
||||
def read(self, size):
|
||||
return 1 / 0
|
||||
|
||||
def readline(self):
|
||||
return 1 / 0
|
||||
|
||||
curs = self.conn.cursor()
|
||||
# It seems we cannot do this, but now at least we propagate the error
|
||||
# self.assertRaises(ZeroDivisionError,
|
||||
# curs.copy_from, BrokenRead(), "tcopy")
|
||||
try:
|
||||
curs.copy_from(BrokenRead(), "tcopy")
|
||||
except Exception as e:
|
||||
self.assert_('ZeroDivisionError' in str(e))
|
||||
|
||||
def test_copy_to_propagate_error(self):
|
||||
class BrokenWrite(TextIOBase):
|
||||
def write(self, data):
|
||||
return 1 / 0
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("insert into tcopy values (10, 'hi')")
|
||||
self.assertRaises(ZeroDivisionError,
|
||||
curs.copy_to, BrokenWrite(), "tcopy")
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
701
tests/test_cursor.py
Executable file
701
tests/test_cursor.py
Executable file
@ -0,0 +1,701 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# test_cursor.py - unit test for cursor attributes
|
||||
#
|
||||
# Copyright (C) 2010-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import gc
|
||||
import sys
|
||||
import time
|
||||
import ctypes
|
||||
import pickle
|
||||
import psycopg2
|
||||
import psycopg2.extensions
|
||||
import unittest
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from weakref import ref
|
||||
from .testutils import (ConnectingTestCase, skip_before_postgres,
|
||||
skip_if_no_getrefcount, slow, skip_if_no_superuser,
|
||||
skip_if_windows, skip_if_crdb, crdb_version)
|
||||
|
||||
import psycopg2.extras
|
||||
|
||||
|
||||
class CursorTests(ConnectingTestCase):
|
||||
|
||||
def test_close_idempotent(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.close()
|
||||
cur.close()
|
||||
self.assert_(cur.closed)
|
||||
|
||||
def test_empty_query(self):
|
||||
cur = self.conn.cursor()
|
||||
self.assertRaises(psycopg2.ProgrammingError, cur.execute, "")
|
||||
self.assertRaises(psycopg2.ProgrammingError, cur.execute, " ")
|
||||
self.assertRaises(psycopg2.ProgrammingError, cur.execute, ";")
|
||||
|
||||
def test_executemany_propagate_exceptions(self):
|
||||
conn = self.conn
|
||||
cur = conn.cursor()
|
||||
cur.execute("create table test_exc (data int);")
|
||||
|
||||
def buggygen():
|
||||
yield 1 // 0
|
||||
|
||||
self.assertRaises(ZeroDivisionError,
|
||||
cur.executemany, "insert into test_exc values (%s)", buggygen())
|
||||
cur.close()
|
||||
|
||||
def test_mogrify_unicode(self):
|
||||
conn = self.conn
|
||||
cur = conn.cursor()
|
||||
|
||||
# test consistency between execute and mogrify.
|
||||
|
||||
# unicode query containing only ascii data
|
||||
cur.execute("SELECT 'foo';")
|
||||
self.assertEqual('foo', cur.fetchone()[0])
|
||||
self.assertEqual(b"SELECT 'foo';", cur.mogrify("SELECT 'foo';"))
|
||||
|
||||
conn.set_client_encoding('UTF8')
|
||||
snowman = "\u2603"
|
||||
|
||||
def b(s):
|
||||
if isinstance(s, str):
|
||||
return s.encode('utf8')
|
||||
else:
|
||||
return s
|
||||
|
||||
# unicode query with non-ascii data
|
||||
cur.execute(f"SELECT '{snowman}';")
|
||||
self.assertEqual(snowman.encode('utf8'), b(cur.fetchone()[0]))
|
||||
self.assertQuotedEqual(f"SELECT '{snowman}';".encode('utf8'),
|
||||
cur.mogrify(f"SELECT '{snowman}';"))
|
||||
|
||||
# unicode args
|
||||
cur.execute("SELECT %s;", (snowman,))
|
||||
self.assertEqual(snowman.encode("utf-8"), b(cur.fetchone()[0]))
|
||||
self.assertQuotedEqual(f"SELECT '{snowman}';".encode('utf8'),
|
||||
cur.mogrify("SELECT %s;", (snowman,)))
|
||||
|
||||
# unicode query and args
|
||||
cur.execute("SELECT %s;", (snowman,))
|
||||
self.assertEqual(snowman.encode("utf-8"), b(cur.fetchone()[0]))
|
||||
self.assertQuotedEqual(f"SELECT '{snowman}';".encode('utf8'),
|
||||
cur.mogrify("SELECT %s;", (snowman,)))
|
||||
|
||||
def test_mogrify_decimal_explodes(self):
|
||||
conn = self.conn
|
||||
cur = conn.cursor()
|
||||
self.assertEqual(b'SELECT 10.3;',
|
||||
cur.mogrify("SELECT %s;", (Decimal("10.3"),)))
|
||||
|
||||
@skip_if_no_getrefcount
|
||||
def test_mogrify_leak_on_multiple_reference(self):
|
||||
# issue #81: reference leak when a parameter value is referenced
|
||||
# more than once from a dict.
|
||||
cur = self.conn.cursor()
|
||||
foo = (lambda x: x)('foo') * 10
|
||||
nref1 = sys.getrefcount(foo)
|
||||
cur.mogrify("select %(foo)s, %(foo)s, %(foo)s", {'foo': foo})
|
||||
nref2 = sys.getrefcount(foo)
|
||||
self.assertEqual(nref1, nref2)
|
||||
|
||||
def test_modify_closed(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.close()
|
||||
sql = cur.mogrify("select %s", (10,))
|
||||
self.assertEqual(sql, b"select 10")
|
||||
|
||||
def test_bad_placeholder(self):
|
||||
cur = self.conn.cursor()
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
cur.mogrify, "select %(foo", {})
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
cur.mogrify, "select %(foo", {'foo': 1})
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
cur.mogrify, "select %(foo, %(bar)", {'foo': 1})
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
cur.mogrify, "select %(foo, %(bar)", {'foo': 1, 'bar': 2})
|
||||
|
||||
def test_cast(self):
|
||||
curs = self.conn.cursor()
|
||||
|
||||
self.assertEqual(42, curs.cast(20, '42'))
|
||||
self.assertAlmostEqual(3.14, curs.cast(700, '3.14'))
|
||||
|
||||
self.assertEqual(Decimal('123.45'), curs.cast(1700, '123.45'))
|
||||
|
||||
self.assertEqual(date(2011, 1, 2), curs.cast(1082, '2011-01-02'))
|
||||
self.assertEqual("who am i?", curs.cast(705, 'who am i?')) # unknown
|
||||
|
||||
def test_cast_specificity(self):
|
||||
curs = self.conn.cursor()
|
||||
self.assertEqual("foo", curs.cast(705, 'foo'))
|
||||
|
||||
D = psycopg2.extensions.new_type((705,), "DOUBLING", lambda v, c: v * 2)
|
||||
psycopg2.extensions.register_type(D, self.conn)
|
||||
self.assertEqual("foofoo", curs.cast(705, 'foo'))
|
||||
|
||||
T = psycopg2.extensions.new_type((705,), "TREBLING", lambda v, c: v * 3)
|
||||
psycopg2.extensions.register_type(T, curs)
|
||||
self.assertEqual("foofoofoo", curs.cast(705, 'foo'))
|
||||
|
||||
curs2 = self.conn.cursor()
|
||||
self.assertEqual("foofoo", curs2.cast(705, 'foo'))
|
||||
|
||||
def test_weakref(self):
|
||||
curs = self.conn.cursor()
|
||||
w = ref(curs)
|
||||
del curs
|
||||
gc.collect()
|
||||
self.assert_(w() is None)
|
||||
|
||||
def test_null_name(self):
|
||||
curs = self.conn.cursor(None)
|
||||
self.assertEqual(curs.name, None)
|
||||
|
||||
def test_description_attribs(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("""select
|
||||
3.14::decimal(10,2) as pi,
|
||||
'hello'::text as hi,
|
||||
'2010-02-18'::date as now;
|
||||
""")
|
||||
self.assertEqual(len(curs.description), 3)
|
||||
for c in curs.description:
|
||||
self.assertEqual(len(c), 7) # DBAPI happy
|
||||
for a in ('name', 'type_code', 'display_size', 'internal_size',
|
||||
'precision', 'scale', 'null_ok'):
|
||||
self.assert_(hasattr(c, a), a)
|
||||
|
||||
c = curs.description[0]
|
||||
self.assertEqual(c.name, 'pi')
|
||||
self.assert_(c.type_code in psycopg2.extensions.DECIMAL.values)
|
||||
if crdb_version(self.conn) is None:
|
||||
self.assert_(c.internal_size > 0)
|
||||
self.assertEqual(c.precision, 10)
|
||||
self.assertEqual(c.scale, 2)
|
||||
|
||||
c = curs.description[1]
|
||||
self.assertEqual(c.name, 'hi')
|
||||
self.assert_(c.type_code in psycopg2.STRING.values)
|
||||
self.assert_(c.internal_size < 0)
|
||||
self.assertEqual(c.precision, None)
|
||||
self.assertEqual(c.scale, None)
|
||||
|
||||
c = curs.description[2]
|
||||
self.assertEqual(c.name, 'now')
|
||||
self.assert_(c.type_code in psycopg2.extensions.DATE.values)
|
||||
self.assert_(c.internal_size > 0)
|
||||
self.assertEqual(c.precision, None)
|
||||
self.assertEqual(c.scale, None)
|
||||
|
||||
@skip_if_crdb("table oid")
|
||||
def test_description_extra_attribs(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("""
|
||||
create table testcol (
|
||||
pi decimal(10,2),
|
||||
hi text)
|
||||
""")
|
||||
curs.execute("select oid from pg_class where relname = %s", ('testcol',))
|
||||
oid = curs.fetchone()[0]
|
||||
|
||||
curs.execute("insert into testcol values (3.14, 'hello')")
|
||||
curs.execute("select hi, pi, 42 from testcol")
|
||||
self.assertEqual(curs.description[0].table_oid, oid)
|
||||
self.assertEqual(curs.description[0].table_column, 2)
|
||||
|
||||
self.assertEqual(curs.description[1].table_oid, oid)
|
||||
self.assertEqual(curs.description[1].table_column, 1)
|
||||
|
||||
self.assertEqual(curs.description[2].table_oid, None)
|
||||
self.assertEqual(curs.description[2].table_column, None)
|
||||
|
||||
def test_description_slice(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select 1::int4 as a")
|
||||
self.assertEqual(curs.description[0][0:2], ('a', 23))
|
||||
|
||||
def test_pickle_description(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.execute('SELECT 1 AS foo')
|
||||
description = curs.description
|
||||
|
||||
pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL)
|
||||
unpickled = pickle.loads(pickled)
|
||||
|
||||
self.assertEqual(description, unpickled)
|
||||
|
||||
def test_column_refcount(self):
|
||||
# Reproduce crash describe in ticket #1252
|
||||
from psycopg2.extensions import Column
|
||||
|
||||
def do_stuff():
|
||||
_ = Column(name='my_column')
|
||||
|
||||
for _ in range(1000):
|
||||
do_stuff()
|
||||
|
||||
def test_bad_subclass(self):
|
||||
# check that we get an error message instead of a segfault
|
||||
# for badly written subclasses.
|
||||
# see https://stackoverflow.com/questions/22019341/
|
||||
class StupidCursor(psycopg2.extensions.cursor):
|
||||
def __init__(self, *args, **kwargs):
|
||||
# I am stupid so not calling superclass init
|
||||
pass
|
||||
|
||||
cur = StupidCursor()
|
||||
self.assertRaises(psycopg2.InterfaceError, cur.execute, 'select 1')
|
||||
self.assertRaises(psycopg2.InterfaceError, cur.executemany,
|
||||
'select 1', [])
|
||||
|
||||
def test_callproc_badparam(self):
|
||||
cur = self.conn.cursor()
|
||||
self.assertRaises(TypeError, cur.callproc, 'lower', 42)
|
||||
|
||||
# It would be inappropriate to test callproc's named parameters in the
|
||||
# DBAPI2.0 test section because they are a psycopg2 extension.
|
||||
@skip_before_postgres(9, 0)
|
||||
@skip_if_crdb("stored procedure")
|
||||
def test_callproc_dict(self):
|
||||
# This parameter name tests for injection and quote escaping
|
||||
paramname = '''
|
||||
Robert'); drop table "students" --
|
||||
'''.strip()
|
||||
escaped_paramname = '"%s"' % paramname.replace('"', '""')
|
||||
procname = 'pg_temp.randall'
|
||||
|
||||
cur = self.conn.cursor()
|
||||
|
||||
# Set up the temporary function
|
||||
cur.execute(f'''
|
||||
CREATE FUNCTION {procname}({escaped_paramname} INT)
|
||||
RETURNS INT AS
|
||||
'SELECT $1 * $1'
|
||||
LANGUAGE SQL
|
||||
''')
|
||||
|
||||
# Make sure callproc works right
|
||||
cur.callproc(procname, {paramname: 2})
|
||||
self.assertEquals(cur.fetchone()[0], 4)
|
||||
|
||||
# Make sure callproc fails right
|
||||
failing_cases = [
|
||||
({paramname: 2, 'foo': 'bar'}, psycopg2.ProgrammingError),
|
||||
({paramname: '2'}, psycopg2.ProgrammingError),
|
||||
({paramname: 'two'}, psycopg2.ProgrammingError),
|
||||
({'bj\xc3rn': 2}, psycopg2.ProgrammingError),
|
||||
({3: 2}, TypeError),
|
||||
({self: 2}, TypeError),
|
||||
]
|
||||
for parameter_sequence, exception in failing_cases:
|
||||
self.assertRaises(exception, cur.callproc, procname, parameter_sequence)
|
||||
self.conn.rollback()
|
||||
|
||||
@skip_if_no_superuser
|
||||
@skip_if_windows
|
||||
@skip_if_crdb("backend pid")
|
||||
@skip_before_postgres(8, 4)
|
||||
def test_external_close_sync(self):
|
||||
# If a "victim" connection is closed by a "control" connection
|
||||
# behind psycopg2's back, psycopg2 always handles it correctly:
|
||||
# raise OperationalError, set conn.closed to 2. This reproduces
|
||||
# issue #443, a race between control_conn closing victim_conn and
|
||||
# psycopg2 noticing.
|
||||
control_conn = self.conn
|
||||
connect_func = self.connect
|
||||
|
||||
def wait_func(conn):
|
||||
pass
|
||||
|
||||
self._test_external_close(control_conn, connect_func, wait_func)
|
||||
|
||||
@skip_if_no_superuser
|
||||
@skip_if_windows
|
||||
@skip_if_crdb("backend pid")
|
||||
@skip_before_postgres(8, 4)
|
||||
def test_external_close_async(self):
|
||||
# Issue #443 is in the async code too. Since the fix is duplicated,
|
||||
# so is the test.
|
||||
control_conn = self.conn
|
||||
|
||||
def connect_func():
|
||||
return self.connect(async_=True)
|
||||
|
||||
wait_func = psycopg2.extras.wait_select
|
||||
self._test_external_close(control_conn, connect_func, wait_func)
|
||||
|
||||
def _test_external_close(self, control_conn, connect_func, wait_func):
|
||||
# The short sleep before using victim_conn the second time makes it
|
||||
# much more likely to lose the race and see the bug. Repeating the
|
||||
# test several times makes it even more likely.
|
||||
for i in range(10):
|
||||
victim_conn = connect_func()
|
||||
wait_func(victim_conn)
|
||||
|
||||
with victim_conn.cursor() as cur:
|
||||
cur.execute('select pg_backend_pid()')
|
||||
wait_func(victim_conn)
|
||||
pid1 = cur.fetchall()[0][0]
|
||||
|
||||
with control_conn.cursor() as cur:
|
||||
cur.execute('select pg_terminate_backend(%s)', (pid1,))
|
||||
|
||||
time.sleep(0.001)
|
||||
|
||||
def f():
|
||||
with victim_conn.cursor() as cur:
|
||||
cur.execute('select 1')
|
||||
wait_func(victim_conn)
|
||||
|
||||
self.assertRaises(psycopg2.OperationalError, f)
|
||||
|
||||
self.assertEqual(victim_conn.closed, 2)
|
||||
|
||||
@skip_before_postgres(8, 2)
|
||||
def test_rowcount_on_executemany_returning(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("create table execmany(id serial primary key, data int)")
|
||||
cur.executemany(
|
||||
"insert into execmany (data) values (%s)",
|
||||
[(i,) for i in range(4)])
|
||||
self.assertEqual(cur.rowcount, 4)
|
||||
|
||||
cur.executemany(
|
||||
"insert into execmany (data) values (%s) returning data",
|
||||
[(i,) for i in range(5)])
|
||||
self.assertEqual(cur.rowcount, 5)
|
||||
|
||||
@skip_before_postgres(9)
|
||||
def test_pgresult_ptr(self):
|
||||
curs = self.conn.cursor()
|
||||
self.assert_(curs.pgresult_ptr is None)
|
||||
|
||||
curs.execute("select 'x'")
|
||||
self.assert_(curs.pgresult_ptr is not None)
|
||||
|
||||
try:
|
||||
f = self.libpq.PQcmdStatus
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
f.argtypes = [ctypes.c_void_p]
|
||||
f.restype = ctypes.c_char_p
|
||||
status = f(curs.pgresult_ptr)
|
||||
self.assertEqual(status, b'SELECT 1')
|
||||
|
||||
curs.close()
|
||||
self.assert_(curs.pgresult_ptr is None)
|
||||
|
||||
|
||||
@skip_if_crdb("named cursor")
|
||||
class NamedCursorTests(ConnectingTestCase):
|
||||
def test_invalid_name(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("create table invname (data int);")
|
||||
for i in (10, 20, 30):
|
||||
curs.execute("insert into invname values (%s)", (i,))
|
||||
curs.close()
|
||||
|
||||
curs = self.conn.cursor(r'1-2-3 \ "test"')
|
||||
curs.execute("select data from invname order by data")
|
||||
self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)])
|
||||
|
||||
def _create_withhold_table(self):
|
||||
curs = self.conn.cursor()
|
||||
try:
|
||||
curs.execute("drop table withhold")
|
||||
except psycopg2.ProgrammingError:
|
||||
self.conn.rollback()
|
||||
curs.execute("create table withhold (data int)")
|
||||
for i in (10, 20, 30):
|
||||
curs.execute("insert into withhold values (%s)", (i,))
|
||||
curs.close()
|
||||
|
||||
def test_withhold(self):
|
||||
self.assertRaises(psycopg2.ProgrammingError, self.conn.cursor,
|
||||
withhold=True)
|
||||
|
||||
self._create_withhold_table()
|
||||
curs = self.conn.cursor("W")
|
||||
self.assertEqual(curs.withhold, False)
|
||||
curs.withhold = True
|
||||
self.assertEqual(curs.withhold, True)
|
||||
curs.execute("select data from withhold order by data")
|
||||
self.conn.commit()
|
||||
self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)])
|
||||
curs.close()
|
||||
|
||||
curs = self.conn.cursor("W", withhold=True)
|
||||
self.assertEqual(curs.withhold, True)
|
||||
curs.execute("select data from withhold order by data")
|
||||
self.conn.commit()
|
||||
self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)])
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("drop table withhold")
|
||||
self.conn.commit()
|
||||
|
||||
def test_withhold_no_begin(self):
|
||||
self._create_withhold_table()
|
||||
curs = self.conn.cursor("w", withhold=True)
|
||||
curs.execute("select data from withhold order by data")
|
||||
self.assertEqual(curs.fetchone(), (10,))
|
||||
self.assertEqual(self.conn.status, psycopg2.extensions.STATUS_BEGIN)
|
||||
self.assertEqual(self.conn.info.transaction_status,
|
||||
psycopg2.extensions.TRANSACTION_STATUS_INTRANS)
|
||||
|
||||
self.conn.commit()
|
||||
self.assertEqual(self.conn.status, psycopg2.extensions.STATUS_READY)
|
||||
self.assertEqual(self.conn.info.transaction_status,
|
||||
psycopg2.extensions.TRANSACTION_STATUS_IDLE)
|
||||
|
||||
self.assertEqual(curs.fetchone(), (20,))
|
||||
self.assertEqual(self.conn.status, psycopg2.extensions.STATUS_READY)
|
||||
self.assertEqual(self.conn.info.transaction_status,
|
||||
psycopg2.extensions.TRANSACTION_STATUS_IDLE)
|
||||
|
||||
curs.close()
|
||||
self.assertEqual(self.conn.status, psycopg2.extensions.STATUS_READY)
|
||||
self.assertEqual(self.conn.info.transaction_status,
|
||||
psycopg2.extensions.TRANSACTION_STATUS_IDLE)
|
||||
|
||||
def test_withhold_autocommit(self):
|
||||
self._create_withhold_table()
|
||||
self.conn.commit()
|
||||
self.conn.autocommit = True
|
||||
curs = self.conn.cursor("w", withhold=True)
|
||||
curs.execute("select data from withhold order by data")
|
||||
|
||||
self.assertEqual(curs.fetchone(), (10,))
|
||||
self.assertEqual(self.conn.status, psycopg2.extensions.STATUS_READY)
|
||||
self.assertEqual(self.conn.info.transaction_status,
|
||||
psycopg2.extensions.TRANSACTION_STATUS_IDLE)
|
||||
|
||||
self.conn.commit()
|
||||
self.assertEqual(self.conn.status, psycopg2.extensions.STATUS_READY)
|
||||
self.assertEqual(self.conn.info.transaction_status,
|
||||
psycopg2.extensions.TRANSACTION_STATUS_IDLE)
|
||||
|
||||
curs.close()
|
||||
self.assertEqual(self.conn.status, psycopg2.extensions.STATUS_READY)
|
||||
self.assertEqual(self.conn.info.transaction_status,
|
||||
psycopg2.extensions.TRANSACTION_STATUS_IDLE)
|
||||
|
||||
def test_scrollable(self):
|
||||
self.assertRaises(psycopg2.ProgrammingError, self.conn.cursor,
|
||||
scrollable=True)
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("create table scrollable (data int)")
|
||||
curs.executemany("insert into scrollable values (%s)",
|
||||
[(i,) for i in range(100)])
|
||||
curs.close()
|
||||
|
||||
for t in range(2):
|
||||
if not t:
|
||||
curs = self.conn.cursor("S")
|
||||
self.assertEqual(curs.scrollable, None)
|
||||
curs.scrollable = True
|
||||
else:
|
||||
curs = self.conn.cursor("S", scrollable=True)
|
||||
|
||||
self.assertEqual(curs.scrollable, True)
|
||||
curs.itersize = 10
|
||||
|
||||
# complex enough to make postgres cursors declare without
|
||||
# scroll/no scroll to fail
|
||||
curs.execute("""
|
||||
select x.data
|
||||
from scrollable x
|
||||
join scrollable y on x.data = y.data
|
||||
order by y.data""")
|
||||
for i, (n,) in enumerate(curs):
|
||||
self.assertEqual(i, n)
|
||||
|
||||
curs.scroll(-1)
|
||||
for i in range(99, -1, -1):
|
||||
curs.scroll(-1)
|
||||
self.assertEqual(i, curs.fetchone()[0])
|
||||
curs.scroll(-1)
|
||||
|
||||
curs.close()
|
||||
|
||||
def test_not_scrollable(self):
|
||||
self.assertRaises(psycopg2.ProgrammingError, self.conn.cursor,
|
||||
scrollable=False)
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("create table scrollable (data int)")
|
||||
curs.executemany("insert into scrollable values (%s)",
|
||||
[(i,) for i in range(100)])
|
||||
curs.close()
|
||||
|
||||
curs = self.conn.cursor("S") # default scrollability
|
||||
curs.execute("select * from scrollable")
|
||||
self.assertEqual(curs.scrollable, None)
|
||||
curs.scroll(2)
|
||||
try:
|
||||
curs.scroll(-1)
|
||||
except psycopg2.OperationalError:
|
||||
return self.skipTest("can't evaluate non-scrollable cursor")
|
||||
curs.close()
|
||||
|
||||
curs = self.conn.cursor("S", scrollable=False)
|
||||
self.assertEqual(curs.scrollable, False)
|
||||
curs.execute("select * from scrollable")
|
||||
curs.scroll(2)
|
||||
self.assertRaises(psycopg2.OperationalError, curs.scroll, -1)
|
||||
|
||||
@slow
|
||||
@skip_before_postgres(8, 2)
|
||||
def test_iter_named_cursor_efficient(self):
|
||||
curs = self.conn.cursor('tmp')
|
||||
# if these records are fetched in the same roundtrip their
|
||||
# timestamp will not be influenced by the pause in Python world.
|
||||
curs.execute("""select clock_timestamp() from generate_series(1,2)""")
|
||||
i = iter(curs)
|
||||
t1 = next(i)[0]
|
||||
time.sleep(0.2)
|
||||
t2 = next(i)[0]
|
||||
self.assert_((t2 - t1).microseconds * 1e-6 < 0.1,
|
||||
f"named cursor records fetched in 2 roundtrips (delta: {t2 - t1})")
|
||||
|
||||
@skip_before_postgres(8, 0)
|
||||
def test_iter_named_cursor_default_itersize(self):
|
||||
curs = self.conn.cursor('tmp')
|
||||
curs.execute('select generate_series(1,50)')
|
||||
rv = [(r[0], curs.rownumber) for r in curs]
|
||||
# everything swallowed in one gulp
|
||||
self.assertEqual(rv, [(i, i) for i in range(1, 51)])
|
||||
|
||||
@skip_before_postgres(8, 0)
|
||||
def test_iter_named_cursor_itersize(self):
|
||||
curs = self.conn.cursor('tmp')
|
||||
curs.itersize = 30
|
||||
curs.execute('select generate_series(1,50)')
|
||||
rv = [(r[0], curs.rownumber) for r in curs]
|
||||
# everything swallowed in two gulps
|
||||
self.assertEqual(rv, [(i, ((i - 1) % 30) + 1) for i in range(1, 51)])
|
||||
|
||||
@skip_before_postgres(8, 0)
|
||||
def test_iter_named_cursor_rownumber(self):
|
||||
curs = self.conn.cursor('tmp')
|
||||
# note: this fails if itersize < dataset: internally we check
|
||||
# rownumber == rowcount to detect when to read anoter page, so we
|
||||
# would need an extra attribute to have a monotonic rownumber.
|
||||
curs.itersize = 20
|
||||
curs.execute('select generate_series(1,10)')
|
||||
for i, rec in enumerate(curs):
|
||||
self.assertEqual(i + 1, curs.rownumber)
|
||||
|
||||
@skip_before_postgres(8, 0)
|
||||
def test_named_cursor_stealing(self):
|
||||
# you can use a named cursor to iterate on a refcursor created
|
||||
# somewhere else
|
||||
cur1 = self.conn.cursor()
|
||||
cur1.execute("DECLARE test CURSOR WITHOUT HOLD "
|
||||
" FOR SELECT generate_series(1,7)")
|
||||
|
||||
cur2 = self.conn.cursor('test')
|
||||
# can call fetch without execute
|
||||
self.assertEqual((1,), cur2.fetchone())
|
||||
self.assertEqual([(2,), (3,), (4,)], cur2.fetchmany(3))
|
||||
self.assertEqual([(5,), (6,), (7,)], cur2.fetchall())
|
||||
|
||||
@skip_before_postgres(8, 2)
|
||||
def test_named_noop_close(self):
|
||||
cur = self.conn.cursor('test')
|
||||
cur.close()
|
||||
|
||||
@skip_before_postgres(8, 2)
|
||||
def test_stolen_named_cursor_close(self):
|
||||
cur1 = self.conn.cursor()
|
||||
cur1.execute("DECLARE test CURSOR WITHOUT HOLD "
|
||||
" FOR SELECT generate_series(1,7)")
|
||||
cur2 = self.conn.cursor('test')
|
||||
cur2.close()
|
||||
|
||||
cur1.execute("DECLARE test CURSOR WITHOUT HOLD "
|
||||
" FOR SELECT generate_series(1,7)")
|
||||
cur2 = self.conn.cursor('test')
|
||||
cur2.close()
|
||||
|
||||
@skip_before_postgres(8, 0)
|
||||
def test_scroll(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("select generate_series(0,9)")
|
||||
cur.scroll(2)
|
||||
self.assertEqual(cur.fetchone(), (2,))
|
||||
cur.scroll(2)
|
||||
self.assertEqual(cur.fetchone(), (5,))
|
||||
cur.scroll(2, mode='relative')
|
||||
self.assertEqual(cur.fetchone(), (8,))
|
||||
cur.scroll(-1)
|
||||
self.assertEqual(cur.fetchone(), (8,))
|
||||
cur.scroll(-2)
|
||||
self.assertEqual(cur.fetchone(), (7,))
|
||||
cur.scroll(2, mode='absolute')
|
||||
self.assertEqual(cur.fetchone(), (2,))
|
||||
|
||||
# on the boundary
|
||||
cur.scroll(0, mode='absolute')
|
||||
self.assertEqual(cur.fetchone(), (0,))
|
||||
self.assertRaises((IndexError, psycopg2.ProgrammingError),
|
||||
cur.scroll, -1, mode='absolute')
|
||||
cur.scroll(0, mode='absolute')
|
||||
self.assertRaises((IndexError, psycopg2.ProgrammingError),
|
||||
cur.scroll, -1)
|
||||
|
||||
cur.scroll(9, mode='absolute')
|
||||
self.assertEqual(cur.fetchone(), (9,))
|
||||
self.assertRaises((IndexError, psycopg2.ProgrammingError),
|
||||
cur.scroll, 10, mode='absolute')
|
||||
cur.scroll(9, mode='absolute')
|
||||
self.assertRaises((IndexError, psycopg2.ProgrammingError),
|
||||
cur.scroll, 1)
|
||||
|
||||
@skip_before_postgres(8, 0)
|
||||
def test_scroll_named(self):
|
||||
cur = self.conn.cursor('tmp', scrollable=True)
|
||||
cur.execute("select generate_series(0,9)")
|
||||
cur.scroll(2)
|
||||
self.assertEqual(cur.fetchone(), (2,))
|
||||
cur.scroll(2)
|
||||
self.assertEqual(cur.fetchone(), (5,))
|
||||
cur.scroll(2, mode='relative')
|
||||
self.assertEqual(cur.fetchone(), (8,))
|
||||
cur.scroll(9, mode='absolute')
|
||||
self.assertEqual(cur.fetchone(), (9,))
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
555
tests/test_dates.py
Executable file
555
tests/test_dates.py
Executable file
@ -0,0 +1,555 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# test_dates.py - unit test for dates handling
|
||||
#
|
||||
# Copyright (C) 2008-2019 James Henstridge <james@jamesh.id.au>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import sys
|
||||
import math
|
||||
import pickle
|
||||
from datetime import date, datetime, time, timedelta, timezone
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.tz import FixedOffsetTimezone, ZERO
|
||||
import unittest
|
||||
from .testutils import ConnectingTestCase, skip_before_postgres, skip_if_crdb
|
||||
|
||||
|
||||
def total_seconds(d):
|
||||
"""Return total number of seconds of a timedelta as a float."""
|
||||
return d.days * 24 * 60 * 60 + d.seconds + d.microseconds / 1000000.0
|
||||
|
||||
|
||||
class CommonDatetimeTestsMixin:
|
||||
|
||||
def execute(self, *args):
|
||||
self.curs.execute(*args)
|
||||
return self.curs.fetchone()[0]
|
||||
|
||||
def test_parse_date(self):
|
||||
value = self.DATE('2007-01-01', self.curs)
|
||||
self.assert_(value is not None)
|
||||
self.assertEqual(value.year, 2007)
|
||||
self.assertEqual(value.month, 1)
|
||||
self.assertEqual(value.day, 1)
|
||||
|
||||
def test_parse_null_date(self):
|
||||
value = self.DATE(None, self.curs)
|
||||
self.assertEqual(value, None)
|
||||
|
||||
def test_parse_incomplete_date(self):
|
||||
self.assertRaises(psycopg2.DataError, self.DATE, '2007', self.curs)
|
||||
self.assertRaises(psycopg2.DataError, self.DATE, '2007-01', self.curs)
|
||||
|
||||
def test_parse_time(self):
|
||||
value = self.TIME('13:30:29', self.curs)
|
||||
self.assert_(value is not None)
|
||||
self.assertEqual(value.hour, 13)
|
||||
self.assertEqual(value.minute, 30)
|
||||
self.assertEqual(value.second, 29)
|
||||
|
||||
def test_parse_null_time(self):
|
||||
value = self.TIME(None, self.curs)
|
||||
self.assertEqual(value, None)
|
||||
|
||||
def test_parse_incomplete_time(self):
|
||||
self.assertRaises(psycopg2.DataError, self.TIME, '13', self.curs)
|
||||
self.assertRaises(psycopg2.DataError, self.TIME, '13:30', self.curs)
|
||||
|
||||
def test_parse_datetime(self):
|
||||
value = self.DATETIME('2007-01-01 13:30:29', self.curs)
|
||||
self.assert_(value is not None)
|
||||
self.assertEqual(value.year, 2007)
|
||||
self.assertEqual(value.month, 1)
|
||||
self.assertEqual(value.day, 1)
|
||||
self.assertEqual(value.hour, 13)
|
||||
self.assertEqual(value.minute, 30)
|
||||
self.assertEqual(value.second, 29)
|
||||
|
||||
def test_parse_null_datetime(self):
|
||||
value = self.DATETIME(None, self.curs)
|
||||
self.assertEqual(value, None)
|
||||
|
||||
def test_parse_incomplete_datetime(self):
|
||||
self.assertRaises(psycopg2.DataError,
|
||||
self.DATETIME, '2007', self.curs)
|
||||
self.assertRaises(psycopg2.DataError,
|
||||
self.DATETIME, '2007-01', self.curs)
|
||||
self.assertRaises(psycopg2.DataError,
|
||||
self.DATETIME, '2007-01-01 13', self.curs)
|
||||
self.assertRaises(psycopg2.DataError,
|
||||
self.DATETIME, '2007-01-01 13:30', self.curs)
|
||||
|
||||
def test_parse_null_interval(self):
|
||||
value = self.INTERVAL(None, self.curs)
|
||||
self.assertEqual(value, None)
|
||||
|
||||
|
||||
class DatetimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
|
||||
"""Tests for the datetime based date handling in psycopg2."""
|
||||
|
||||
def setUp(self):
|
||||
ConnectingTestCase.setUp(self)
|
||||
self.curs = self.conn.cursor()
|
||||
self.DATE = psycopg2.extensions.PYDATE
|
||||
self.TIME = psycopg2.extensions.PYTIME
|
||||
self.DATETIME = psycopg2.extensions.PYDATETIME
|
||||
self.INTERVAL = psycopg2.extensions.PYINTERVAL
|
||||
|
||||
def test_parse_bc_date(self):
|
||||
# datetime does not support BC dates
|
||||
self.assertRaises(ValueError, self.DATE, '00042-01-01 BC', self.curs)
|
||||
|
||||
def test_parse_bc_datetime(self):
|
||||
# datetime does not support BC dates
|
||||
self.assertRaises(ValueError, self.DATETIME,
|
||||
'00042-01-01 13:30:29 BC', self.curs)
|
||||
|
||||
def test_parse_time_microseconds(self):
|
||||
value = self.TIME('13:30:29.123456', self.curs)
|
||||
self.assertEqual(value.second, 29)
|
||||
self.assertEqual(value.microsecond, 123456)
|
||||
|
||||
def test_parse_datetime_microseconds(self):
|
||||
value = self.DATETIME('2007-01-01 13:30:29.123456', self.curs)
|
||||
self.assertEqual(value.second, 29)
|
||||
self.assertEqual(value.microsecond, 123456)
|
||||
|
||||
def check_time_tz(self, str_offset, offset):
|
||||
base = time(13, 30, 29)
|
||||
base_str = '13:30:29'
|
||||
|
||||
value = self.TIME(base_str + str_offset, self.curs)
|
||||
|
||||
# Value has time zone info and correct UTC offset.
|
||||
self.assertNotEqual(value.tzinfo, None),
|
||||
self.assertEqual(value.utcoffset(), timedelta(seconds=offset))
|
||||
|
||||
# Time portion is correct.
|
||||
self.assertEqual(value.replace(tzinfo=None), base)
|
||||
|
||||
def test_parse_time_timezone(self):
|
||||
self.check_time_tz("+01", 3600)
|
||||
self.check_time_tz("-01", -3600)
|
||||
self.check_time_tz("+01:15", 4500)
|
||||
self.check_time_tz("-01:15", -4500)
|
||||
if sys.version_info < (3, 7):
|
||||
# The Python < 3.7 datetime module does not support time zone
|
||||
# offsets that are not a whole number of minutes.
|
||||
# We round the offset to the nearest minute.
|
||||
self.check_time_tz("+01:15:00", 60 * (60 + 15))
|
||||
self.check_time_tz("+01:15:29", 60 * (60 + 15))
|
||||
self.check_time_tz("+01:15:30", 60 * (60 + 16))
|
||||
self.check_time_tz("+01:15:59", 60 * (60 + 16))
|
||||
self.check_time_tz("-01:15:00", -60 * (60 + 15))
|
||||
self.check_time_tz("-01:15:29", -60 * (60 + 15))
|
||||
self.check_time_tz("-01:15:30", -60 * (60 + 16))
|
||||
self.check_time_tz("-01:15:59", -60 * (60 + 16))
|
||||
else:
|
||||
self.check_time_tz("+01:15:00", 60 * (60 + 15))
|
||||
self.check_time_tz("+01:15:29", 60 * (60 + 15) + 29)
|
||||
self.check_time_tz("+01:15:30", 60 * (60 + 15) + 30)
|
||||
self.check_time_tz("+01:15:59", 60 * (60 + 15) + 59)
|
||||
self.check_time_tz("-01:15:00", -(60 * (60 + 15)))
|
||||
self.check_time_tz("-01:15:29", -(60 * (60 + 15) + 29))
|
||||
self.check_time_tz("-01:15:30", -(60 * (60 + 15) + 30))
|
||||
self.check_time_tz("-01:15:59", -(60 * (60 + 15) + 59))
|
||||
|
||||
def check_datetime_tz(self, str_offset, offset):
|
||||
base = datetime(2007, 1, 1, 13, 30, 29)
|
||||
base_str = '2007-01-01 13:30:29'
|
||||
|
||||
value = self.DATETIME(base_str + str_offset, self.curs)
|
||||
|
||||
# Value has time zone info and correct UTC offset.
|
||||
self.assertNotEqual(value.tzinfo, None),
|
||||
self.assertEqual(value.utcoffset(), timedelta(seconds=offset))
|
||||
|
||||
# Datetime is correct.
|
||||
self.assertEqual(value.replace(tzinfo=None), base)
|
||||
|
||||
# Conversion to UTC produces the expected offset.
|
||||
UTC = timezone(timedelta(0))
|
||||
value_utc = value.astimezone(UTC).replace(tzinfo=None)
|
||||
self.assertEqual(base - value_utc, timedelta(seconds=offset))
|
||||
|
||||
def test_default_tzinfo(self):
|
||||
self.curs.execute("select '2000-01-01 00:00+02:00'::timestamptz")
|
||||
dt = self.curs.fetchone()[0]
|
||||
self.assert_(isinstance(dt.tzinfo, timezone))
|
||||
self.assertEqual(dt,
|
||||
datetime(2000, 1, 1, tzinfo=timezone(timedelta(minutes=120))))
|
||||
|
||||
def test_fotz_tzinfo(self):
|
||||
self.curs.tzinfo_factory = FixedOffsetTimezone
|
||||
self.curs.execute("select '2000-01-01 00:00+02:00'::timestamptz")
|
||||
dt = self.curs.fetchone()[0]
|
||||
self.assert_(not isinstance(dt.tzinfo, timezone))
|
||||
self.assert_(isinstance(dt.tzinfo, FixedOffsetTimezone))
|
||||
self.assertEqual(dt,
|
||||
datetime(2000, 1, 1, tzinfo=timezone(timedelta(minutes=120))))
|
||||
|
||||
def test_parse_datetime_timezone(self):
|
||||
self.check_datetime_tz("+01", 3600)
|
||||
self.check_datetime_tz("-01", -3600)
|
||||
self.check_datetime_tz("+01:15", 4500)
|
||||
self.check_datetime_tz("-01:15", -4500)
|
||||
if sys.version_info < (3, 7):
|
||||
# The Python < 3.7 datetime module does not support time zone
|
||||
# offsets that are not a whole number of minutes.
|
||||
# We round the offset to the nearest minute.
|
||||
self.check_datetime_tz("+01:15:00", 60 * (60 + 15))
|
||||
self.check_datetime_tz("+01:15:29", 60 * (60 + 15))
|
||||
self.check_datetime_tz("+01:15:30", 60 * (60 + 16))
|
||||
self.check_datetime_tz("+01:15:59", 60 * (60 + 16))
|
||||
self.check_datetime_tz("-01:15:00", -60 * (60 + 15))
|
||||
self.check_datetime_tz("-01:15:29", -60 * (60 + 15))
|
||||
self.check_datetime_tz("-01:15:30", -60 * (60 + 16))
|
||||
self.check_datetime_tz("-01:15:59", -60 * (60 + 16))
|
||||
else:
|
||||
self.check_datetime_tz("+01:15:00", 60 * (60 + 15))
|
||||
self.check_datetime_tz("+01:15:29", 60 * (60 + 15) + 29)
|
||||
self.check_datetime_tz("+01:15:30", 60 * (60 + 15) + 30)
|
||||
self.check_datetime_tz("+01:15:59", 60 * (60 + 15) + 59)
|
||||
self.check_datetime_tz("-01:15:00", -(60 * (60 + 15)))
|
||||
self.check_datetime_tz("-01:15:29", -(60 * (60 + 15) + 29))
|
||||
self.check_datetime_tz("-01:15:30", -(60 * (60 + 15) + 30))
|
||||
self.check_datetime_tz("-01:15:59", -(60 * (60 + 15) + 59))
|
||||
|
||||
def test_parse_time_no_timezone(self):
|
||||
self.assertEqual(self.TIME("13:30:29", self.curs).tzinfo, None)
|
||||
self.assertEqual(self.TIME("13:30:29.123456", self.curs).tzinfo, None)
|
||||
|
||||
def test_parse_datetime_no_timezone(self):
|
||||
self.assertEqual(
|
||||
self.DATETIME("2007-01-01 13:30:29", self.curs).tzinfo, None)
|
||||
self.assertEqual(
|
||||
self.DATETIME("2007-01-01 13:30:29.123456", self.curs).tzinfo, None)
|
||||
|
||||
def test_parse_interval(self):
|
||||
value = self.INTERVAL('42 days 12:34:56.123456', self.curs)
|
||||
self.assertNotEqual(value, None)
|
||||
self.assertEqual(value.days, 42)
|
||||
self.assertEqual(value.seconds, 45296)
|
||||
self.assertEqual(value.microseconds, 123456)
|
||||
|
||||
def test_parse_negative_interval(self):
|
||||
value = self.INTERVAL('-42 days -12:34:56.123456', self.curs)
|
||||
self.assertNotEqual(value, None)
|
||||
self.assertEqual(value.days, -43)
|
||||
self.assertEqual(value.seconds, 41103)
|
||||
self.assertEqual(value.microseconds, 876544)
|
||||
|
||||
def test_parse_infinity(self):
|
||||
value = self.DATETIME('-infinity', self.curs)
|
||||
self.assertEqual(str(value), '0001-01-01 00:00:00')
|
||||
value = self.DATETIME('infinity', self.curs)
|
||||
self.assertEqual(str(value), '9999-12-31 23:59:59.999999')
|
||||
value = self.DATE('infinity', self.curs)
|
||||
self.assertEqual(str(value), '9999-12-31')
|
||||
|
||||
def test_adapt_date(self):
|
||||
value = self.execute('select (%s)::date::text',
|
||||
[date(2007, 1, 1)])
|
||||
self.assertEqual(value, '2007-01-01')
|
||||
|
||||
def test_adapt_time(self):
|
||||
value = self.execute('select (%s)::time::text',
|
||||
[time(13, 30, 29)])
|
||||
self.assertEqual(value, '13:30:29')
|
||||
|
||||
@skip_if_crdb("cast adds tz")
|
||||
def test_adapt_datetime(self):
|
||||
value = self.execute('select (%s)::timestamp::text',
|
||||
[datetime(2007, 1, 1, 13, 30, 29)])
|
||||
self.assertEqual(value, '2007-01-01 13:30:29')
|
||||
|
||||
def test_adapt_timedelta(self):
|
||||
value = self.execute('select extract(epoch from (%s)::interval)',
|
||||
[timedelta(days=42, seconds=45296,
|
||||
microseconds=123456)])
|
||||
seconds = math.floor(value)
|
||||
self.assertEqual(seconds, 3674096)
|
||||
self.assertEqual(int(round((value - seconds) * 1000000)), 123456)
|
||||
|
||||
def test_adapt_negative_timedelta(self):
|
||||
value = self.execute('select extract(epoch from (%s)::interval)',
|
||||
[timedelta(days=-42, seconds=45296,
|
||||
microseconds=123456)])
|
||||
seconds = math.floor(value)
|
||||
self.assertEqual(seconds, -3583504)
|
||||
self.assertEqual(int(round((value - seconds) * 1000000)), 123456)
|
||||
|
||||
def _test_type_roundtrip(self, o1):
|
||||
o2 = self.execute("select %s;", (o1,))
|
||||
self.assertEqual(type(o1), type(o2))
|
||||
return o2
|
||||
|
||||
def _test_type_roundtrip_array(self, o1):
|
||||
o1 = [o1]
|
||||
o2 = self.execute("select %s;", (o1,))
|
||||
self.assertEqual(type(o1[0]), type(o2[0]))
|
||||
|
||||
def test_type_roundtrip_date(self):
|
||||
self._test_type_roundtrip(date(2010, 5, 3))
|
||||
|
||||
def test_type_roundtrip_datetime(self):
|
||||
dt = self._test_type_roundtrip(datetime(2010, 5, 3, 10, 20, 30))
|
||||
self.assertEqual(None, dt.tzinfo)
|
||||
|
||||
def test_type_roundtrip_datetimetz(self):
|
||||
tz = timezone(timedelta(minutes=8 * 60))
|
||||
dt1 = datetime(2010, 5, 3, 10, 20, 30, tzinfo=tz)
|
||||
dt2 = self._test_type_roundtrip(dt1)
|
||||
self.assertNotEqual(None, dt2.tzinfo)
|
||||
self.assertEqual(dt1, dt2)
|
||||
|
||||
def test_type_roundtrip_time(self):
|
||||
tm = self._test_type_roundtrip(time(10, 20, 30))
|
||||
self.assertEqual(None, tm.tzinfo)
|
||||
|
||||
def test_type_roundtrip_timetz(self):
|
||||
tz = timezone(timedelta(minutes=8 * 60))
|
||||
tm1 = time(10, 20, 30, tzinfo=tz)
|
||||
tm2 = self._test_type_roundtrip(tm1)
|
||||
self.assertNotEqual(None, tm2.tzinfo)
|
||||
self.assertEqual(tm1, tm2)
|
||||
|
||||
def test_type_roundtrip_interval(self):
|
||||
self._test_type_roundtrip(timedelta(seconds=30))
|
||||
|
||||
def test_type_roundtrip_date_array(self):
|
||||
self._test_type_roundtrip_array(date(2010, 5, 3))
|
||||
|
||||
def test_type_roundtrip_datetime_array(self):
|
||||
self._test_type_roundtrip_array(datetime(2010, 5, 3, 10, 20, 30))
|
||||
|
||||
def test_type_roundtrip_datetimetz_array(self):
|
||||
self._test_type_roundtrip_array(
|
||||
datetime(2010, 5, 3, 10, 20, 30, tzinfo=timezone(timedelta(0))))
|
||||
|
||||
def test_type_roundtrip_time_array(self):
|
||||
self._test_type_roundtrip_array(time(10, 20, 30))
|
||||
|
||||
def test_type_roundtrip_interval_array(self):
|
||||
self._test_type_roundtrip_array(timedelta(seconds=30))
|
||||
|
||||
@skip_before_postgres(8, 1)
|
||||
def test_time_24(self):
|
||||
t = self.execute("select '24:00'::time;")
|
||||
self.assertEqual(t, time(0, 0))
|
||||
|
||||
t = self.execute("select '24:00+05'::timetz;")
|
||||
self.assertEqual(t, time(0, 0, tzinfo=timezone(timedelta(minutes=300))))
|
||||
|
||||
t = self.execute("select '24:00+05:30'::timetz;")
|
||||
self.assertEqual(t, time(0, 0, tzinfo=timezone(timedelta(minutes=330))))
|
||||
|
||||
@skip_before_postgres(8, 1)
|
||||
def test_large_interval(self):
|
||||
t = self.execute("select '999999:00:00'::interval")
|
||||
self.assertEqual(total_seconds(t), 999999 * 60 * 60)
|
||||
|
||||
t = self.execute("select '-999999:00:00'::interval")
|
||||
self.assertEqual(total_seconds(t), -999999 * 60 * 60)
|
||||
|
||||
t = self.execute("select '999999:00:00.1'::interval")
|
||||
self.assertEqual(total_seconds(t), 999999 * 60 * 60 + 0.1)
|
||||
|
||||
t = self.execute("select '999999:00:00.9'::interval")
|
||||
self.assertEqual(total_seconds(t), 999999 * 60 * 60 + 0.9)
|
||||
|
||||
t = self.execute("select '-999999:00:00.1'::interval")
|
||||
self.assertEqual(total_seconds(t), -999999 * 60 * 60 - 0.1)
|
||||
|
||||
t = self.execute("select '-999999:00:00.9'::interval")
|
||||
self.assertEqual(total_seconds(t), -999999 * 60 * 60 - 0.9)
|
||||
|
||||
def test_micros_rounding(self):
|
||||
t = self.execute("select '0.1'::interval")
|
||||
self.assertEqual(total_seconds(t), 0.1)
|
||||
|
||||
t = self.execute("select '0.01'::interval")
|
||||
self.assertEqual(total_seconds(t), 0.01)
|
||||
|
||||
t = self.execute("select '0.000001'::interval")
|
||||
self.assertEqual(total_seconds(t), 1e-6)
|
||||
|
||||
t = self.execute("select '0.0000004'::interval")
|
||||
self.assertEqual(total_seconds(t), 0)
|
||||
|
||||
t = self.execute("select '0.0000006'::interval")
|
||||
self.assertEqual(total_seconds(t), 1e-6)
|
||||
|
||||
def test_interval_overflow(self):
|
||||
cur = self.conn.cursor()
|
||||
# hack a cursor to receive values too extreme to be represented
|
||||
# but still I want an error, not a random number
|
||||
psycopg2.extensions.register_type(
|
||||
psycopg2.extensions.new_type(
|
||||
psycopg2.STRING.values, 'WAT', psycopg2.extensions.INTERVAL),
|
||||
cur)
|
||||
|
||||
def f(val):
|
||||
cur.execute(f"select '{val}'::text")
|
||||
return cur.fetchone()[0]
|
||||
|
||||
self.assertRaises(OverflowError, f, '100000000000000000:00:00')
|
||||
self.assertRaises(OverflowError, f, '00:100000000000000000:00:00')
|
||||
self.assertRaises(OverflowError, f, '00:00:100000000000000000:00')
|
||||
self.assertRaises(OverflowError, f, '00:00:00.100000000000000000')
|
||||
|
||||
@skip_if_crdb("infinity date")
|
||||
def test_adapt_infinity_tz(self):
|
||||
t = self.execute("select 'infinity'::timestamp")
|
||||
self.assert_(t.tzinfo is None)
|
||||
self.assert_(t > datetime(4000, 1, 1))
|
||||
|
||||
t = self.execute("select '-infinity'::timestamp")
|
||||
self.assert_(t.tzinfo is None)
|
||||
self.assert_(t < datetime(1000, 1, 1))
|
||||
|
||||
t = self.execute("select 'infinity'::timestamptz")
|
||||
self.assert_(t.tzinfo is not None)
|
||||
self.assert_(t > datetime(4000, 1, 1, tzinfo=timezone(timedelta(0))))
|
||||
|
||||
t = self.execute("select '-infinity'::timestamptz")
|
||||
self.assert_(t.tzinfo is not None)
|
||||
self.assert_(t < datetime(1000, 1, 1, tzinfo=timezone(timedelta(0))))
|
||||
|
||||
def test_redshift_day(self):
|
||||
# Redshift is reported returning 1 day interval as microsec (bug #558)
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extensions.register_type(
|
||||
psycopg2.extensions.new_type(
|
||||
psycopg2.STRING.values, 'WAT', psycopg2.extensions.INTERVAL),
|
||||
cur)
|
||||
|
||||
for s, v in [
|
||||
('0', timedelta(0)),
|
||||
('1', timedelta(microseconds=1)),
|
||||
('-1', timedelta(microseconds=-1)),
|
||||
('1000000', timedelta(seconds=1)),
|
||||
('86400000000', timedelta(days=1)),
|
||||
('-86400000000', timedelta(days=-1)),
|
||||
]:
|
||||
cur.execute("select %s::text", (s,))
|
||||
r = cur.fetchone()[0]
|
||||
self.assertEqual(r, v, f"{s} -> {r} != {v}")
|
||||
|
||||
@skip_if_crdb("interval style")
|
||||
@skip_before_postgres(8, 4)
|
||||
def test_interval_iso_8601_not_supported(self):
|
||||
# We may end up supporting, but no pressure for it
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("set local intervalstyle to iso_8601")
|
||||
cur.execute("select '1 day 2 hours'::interval")
|
||||
self.assertRaises(psycopg2.NotSupportedError, cur.fetchone)
|
||||
|
||||
|
||||
class FromTicksTestCase(unittest.TestCase):
|
||||
# bug "TimestampFromTicks() throws ValueError (2-2.0.14)"
|
||||
# reported by Jozsef Szalay on 2010-05-06
|
||||
def test_timestamp_value_error_sec_59_99(self):
|
||||
s = psycopg2.TimestampFromTicks(1273173119.99992)
|
||||
self.assertEqual(s.adapted,
|
||||
datetime(2010, 5, 6, 14, 11, 59, 999920,
|
||||
tzinfo=timezone(timedelta(minutes=-5 * 60))))
|
||||
|
||||
def test_date_value_error_sec_59_99(self):
|
||||
s = psycopg2.DateFromTicks(1273173119.99992)
|
||||
# The returned date is local
|
||||
self.assert_(s.adapted in [date(2010, 5, 6), date(2010, 5, 7)])
|
||||
|
||||
def test_time_value_error_sec_59_99(self):
|
||||
s = psycopg2.TimeFromTicks(1273173119.99992)
|
||||
self.assertEqual(s.adapted.replace(hour=0),
|
||||
time(0, 11, 59, 999920))
|
||||
|
||||
|
||||
class FixedOffsetTimezoneTests(unittest.TestCase):
|
||||
|
||||
def test_init_with_no_args(self):
|
||||
tzinfo = FixedOffsetTimezone()
|
||||
self.assert_(tzinfo._offset is ZERO)
|
||||
self.assert_(tzinfo._name is None)
|
||||
|
||||
def test_repr_with_positive_offset(self):
|
||||
tzinfo = FixedOffsetTimezone(5 * 60)
|
||||
self.assertEqual(repr(tzinfo),
|
||||
"psycopg2.tz.FixedOffsetTimezone(offset=%r, name=None)"
|
||||
% timedelta(minutes=5 * 60))
|
||||
|
||||
def test_repr_with_negative_offset(self):
|
||||
tzinfo = FixedOffsetTimezone(-5 * 60)
|
||||
self.assertEqual(repr(tzinfo),
|
||||
"psycopg2.tz.FixedOffsetTimezone(offset=%r, name=None)"
|
||||
% timedelta(minutes=-5 * 60))
|
||||
|
||||
def test_init_with_timedelta(self):
|
||||
td = timedelta(minutes=5 * 60)
|
||||
tzinfo = FixedOffsetTimezone(td)
|
||||
self.assertEqual(tzinfo, FixedOffsetTimezone(5 * 60))
|
||||
self.assertEqual(repr(tzinfo),
|
||||
"psycopg2.tz.FixedOffsetTimezone(offset=%r, name=None)" % td)
|
||||
|
||||
def test_repr_with_name(self):
|
||||
tzinfo = FixedOffsetTimezone(name="FOO")
|
||||
self.assertEqual(repr(tzinfo),
|
||||
"psycopg2.tz.FixedOffsetTimezone(offset=%r, name='FOO')"
|
||||
% timedelta(0))
|
||||
|
||||
def test_instance_caching(self):
|
||||
self.assert_(FixedOffsetTimezone(name="FOO")
|
||||
is FixedOffsetTimezone(name="FOO"))
|
||||
self.assert_(FixedOffsetTimezone(7 * 60)
|
||||
is FixedOffsetTimezone(7 * 60))
|
||||
self.assert_(FixedOffsetTimezone(-9 * 60, 'FOO')
|
||||
is FixedOffsetTimezone(-9 * 60, 'FOO'))
|
||||
self.assert_(FixedOffsetTimezone(9 * 60)
|
||||
is not FixedOffsetTimezone(9 * 60, 'FOO'))
|
||||
self.assert_(FixedOffsetTimezone(name='FOO')
|
||||
is not FixedOffsetTimezone(9 * 60, 'FOO'))
|
||||
|
||||
def test_pickle(self):
|
||||
# ticket #135
|
||||
tz11 = FixedOffsetTimezone(60)
|
||||
tz12 = FixedOffsetTimezone(120)
|
||||
for proto in [-1, 0, 1, 2]:
|
||||
tz21, tz22 = pickle.loads(pickle.dumps([tz11, tz12], proto))
|
||||
self.assertEqual(tz11, tz21)
|
||||
self.assertEqual(tz12, tz22)
|
||||
|
||||
tz11 = FixedOffsetTimezone(60, name='foo')
|
||||
tz12 = FixedOffsetTimezone(120, name='bar')
|
||||
for proto in [-1, 0, 1, 2]:
|
||||
tz21, tz22 = pickle.loads(pickle.dumps([tz11, tz12], proto))
|
||||
self.assertEqual(tz11, tz21)
|
||||
self.assertEqual(tz12, tz22)
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
74
tests/test_errcodes.py
Executable file
74
tests/test_errcodes.py
Executable file
@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# test_errcodes.py - unit test for psycopg2.errcodes module
|
||||
#
|
||||
# Copyright (C) 2015-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import unittest
|
||||
from .testutils import ConnectingTestCase, slow, reload
|
||||
|
||||
from threading import Thread
|
||||
from psycopg2 import errorcodes
|
||||
|
||||
|
||||
class ErrocodeTests(ConnectingTestCase):
|
||||
@slow
|
||||
def test_lookup_threadsafe(self):
|
||||
|
||||
# Increase if it does not fail with KeyError
|
||||
MAX_CYCLES = 2000
|
||||
|
||||
errs = []
|
||||
|
||||
def f(pg_code='40001'):
|
||||
try:
|
||||
errorcodes.lookup(pg_code)
|
||||
except Exception as e:
|
||||
errs.append(e)
|
||||
|
||||
for __ in range(MAX_CYCLES):
|
||||
reload(errorcodes)
|
||||
(t1, t2) = (Thread(target=f), Thread(target=f))
|
||||
(t1.start(), t2.start())
|
||||
(t1.join(), t2.join())
|
||||
|
||||
if errs:
|
||||
self.fail(
|
||||
"raised {} errors in {} cycles (first is {} {})".format(
|
||||
len(errs), MAX_CYCLES,
|
||||
errs[0].__class__.__name__, errs[0]))
|
||||
|
||||
def test_ambiguous_names(self):
|
||||
self.assertEqual(
|
||||
errorcodes.lookup('2F004'), "READING_SQL_DATA_NOT_PERMITTED")
|
||||
self.assertEqual(
|
||||
errorcodes.lookup('38004'), "READING_SQL_DATA_NOT_PERMITTED")
|
||||
self.assertEqual(errorcodes.READING_SQL_DATA_NOT_PERMITTED, '38004')
|
||||
self.assertEqual(errorcodes.READING_SQL_DATA_NOT_PERMITTED_, '2F004')
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
94
tests/test_errors.py
Executable file
94
tests/test_errors.py
Executable file
@ -0,0 +1,94 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# test_errors.py - unit test for psycopg2.errors module
|
||||
#
|
||||
# Copyright (C) 2018-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import unittest
|
||||
from .testutils import ConnectingTestCase
|
||||
|
||||
import psycopg2
|
||||
from psycopg2 import errors
|
||||
from psycopg2._psycopg import sqlstate_errors
|
||||
from psycopg2.errors import UndefinedTable
|
||||
|
||||
|
||||
class ErrorsTests(ConnectingTestCase):
|
||||
def test_exception_class(self):
|
||||
cur = self.conn.cursor()
|
||||
try:
|
||||
cur.execute("select * from nonexist")
|
||||
except psycopg2.Error as exc:
|
||||
e = exc
|
||||
|
||||
self.assert_(isinstance(e, UndefinedTable), type(e))
|
||||
self.assert_(isinstance(e, self.conn.ProgrammingError))
|
||||
|
||||
def test_exception_class_fallback(self):
|
||||
cur = self.conn.cursor()
|
||||
|
||||
x = sqlstate_errors.pop('42P01')
|
||||
try:
|
||||
cur.execute("select * from nonexist")
|
||||
except psycopg2.Error as exc:
|
||||
e = exc
|
||||
finally:
|
||||
sqlstate_errors['42P01'] = x
|
||||
|
||||
self.assertEqual(type(e), self.conn.ProgrammingError)
|
||||
|
||||
def test_lookup(self):
|
||||
self.assertIs(errors.lookup('42P01'), errors.UndefinedTable)
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
errors.lookup('XXXXX')
|
||||
|
||||
def test_connection_exceptions_backwards_compatibility(self):
|
||||
err = errors.lookup('08000')
|
||||
# connection exceptions are classified as operational errors
|
||||
self.assert_(issubclass(err, errors.OperationalError))
|
||||
# previously these errors were classified only as DatabaseError
|
||||
self.assert_(issubclass(err, errors.DatabaseError))
|
||||
|
||||
def test_has_base_exceptions(self):
|
||||
excs = []
|
||||
for n in dir(psycopg2):
|
||||
obj = getattr(psycopg2, n)
|
||||
if isinstance(obj, type) and issubclass(obj, Exception):
|
||||
excs.append(obj)
|
||||
|
||||
self.assert_(len(excs) > 8, str(excs))
|
||||
|
||||
excs.append(psycopg2.extensions.QueryCanceledError)
|
||||
excs.append(psycopg2.extensions.TransactionRollbackError)
|
||||
|
||||
for exc in excs:
|
||||
self.assert_(hasattr(errors, exc.__name__))
|
||||
self.assert_(getattr(errors, exc.__name__) is exc)
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
646
tests/test_extras_dictcursor.py
Executable file
646
tests/test_extras_dictcursor.py
Executable file
@ -0,0 +1,646 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# extras_dictcursor - test if DictCursor extension class works
|
||||
#
|
||||
# Copyright (C) 2004-2019 Federico Di Gregorio <fog@debian.org>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import copy
|
||||
import time
|
||||
import pickle
|
||||
import unittest
|
||||
from datetime import timedelta
|
||||
from functools import lru_cache
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
from psycopg2.extras import NamedTupleConnection, NamedTupleCursor
|
||||
|
||||
from .testutils import ConnectingTestCase, skip_before_postgres, \
|
||||
crdb_version, skip_if_crdb
|
||||
|
||||
|
||||
class _DictCursorBase(ConnectingTestCase):
|
||||
def setUp(self):
|
||||
ConnectingTestCase.setUp(self)
|
||||
curs = self.conn.cursor()
|
||||
if crdb_version(self.conn) is not None:
|
||||
curs.execute("SET experimental_enable_temp_tables = 'on'")
|
||||
curs.execute("CREATE TEMPORARY TABLE ExtrasDictCursorTests (foo text)")
|
||||
curs.execute("INSERT INTO ExtrasDictCursorTests VALUES ('bar')")
|
||||
self.conn.commit()
|
||||
|
||||
def _testIterRowNumber(self, curs):
|
||||
# Only checking for dataset < itersize:
|
||||
# see CursorTests.test_iter_named_cursor_rownumber
|
||||
curs.itersize = 20
|
||||
curs.execute("""select * from generate_series(1,10)""")
|
||||
for i, r in enumerate(curs):
|
||||
self.assertEqual(i + 1, curs.rownumber)
|
||||
|
||||
def _testNamedCursorNotGreedy(self, curs):
|
||||
curs.itersize = 2
|
||||
curs.execute("""select clock_timestamp() as ts from generate_series(1,3)""")
|
||||
recs = []
|
||||
for t in curs:
|
||||
time.sleep(0.01)
|
||||
recs.append(t)
|
||||
|
||||
# check that the dataset was not fetched in a single gulp
|
||||
self.assert_(recs[1]['ts'] - recs[0]['ts'] < timedelta(seconds=0.005))
|
||||
self.assert_(recs[2]['ts'] - recs[1]['ts'] > timedelta(seconds=0.0099))
|
||||
|
||||
|
||||
class ExtrasDictCursorTests(_DictCursorBase):
|
||||
"""Test if DictCursor extension class works."""
|
||||
|
||||
@skip_if_crdb("named cursor")
|
||||
def testDictConnCursorArgs(self):
|
||||
self.conn.close()
|
||||
self.conn = self.connect(connection_factory=psycopg2.extras.DictConnection)
|
||||
cur = self.conn.cursor()
|
||||
self.assert_(isinstance(cur, psycopg2.extras.DictCursor))
|
||||
self.assertEqual(cur.name, None)
|
||||
# overridable
|
||||
cur = self.conn.cursor('foo',
|
||||
cursor_factory=psycopg2.extras.NamedTupleCursor)
|
||||
self.assertEqual(cur.name, 'foo')
|
||||
self.assert_(isinstance(cur, psycopg2.extras.NamedTupleCursor))
|
||||
|
||||
def testDictCursorWithPlainCursorFetchOne(self):
|
||||
self._testWithPlainCursor(lambda curs: curs.fetchone())
|
||||
|
||||
def testDictCursorWithPlainCursorFetchMany(self):
|
||||
self._testWithPlainCursor(lambda curs: curs.fetchmany(100)[0])
|
||||
|
||||
def testDictCursorWithPlainCursorFetchManyNoarg(self):
|
||||
self._testWithPlainCursor(lambda curs: curs.fetchmany()[0])
|
||||
|
||||
def testDictCursorWithPlainCursorFetchAll(self):
|
||||
self._testWithPlainCursor(lambda curs: curs.fetchall()[0])
|
||||
|
||||
def testDictCursorWithPlainCursorIter(self):
|
||||
def getter(curs):
|
||||
for row in curs:
|
||||
return row
|
||||
self._testWithPlainCursor(getter)
|
||||
|
||||
def testUpdateRow(self):
|
||||
row = self._testWithPlainCursor(lambda curs: curs.fetchone())
|
||||
row['foo'] = 'qux'
|
||||
self.failUnless(row['foo'] == 'qux')
|
||||
self.failUnless(row[0] == 'qux')
|
||||
|
||||
@skip_before_postgres(8, 0)
|
||||
def testDictCursorWithPlainCursorIterRowNumber(self):
|
||||
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
|
||||
self._testIterRowNumber(curs)
|
||||
|
||||
def _testWithPlainCursor(self, getter):
|
||||
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
|
||||
curs.execute("SELECT * FROM ExtrasDictCursorTests")
|
||||
row = getter(curs)
|
||||
self.failUnless(row['foo'] == 'bar')
|
||||
self.failUnless(row[0] == 'bar')
|
||||
return row
|
||||
|
||||
def testDictCursorWithNamedCursorFetchOne(self):
|
||||
self._testWithNamedCursor(lambda curs: curs.fetchone())
|
||||
|
||||
def testDictCursorWithNamedCursorFetchMany(self):
|
||||
self._testWithNamedCursor(lambda curs: curs.fetchmany(100)[0])
|
||||
|
||||
def testDictCursorWithNamedCursorFetchManyNoarg(self):
|
||||
self._testWithNamedCursor(lambda curs: curs.fetchmany()[0])
|
||||
|
||||
def testDictCursorWithNamedCursorFetchAll(self):
|
||||
self._testWithNamedCursor(lambda curs: curs.fetchall()[0])
|
||||
|
||||
def testDictCursorWithNamedCursorIter(self):
|
||||
def getter(curs):
|
||||
for row in curs:
|
||||
return row
|
||||
self._testWithNamedCursor(getter)
|
||||
|
||||
@skip_if_crdb("named cursor")
|
||||
@skip_before_postgres(8, 2)
|
||||
def testDictCursorWithNamedCursorNotGreedy(self):
|
||||
curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.DictCursor)
|
||||
self._testNamedCursorNotGreedy(curs)
|
||||
|
||||
@skip_if_crdb("named cursor")
|
||||
@skip_before_postgres(8, 0)
|
||||
def testDictCursorWithNamedCursorIterRowNumber(self):
|
||||
curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.DictCursor)
|
||||
self._testIterRowNumber(curs)
|
||||
|
||||
@skip_if_crdb("named cursor")
|
||||
def _testWithNamedCursor(self, getter):
|
||||
curs = self.conn.cursor('aname', cursor_factory=psycopg2.extras.DictCursor)
|
||||
curs.execute("SELECT * FROM ExtrasDictCursorTests")
|
||||
row = getter(curs)
|
||||
self.failUnless(row['foo'] == 'bar')
|
||||
self.failUnless(row[0] == 'bar')
|
||||
|
||||
def testPickleDictRow(self):
|
||||
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
|
||||
curs.execute("select 10 as a, 20 as b")
|
||||
r = curs.fetchone()
|
||||
d = pickle.dumps(r)
|
||||
r1 = pickle.loads(d)
|
||||
self.assertEqual(r, r1)
|
||||
self.assertEqual(r[0], r1[0])
|
||||
self.assertEqual(r[1], r1[1])
|
||||
self.assertEqual(r['a'], r1['a'])
|
||||
self.assertEqual(r['b'], r1['b'])
|
||||
self.assertEqual(r._index, r1._index)
|
||||
|
||||
def test_copy(self):
|
||||
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
|
||||
curs.execute("select 10 as foo, 'hi' as bar")
|
||||
rv = curs.fetchone()
|
||||
self.assertEqual(len(rv), 2)
|
||||
|
||||
rv2 = copy.copy(rv)
|
||||
self.assertEqual(len(rv2), 2)
|
||||
self.assertEqual(len(rv), 2)
|
||||
|
||||
rv3 = copy.deepcopy(rv)
|
||||
self.assertEqual(len(rv3), 2)
|
||||
self.assertEqual(len(rv), 2)
|
||||
|
||||
def test_iter_methods(self):
|
||||
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
|
||||
curs.execute("select 10 as a, 20 as b")
|
||||
r = curs.fetchone()
|
||||
self.assert_(not isinstance(r.keys(), list))
|
||||
self.assertEqual(len(list(r.keys())), 2)
|
||||
self.assert_(not isinstance(r.values(), list))
|
||||
self.assertEqual(len(list(r.values())), 2)
|
||||
self.assert_(not isinstance(r.items(), list))
|
||||
self.assertEqual(len(list(r.items())), 2)
|
||||
|
||||
def test_order(self):
|
||||
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
|
||||
curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux")
|
||||
r = curs.fetchone()
|
||||
self.assertEqual(list(r), [5, 4, 33, 2])
|
||||
self.assertEqual(list(r.keys()), ['foo', 'bar', 'baz', 'qux'])
|
||||
self.assertEqual(list(r.values()), [5, 4, 33, 2])
|
||||
self.assertEqual(list(r.items()),
|
||||
[('foo', 5), ('bar', 4), ('baz', 33), ('qux', 2)])
|
||||
|
||||
r1 = pickle.loads(pickle.dumps(r))
|
||||
self.assertEqual(list(r1), list(r))
|
||||
self.assertEqual(list(r1.keys()), list(r.keys()))
|
||||
self.assertEqual(list(r1.values()), list(r.values()))
|
||||
self.assertEqual(list(r1.items()), list(r.items()))
|
||||
|
||||
|
||||
class ExtrasDictCursorRealTests(_DictCursorBase):
|
||||
def testRealMeansReal(self):
|
||||
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
||||
curs.execute("SELECT * FROM ExtrasDictCursorTests")
|
||||
row = curs.fetchone()
|
||||
self.assert_(isinstance(row, dict))
|
||||
|
||||
def testDictCursorWithPlainCursorRealFetchOne(self):
|
||||
self._testWithPlainCursorReal(lambda curs: curs.fetchone())
|
||||
|
||||
def testDictCursorWithPlainCursorRealFetchMany(self):
|
||||
self._testWithPlainCursorReal(lambda curs: curs.fetchmany(100)[0])
|
||||
|
||||
def testDictCursorWithPlainCursorRealFetchManyNoarg(self):
|
||||
self._testWithPlainCursorReal(lambda curs: curs.fetchmany()[0])
|
||||
|
||||
def testDictCursorWithPlainCursorRealFetchAll(self):
|
||||
self._testWithPlainCursorReal(lambda curs: curs.fetchall()[0])
|
||||
|
||||
def testDictCursorWithPlainCursorRealIter(self):
|
||||
def getter(curs):
|
||||
for row in curs:
|
||||
return row
|
||||
self._testWithPlainCursorReal(getter)
|
||||
|
||||
@skip_before_postgres(8, 0)
|
||||
def testDictCursorWithPlainCursorRealIterRowNumber(self):
|
||||
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
||||
self._testIterRowNumber(curs)
|
||||
|
||||
def _testWithPlainCursorReal(self, getter):
|
||||
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
||||
curs.execute("SELECT * FROM ExtrasDictCursorTests")
|
||||
row = getter(curs)
|
||||
self.failUnless(row['foo'] == 'bar')
|
||||
|
||||
def testPickleRealDictRow(self):
|
||||
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
||||
curs.execute("select 10 as a, 20 as b")
|
||||
r = curs.fetchone()
|
||||
d = pickle.dumps(r)
|
||||
r1 = pickle.loads(d)
|
||||
self.assertEqual(r, r1)
|
||||
self.assertEqual(r['a'], r1['a'])
|
||||
self.assertEqual(r['b'], r1['b'])
|
||||
|
||||
def test_copy(self):
|
||||
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
||||
curs.execute("select 10 as foo, 'hi' as bar")
|
||||
rv = curs.fetchone()
|
||||
self.assertEqual(len(rv), 2)
|
||||
|
||||
rv2 = copy.copy(rv)
|
||||
self.assertEqual(len(rv2), 2)
|
||||
self.assertEqual(len(rv), 2)
|
||||
|
||||
rv3 = copy.deepcopy(rv)
|
||||
self.assertEqual(len(rv3), 2)
|
||||
self.assertEqual(len(rv), 2)
|
||||
|
||||
def testDictCursorRealWithNamedCursorFetchOne(self):
|
||||
self._testWithNamedCursorReal(lambda curs: curs.fetchone())
|
||||
|
||||
def testDictCursorRealWithNamedCursorFetchMany(self):
|
||||
self._testWithNamedCursorReal(lambda curs: curs.fetchmany(100)[0])
|
||||
|
||||
def testDictCursorRealWithNamedCursorFetchManyNoarg(self):
|
||||
self._testWithNamedCursorReal(lambda curs: curs.fetchmany()[0])
|
||||
|
||||
def testDictCursorRealWithNamedCursorFetchAll(self):
|
||||
self._testWithNamedCursorReal(lambda curs: curs.fetchall()[0])
|
||||
|
||||
def testDictCursorRealWithNamedCursorIter(self):
|
||||
def getter(curs):
|
||||
for row in curs:
|
||||
return row
|
||||
self._testWithNamedCursorReal(getter)
|
||||
|
||||
@skip_if_crdb("named cursor")
|
||||
@skip_before_postgres(8, 2)
|
||||
def testDictCursorRealWithNamedCursorNotGreedy(self):
|
||||
curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.RealDictCursor)
|
||||
self._testNamedCursorNotGreedy(curs)
|
||||
|
||||
@skip_if_crdb("named cursor")
|
||||
@skip_before_postgres(8, 0)
|
||||
def testDictCursorRealWithNamedCursorIterRowNumber(self):
|
||||
curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.RealDictCursor)
|
||||
self._testIterRowNumber(curs)
|
||||
|
||||
@skip_if_crdb("named cursor")
|
||||
def _testWithNamedCursorReal(self, getter):
|
||||
curs = self.conn.cursor('aname',
|
||||
cursor_factory=psycopg2.extras.RealDictCursor)
|
||||
curs.execute("SELECT * FROM ExtrasDictCursorTests")
|
||||
row = getter(curs)
|
||||
self.failUnless(row['foo'] == 'bar')
|
||||
|
||||
def test_iter_methods(self):
|
||||
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
||||
curs.execute("select 10 as a, 20 as b")
|
||||
r = curs.fetchone()
|
||||
self.assert_(not isinstance(r.keys(), list))
|
||||
self.assertEqual(len(list(r.keys())), 2)
|
||||
self.assert_(not isinstance(r.values(), list))
|
||||
self.assertEqual(len(list(r.values())), 2)
|
||||
self.assert_(not isinstance(r.items(), list))
|
||||
self.assertEqual(len(list(r.items())), 2)
|
||||
|
||||
def test_order(self):
|
||||
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
||||
curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux")
|
||||
r = curs.fetchone()
|
||||
self.assertEqual(list(r), ['foo', 'bar', 'baz', 'qux'])
|
||||
self.assertEqual(list(r.keys()), ['foo', 'bar', 'baz', 'qux'])
|
||||
self.assertEqual(list(r.values()), [5, 4, 33, 2])
|
||||
self.assertEqual(list(r.items()),
|
||||
[('foo', 5), ('bar', 4), ('baz', 33), ('qux', 2)])
|
||||
|
||||
r1 = pickle.loads(pickle.dumps(r))
|
||||
self.assertEqual(list(r1), list(r))
|
||||
self.assertEqual(list(r1.keys()), list(r.keys()))
|
||||
self.assertEqual(list(r1.values()), list(r.values()))
|
||||
self.assertEqual(list(r1.items()), list(r.items()))
|
||||
|
||||
def test_pop(self):
|
||||
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
||||
curs.execute("select 1 as a, 2 as b, 3 as c")
|
||||
r = curs.fetchone()
|
||||
self.assertEqual(r.pop('b'), 2)
|
||||
self.assertEqual(list(r), ['a', 'c'])
|
||||
self.assertEqual(list(r.keys()), ['a', 'c'])
|
||||
self.assertEqual(list(r.values()), [1, 3])
|
||||
self.assertEqual(list(r.items()), [('a', 1), ('c', 3)])
|
||||
|
||||
self.assertEqual(r.pop('b', None), None)
|
||||
self.assertRaises(KeyError, r.pop, 'b')
|
||||
|
||||
def test_mod(self):
|
||||
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
||||
curs.execute("select 1 as a, 2 as b, 3 as c")
|
||||
r = curs.fetchone()
|
||||
r['d'] = 4
|
||||
self.assertEqual(list(r), ['a', 'b', 'c', 'd'])
|
||||
self.assertEqual(list(r.keys()), ['a', 'b', 'c', 'd'])
|
||||
self.assertEqual(list(r.values()), [1, 2, 3, 4])
|
||||
self.assertEqual(list(
|
||||
r.items()), [('a', 1), ('b', 2), ('c', 3), ('d', 4)])
|
||||
|
||||
assert r['a'] == 1
|
||||
assert r['b'] == 2
|
||||
assert r['c'] == 3
|
||||
assert r['d'] == 4
|
||||
|
||||
|
||||
class NamedTupleCursorTest(ConnectingTestCase):
|
||||
def setUp(self):
|
||||
ConnectingTestCase.setUp(self)
|
||||
|
||||
self.conn = self.connect(connection_factory=NamedTupleConnection)
|
||||
curs = self.conn.cursor()
|
||||
if crdb_version(self.conn) is not None:
|
||||
curs.execute("SET experimental_enable_temp_tables = 'on'")
|
||||
curs.execute("CREATE TEMPORARY TABLE nttest (i int, s text)")
|
||||
curs.execute("INSERT INTO nttest VALUES (1, 'foo')")
|
||||
curs.execute("INSERT INTO nttest VALUES (2, 'bar')")
|
||||
curs.execute("INSERT INTO nttest VALUES (3, 'baz')")
|
||||
self.conn.commit()
|
||||
|
||||
@skip_if_crdb("named cursor")
|
||||
def test_cursor_args(self):
|
||||
cur = self.conn.cursor('foo', cursor_factory=psycopg2.extras.DictCursor)
|
||||
self.assertEqual(cur.name, 'foo')
|
||||
self.assert_(isinstance(cur, psycopg2.extras.DictCursor))
|
||||
|
||||
def test_fetchone(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select * from nttest order by 1")
|
||||
t = curs.fetchone()
|
||||
self.assertEqual(t[0], 1)
|
||||
self.assertEqual(t.i, 1)
|
||||
self.assertEqual(t[1], 'foo')
|
||||
self.assertEqual(t.s, 'foo')
|
||||
self.assertEqual(curs.rownumber, 1)
|
||||
self.assertEqual(curs.rowcount, 3)
|
||||
|
||||
def test_fetchmany_noarg(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.arraysize = 2
|
||||
curs.execute("select * from nttest order by 1")
|
||||
res = curs.fetchmany()
|
||||
self.assertEqual(2, len(res))
|
||||
self.assertEqual(res[0].i, 1)
|
||||
self.assertEqual(res[0].s, 'foo')
|
||||
self.assertEqual(res[1].i, 2)
|
||||
self.assertEqual(res[1].s, 'bar')
|
||||
self.assertEqual(curs.rownumber, 2)
|
||||
self.assertEqual(curs.rowcount, 3)
|
||||
|
||||
def test_fetchmany(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select * from nttest order by 1")
|
||||
res = curs.fetchmany(2)
|
||||
self.assertEqual(2, len(res))
|
||||
self.assertEqual(res[0].i, 1)
|
||||
self.assertEqual(res[0].s, 'foo')
|
||||
self.assertEqual(res[1].i, 2)
|
||||
self.assertEqual(res[1].s, 'bar')
|
||||
self.assertEqual(curs.rownumber, 2)
|
||||
self.assertEqual(curs.rowcount, 3)
|
||||
|
||||
def test_fetchall(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select * from nttest order by 1")
|
||||
res = curs.fetchall()
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertEqual(res[0].i, 1)
|
||||
self.assertEqual(res[0].s, 'foo')
|
||||
self.assertEqual(res[1].i, 2)
|
||||
self.assertEqual(res[1].s, 'bar')
|
||||
self.assertEqual(res[2].i, 3)
|
||||
self.assertEqual(res[2].s, 'baz')
|
||||
self.assertEqual(curs.rownumber, 3)
|
||||
self.assertEqual(curs.rowcount, 3)
|
||||
|
||||
def test_executemany(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.executemany("delete from nttest where i = %s",
|
||||
[(1,), (2,)])
|
||||
curs.execute("select * from nttest order by 1")
|
||||
res = curs.fetchall()
|
||||
self.assertEqual(1, len(res))
|
||||
self.assertEqual(res[0].i, 3)
|
||||
self.assertEqual(res[0].s, 'baz')
|
||||
|
||||
def test_iter(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select * from nttest order by 1")
|
||||
i = iter(curs)
|
||||
self.assertEqual(curs.rownumber, 0)
|
||||
|
||||
t = next(i)
|
||||
self.assertEqual(t.i, 1)
|
||||
self.assertEqual(t.s, 'foo')
|
||||
self.assertEqual(curs.rownumber, 1)
|
||||
self.assertEqual(curs.rowcount, 3)
|
||||
|
||||
t = next(i)
|
||||
self.assertEqual(t.i, 2)
|
||||
self.assertEqual(t.s, 'bar')
|
||||
self.assertEqual(curs.rownumber, 2)
|
||||
self.assertEqual(curs.rowcount, 3)
|
||||
|
||||
t = next(i)
|
||||
self.assertEqual(t.i, 3)
|
||||
self.assertEqual(t.s, 'baz')
|
||||
self.assertRaises(StopIteration, next, i)
|
||||
self.assertEqual(curs.rownumber, 3)
|
||||
self.assertEqual(curs.rowcount, 3)
|
||||
|
||||
def test_record_updated(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select 1 as foo;")
|
||||
r = curs.fetchone()
|
||||
self.assertEqual(r.foo, 1)
|
||||
|
||||
curs.execute("select 2 as bar;")
|
||||
r = curs.fetchone()
|
||||
self.assertEqual(r.bar, 2)
|
||||
self.assertRaises(AttributeError, getattr, r, 'foo')
|
||||
|
||||
def test_no_result_no_surprise(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("update nttest set s = s")
|
||||
self.assertRaises(psycopg2.ProgrammingError, curs.fetchone)
|
||||
|
||||
curs.execute("update nttest set s = s")
|
||||
self.assertRaises(psycopg2.ProgrammingError, curs.fetchall)
|
||||
|
||||
def test_bad_col_names(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.execute('select 1 as "foo.bar_baz", 2 as "?column?", 3 as "3"')
|
||||
rv = curs.fetchone()
|
||||
self.assertEqual(rv.foo_bar_baz, 1)
|
||||
self.assertEqual(rv.f_column_, 2)
|
||||
self.assertEqual(rv.f3, 3)
|
||||
|
||||
@skip_before_postgres(8)
|
||||
def test_nonascii_name(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.execute('select 1 as \xe5h\xe9')
|
||||
rv = curs.fetchone()
|
||||
self.assertEqual(getattr(rv, '\xe5h\xe9'), 1)
|
||||
|
||||
def test_minimal_generation(self):
|
||||
# Instrument the class to verify it gets called the minimum number of times.
|
||||
f_orig = NamedTupleCursor._make_nt
|
||||
calls = [0]
|
||||
|
||||
def f_patched(self_):
|
||||
calls[0] += 1
|
||||
return f_orig(self_)
|
||||
|
||||
NamedTupleCursor._make_nt = f_patched
|
||||
|
||||
try:
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select * from nttest order by 1")
|
||||
curs.fetchone()
|
||||
curs.fetchone()
|
||||
curs.fetchone()
|
||||
self.assertEqual(1, calls[0])
|
||||
|
||||
curs.execute("select * from nttest order by 1")
|
||||
curs.fetchone()
|
||||
curs.fetchall()
|
||||
self.assertEqual(2, calls[0])
|
||||
|
||||
curs.execute("select * from nttest order by 1")
|
||||
curs.fetchone()
|
||||
curs.fetchmany(1)
|
||||
self.assertEqual(3, calls[0])
|
||||
|
||||
finally:
|
||||
NamedTupleCursor._make_nt = f_orig
|
||||
|
||||
@skip_if_crdb("named cursor")
|
||||
@skip_before_postgres(8, 0)
|
||||
def test_named(self):
|
||||
curs = self.conn.cursor('tmp')
|
||||
curs.execute("""select i from generate_series(0,9) i""")
|
||||
recs = []
|
||||
recs.extend(curs.fetchmany(5))
|
||||
recs.append(curs.fetchone())
|
||||
recs.extend(curs.fetchall())
|
||||
self.assertEqual(list(range(10)), [t.i for t in recs])
|
||||
|
||||
@skip_if_crdb("named cursor")
|
||||
def test_named_fetchone(self):
|
||||
curs = self.conn.cursor('tmp')
|
||||
curs.execute("""select 42 as i""")
|
||||
t = curs.fetchone()
|
||||
self.assertEqual(t.i, 42)
|
||||
|
||||
@skip_if_crdb("named cursor")
|
||||
def test_named_fetchmany(self):
|
||||
curs = self.conn.cursor('tmp')
|
||||
curs.execute("""select 42 as i""")
|
||||
recs = curs.fetchmany(10)
|
||||
self.assertEqual(recs[0].i, 42)
|
||||
|
||||
@skip_if_crdb("named cursor")
|
||||
def test_named_fetchall(self):
|
||||
curs = self.conn.cursor('tmp')
|
||||
curs.execute("""select 42 as i""")
|
||||
recs = curs.fetchall()
|
||||
self.assertEqual(recs[0].i, 42)
|
||||
|
||||
@skip_if_crdb("named cursor")
|
||||
@skip_before_postgres(8, 2)
|
||||
def test_not_greedy(self):
|
||||
curs = self.conn.cursor('tmp')
|
||||
curs.itersize = 2
|
||||
curs.execute("""select clock_timestamp() as ts from generate_series(1,3)""")
|
||||
recs = []
|
||||
for t in curs:
|
||||
time.sleep(0.01)
|
||||
recs.append(t)
|
||||
|
||||
# check that the dataset was not fetched in a single gulp
|
||||
self.assert_(recs[1].ts - recs[0].ts < timedelta(seconds=0.005))
|
||||
self.assert_(recs[2].ts - recs[1].ts > timedelta(seconds=0.0099))
|
||||
|
||||
@skip_if_crdb("named cursor")
|
||||
@skip_before_postgres(8, 0)
|
||||
def test_named_rownumber(self):
|
||||
curs = self.conn.cursor('tmp')
|
||||
# Only checking for dataset < itersize:
|
||||
# see CursorTests.test_iter_named_cursor_rownumber
|
||||
curs.itersize = 4
|
||||
curs.execute("""select * from generate_series(1,3)""")
|
||||
for i, t in enumerate(curs):
|
||||
self.assertEqual(i + 1, curs.rownumber)
|
||||
|
||||
def test_cache(self):
|
||||
NamedTupleCursor._cached_make_nt.cache_clear()
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select 10 as a, 20 as b")
|
||||
r1 = curs.fetchone()
|
||||
curs.execute("select 10 as a, 20 as c")
|
||||
r2 = curs.fetchone()
|
||||
|
||||
# Get a new cursor to check that the cache works across multiple ones
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select 10 as a, 30 as b")
|
||||
r3 = curs.fetchone()
|
||||
|
||||
self.assert_(type(r1) is type(r3))
|
||||
self.assert_(type(r1) is not type(r2))
|
||||
|
||||
cache_info = NamedTupleCursor._cached_make_nt.cache_info()
|
||||
self.assertEqual(cache_info.hits, 1)
|
||||
self.assertEqual(cache_info.misses, 2)
|
||||
self.assertEqual(cache_info.currsize, 2)
|
||||
|
||||
def test_max_cache(self):
|
||||
old_func = NamedTupleCursor._cached_make_nt
|
||||
NamedTupleCursor._cached_make_nt = \
|
||||
lru_cache(8)(NamedTupleCursor._cached_make_nt.__wrapped__)
|
||||
try:
|
||||
recs = []
|
||||
curs = self.conn.cursor()
|
||||
for i in range(10):
|
||||
curs.execute(f"select 1 as f{i}")
|
||||
recs.append(curs.fetchone())
|
||||
|
||||
# Still in cache
|
||||
curs.execute("select 1 as f9")
|
||||
rec = curs.fetchone()
|
||||
self.assert_(any(type(r) is type(rec) for r in recs))
|
||||
|
||||
# Gone from cache
|
||||
curs.execute("select 1 as f0")
|
||||
rec = curs.fetchone()
|
||||
self.assert_(all(type(r) is not type(rec) for r in recs))
|
||||
|
||||
finally:
|
||||
NamedTupleCursor._cached_make_nt = old_func
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
269
tests/test_fast_executemany.py
Executable file
269
tests/test_fast_executemany.py
Executable file
@ -0,0 +1,269 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# test_fast_executemany.py - tests for fast executemany implementations
|
||||
#
|
||||
# Copyright (C) 2017-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
from datetime import date
|
||||
|
||||
from . import testutils
|
||||
import unittest
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
import psycopg2.extensions as ext
|
||||
from psycopg2 import sql
|
||||
|
||||
|
||||
class TestPaginate(unittest.TestCase):
|
||||
def test_paginate(self):
|
||||
def pag(seq):
|
||||
return psycopg2.extras._paginate(seq, 100)
|
||||
|
||||
self.assertEqual(list(pag([])), [])
|
||||
self.assertEqual(list(pag([1])), [[1]])
|
||||
self.assertEqual(list(pag(range(99))), [list(range(99))])
|
||||
self.assertEqual(list(pag(range(100))), [list(range(100))])
|
||||
self.assertEqual(list(pag(range(101))), [list(range(100)), [100]])
|
||||
self.assertEqual(
|
||||
list(pag(range(200))), [list(range(100)), list(range(100, 200))])
|
||||
self.assertEqual(
|
||||
list(pag(range(1000))),
|
||||
[list(range(i * 100, (i + 1) * 100)) for i in range(10)])
|
||||
|
||||
|
||||
class FastExecuteTestMixin:
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("""create table testfast (
|
||||
id serial primary key, date date, val int, data text)""")
|
||||
|
||||
|
||||
class TestExecuteBatch(FastExecuteTestMixin, testutils.ConnectingTestCase):
|
||||
def test_empty(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.execute_batch(cur,
|
||||
"insert into testfast (id, val) values (%s, %s)",
|
||||
[])
|
||||
cur.execute("select * from testfast order by id")
|
||||
self.assertEqual(cur.fetchall(), [])
|
||||
|
||||
def test_one(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.execute_batch(cur,
|
||||
"insert into testfast (id, val) values (%s, %s)",
|
||||
iter([(1, 10)]))
|
||||
cur.execute("select id, val from testfast order by id")
|
||||
self.assertEqual(cur.fetchall(), [(1, 10)])
|
||||
|
||||
def test_tuples(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.execute_batch(cur,
|
||||
"insert into testfast (id, date, val) values (%s, %s, %s)",
|
||||
((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
|
||||
cur.execute("select id, date, val from testfast order by id")
|
||||
self.assertEqual(cur.fetchall(),
|
||||
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
|
||||
|
||||
def test_many(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.execute_batch(cur,
|
||||
"insert into testfast (id, val) values (%s, %s)",
|
||||
((i, i * 10) for i in range(1000)))
|
||||
cur.execute("select id, val from testfast order by id")
|
||||
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
|
||||
|
||||
def test_composed(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.execute_batch(cur,
|
||||
sql.SQL("insert into {0} (id, val) values (%s, %s)")
|
||||
.format(sql.Identifier('testfast')),
|
||||
((i, i * 10) for i in range(1000)))
|
||||
cur.execute("select id, val from testfast order by id")
|
||||
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
|
||||
|
||||
def test_pages(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.execute_batch(cur,
|
||||
"insert into testfast (id, val) values (%s, %s)",
|
||||
((i, i * 10) for i in range(25)),
|
||||
page_size=10)
|
||||
|
||||
# last command was 5 statements
|
||||
self.assertEqual(sum(c == ';' for c in cur.query.decode('ascii')), 4)
|
||||
|
||||
cur.execute("select id, val from testfast order by id")
|
||||
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
|
||||
|
||||
@testutils.skip_before_postgres(8, 0)
|
||||
def test_unicode(self):
|
||||
cur = self.conn.cursor()
|
||||
ext.register_type(ext.UNICODE, cur)
|
||||
snowman = "\u2603"
|
||||
|
||||
# unicode in statement
|
||||
psycopg2.extras.execute_batch(cur,
|
||||
"insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
|
||||
[(1, 'x')])
|
||||
cur.execute("select id, data from testfast where id = 1")
|
||||
self.assertEqual(cur.fetchone(), (1, 'x'))
|
||||
|
||||
# unicode in data
|
||||
psycopg2.extras.execute_batch(cur,
|
||||
"insert into testfast (id, data) values (%s, %s)",
|
||||
[(2, snowman)])
|
||||
cur.execute("select id, data from testfast where id = 2")
|
||||
self.assertEqual(cur.fetchone(), (2, snowman))
|
||||
|
||||
# unicode in both
|
||||
psycopg2.extras.execute_batch(cur,
|
||||
"insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
|
||||
[(3, snowman)])
|
||||
cur.execute("select id, data from testfast where id = 3")
|
||||
self.assertEqual(cur.fetchone(), (3, snowman))
|
||||
|
||||
|
||||
@testutils.skip_before_postgres(8, 2)
|
||||
class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase):
|
||||
def test_empty(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.execute_values(cur,
|
||||
"insert into testfast (id, val) values %s",
|
||||
[])
|
||||
cur.execute("select * from testfast order by id")
|
||||
self.assertEqual(cur.fetchall(), [])
|
||||
|
||||
def test_one(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.execute_values(cur,
|
||||
"insert into testfast (id, val) values %s",
|
||||
iter([(1, 10)]))
|
||||
cur.execute("select id, val from testfast order by id")
|
||||
self.assertEqual(cur.fetchall(), [(1, 10)])
|
||||
|
||||
def test_tuples(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.execute_values(cur,
|
||||
"insert into testfast (id, date, val) values %s",
|
||||
((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
|
||||
cur.execute("select id, date, val from testfast order by id")
|
||||
self.assertEqual(cur.fetchall(),
|
||||
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
|
||||
|
||||
def test_dicts(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.execute_values(cur,
|
||||
"insert into testfast (id, date, val) values %s",
|
||||
(dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar")
|
||||
for i in range(10)),
|
||||
template='(%(id)s, %(date)s, %(val)s)')
|
||||
cur.execute("select id, date, val from testfast order by id")
|
||||
self.assertEqual(cur.fetchall(),
|
||||
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
|
||||
|
||||
def test_many(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.execute_values(cur,
|
||||
"insert into testfast (id, val) values %s",
|
||||
((i, i * 10) for i in range(1000)))
|
||||
cur.execute("select id, val from testfast order by id")
|
||||
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
|
||||
|
||||
def test_composed(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.execute_values(cur,
|
||||
sql.SQL("insert into {0} (id, val) values %s")
|
||||
.format(sql.Identifier('testfast')),
|
||||
((i, i * 10) for i in range(1000)))
|
||||
cur.execute("select id, val from testfast order by id")
|
||||
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
|
||||
|
||||
def test_pages(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.execute_values(cur,
|
||||
"insert into testfast (id, val) values %s",
|
||||
((i, i * 10) for i in range(25)),
|
||||
page_size=10)
|
||||
|
||||
# last statement was 5 tuples (one parens is for the fields list)
|
||||
self.assertEqual(sum(c == '(' for c in cur.query.decode('ascii')), 6)
|
||||
|
||||
cur.execute("select id, val from testfast order by id")
|
||||
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
|
||||
|
||||
def test_unicode(self):
|
||||
cur = self.conn.cursor()
|
||||
ext.register_type(ext.UNICODE, cur)
|
||||
snowman = "\u2603"
|
||||
|
||||
# unicode in statement
|
||||
psycopg2.extras.execute_values(cur,
|
||||
"insert into testfast (id, data) values %%s -- %s" % snowman,
|
||||
[(1, 'x')])
|
||||
cur.execute("select id, data from testfast where id = 1")
|
||||
self.assertEqual(cur.fetchone(), (1, 'x'))
|
||||
|
||||
# unicode in data
|
||||
psycopg2.extras.execute_values(cur,
|
||||
"insert into testfast (id, data) values %s",
|
||||
[(2, snowman)])
|
||||
cur.execute("select id, data from testfast where id = 2")
|
||||
self.assertEqual(cur.fetchone(), (2, snowman))
|
||||
|
||||
# unicode in both
|
||||
psycopg2.extras.execute_values(cur,
|
||||
"insert into testfast (id, data) values %%s -- %s" % snowman,
|
||||
[(3, snowman)])
|
||||
cur.execute("select id, data from testfast where id = 3")
|
||||
self.assertEqual(cur.fetchone(), (3, snowman))
|
||||
|
||||
def test_returning(self):
|
||||
cur = self.conn.cursor()
|
||||
result = psycopg2.extras.execute_values(cur,
|
||||
"insert into testfast (id, val) values %s returning id",
|
||||
((i, i * 10) for i in range(25)),
|
||||
page_size=10, fetch=True)
|
||||
# result contains all returned pages
|
||||
self.assertEqual([r[0] for r in result], list(range(25)))
|
||||
|
||||
def test_invalid_sql(self):
|
||||
cur = self.conn.cursor()
|
||||
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
|
||||
"insert", [])
|
||||
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
|
||||
"insert %s and %s", [])
|
||||
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
|
||||
"insert %f", [])
|
||||
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
|
||||
"insert %f %s", [])
|
||||
|
||||
def test_percent_escape(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.execute_values(cur,
|
||||
"insert into testfast (id, data) values %s -- a%%b",
|
||||
[(1, 'hi')])
|
||||
self.assert_(b'a%%b' not in cur.query)
|
||||
self.assert_(b'a%b' in cur.query)
|
||||
|
||||
cur.execute("select id, data from testfast")
|
||||
self.assertEqual(cur.fetchall(), [(1, 'hi')])
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
246
tests/test_green.py
Executable file
246
tests/test_green.py
Executable file
@ -0,0 +1,246 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# test_green.py - unit test for async wait callback
|
||||
#
|
||||
# Copyright (C) 2010-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import select
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extensions
|
||||
import psycopg2.extras
|
||||
from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE
|
||||
|
||||
from .testutils import ConnectingTestCase, skip_before_postgres, slow
|
||||
from .testutils import skip_if_crdb
|
||||
|
||||
|
||||
class ConnectionStub:
|
||||
"""A `connection` wrapper allowing analysis of the `poll()` calls."""
|
||||
def __init__(self, conn):
|
||||
self.conn = conn
|
||||
self.polls = []
|
||||
|
||||
def fileno(self):
|
||||
return self.conn.fileno()
|
||||
|
||||
def poll(self):
|
||||
rv = self.conn.poll()
|
||||
self.polls.append(rv)
|
||||
return rv
|
||||
|
||||
|
||||
class GreenTestCase(ConnectingTestCase):
|
||||
def setUp(self):
|
||||
self._cb = psycopg2.extensions.get_wait_callback()
|
||||
psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select)
|
||||
ConnectingTestCase.setUp(self)
|
||||
|
||||
def tearDown(self):
|
||||
ConnectingTestCase.tearDown(self)
|
||||
psycopg2.extensions.set_wait_callback(self._cb)
|
||||
|
||||
def set_stub_wait_callback(self, conn, cb=None):
|
||||
stub = ConnectionStub(conn)
|
||||
psycopg2.extensions.set_wait_callback(
|
||||
lambda conn: (cb or psycopg2.extras.wait_select)(stub))
|
||||
return stub
|
||||
|
||||
@slow
|
||||
@skip_if_crdb("flush on write flakey")
|
||||
def test_flush_on_write(self):
|
||||
# a very large query requires a flush loop to be sent to the backend
|
||||
conn = self.conn
|
||||
stub = self.set_stub_wait_callback(conn)
|
||||
curs = conn.cursor()
|
||||
for mb in 1, 5, 10, 20, 50:
|
||||
size = mb * 1024 * 1024
|
||||
del stub.polls[:]
|
||||
curs.execute("select %s;", ('x' * size,))
|
||||
self.assertEqual(size, len(curs.fetchone()[0]))
|
||||
if stub.polls.count(psycopg2.extensions.POLL_WRITE) > 1:
|
||||
return
|
||||
|
||||
# This is more a testing glitch than an error: it happens
|
||||
# on high load on linux: probably because the kernel has more
|
||||
# buffers ready. A warning may be useful during development,
|
||||
# but an error is bad during regression testing.
|
||||
warnings.warn("sending a large query didn't trigger block on write.")
|
||||
|
||||
def test_error_in_callback(self):
|
||||
# behaviour changed after issue #113: if there is an error in the
|
||||
# callback for the moment we don't have a way to reset the connection
|
||||
# without blocking (ticket #113) so just close it.
|
||||
conn = self.conn
|
||||
curs = conn.cursor()
|
||||
curs.execute("select 1") # have a BEGIN
|
||||
curs.fetchone()
|
||||
|
||||
# now try to do something that will fail in the callback
|
||||
psycopg2.extensions.set_wait_callback(lambda conn: 1 // 0)
|
||||
self.assertRaises(ZeroDivisionError, curs.execute, "select 2")
|
||||
|
||||
self.assert_(conn.closed)
|
||||
|
||||
def test_dont_freak_out(self):
|
||||
# if there is an error in a green query, don't freak out and close
|
||||
# the connection
|
||||
conn = self.conn
|
||||
curs = conn.cursor()
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
curs.execute, "select the unselectable")
|
||||
|
||||
# check that the connection is left in an usable state
|
||||
self.assert_(not conn.closed)
|
||||
conn.rollback()
|
||||
curs.execute("select 1")
|
||||
self.assertEqual(curs.fetchone()[0], 1)
|
||||
|
||||
@skip_before_postgres(8, 2)
|
||||
def test_copy_no_hang(self):
|
||||
cur = self.conn.cursor()
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
cur.execute, "copy (select 1) to stdout")
|
||||
|
||||
@slow
|
||||
@skip_if_crdb("notice")
|
||||
@skip_before_postgres(9, 0)
|
||||
def test_non_block_after_notice(self):
|
||||
def wait(conn):
|
||||
while 1:
|
||||
state = conn.poll()
|
||||
if state == POLL_OK:
|
||||
break
|
||||
elif state == POLL_READ:
|
||||
select.select([conn.fileno()], [], [], 0.1)
|
||||
elif state == POLL_WRITE:
|
||||
select.select([], [conn.fileno()], [], 0.1)
|
||||
else:
|
||||
raise conn.OperationalError(f"bad state from poll: {state}")
|
||||
|
||||
stub = self.set_stub_wait_callback(self.conn, wait)
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("""
|
||||
select 1;
|
||||
do $$
|
||||
begin
|
||||
raise notice 'hello';
|
||||
end
|
||||
$$ language plpgsql;
|
||||
select pg_sleep(1);
|
||||
""")
|
||||
|
||||
polls = stub.polls.count(POLL_READ)
|
||||
self.assert_(polls > 8, polls)
|
||||
|
||||
|
||||
class CallbackErrorTestCase(ConnectingTestCase):
|
||||
def setUp(self):
|
||||
self._cb = psycopg2.extensions.get_wait_callback()
|
||||
psycopg2.extensions.set_wait_callback(self.crappy_callback)
|
||||
ConnectingTestCase.setUp(self)
|
||||
self.to_error = None
|
||||
|
||||
def tearDown(self):
|
||||
ConnectingTestCase.tearDown(self)
|
||||
psycopg2.extensions.set_wait_callback(self._cb)
|
||||
|
||||
def crappy_callback(self, conn):
|
||||
"""green callback failing after `self.to_error` time it is called"""
|
||||
while True:
|
||||
if self.to_error is not None:
|
||||
self.to_error -= 1
|
||||
if self.to_error <= 0:
|
||||
raise ZeroDivisionError("I accidentally the connection")
|
||||
try:
|
||||
state = conn.poll()
|
||||
if state == POLL_OK:
|
||||
break
|
||||
elif state == POLL_READ:
|
||||
select.select([conn.fileno()], [], [])
|
||||
elif state == POLL_WRITE:
|
||||
select.select([], [conn.fileno()], [])
|
||||
else:
|
||||
raise conn.OperationalError(f"bad state from poll: {state}")
|
||||
except KeyboardInterrupt:
|
||||
conn.cancel()
|
||||
# the loop will be broken by a server error
|
||||
continue
|
||||
|
||||
def test_errors_on_connection(self):
|
||||
# Test error propagation in the different stages of the connection
|
||||
for i in range(100):
|
||||
self.to_error = i
|
||||
try:
|
||||
self.connect()
|
||||
except ZeroDivisionError:
|
||||
pass
|
||||
else:
|
||||
# We managed to connect
|
||||
return
|
||||
|
||||
self.fail("you should have had a success or an error by now")
|
||||
|
||||
def test_errors_on_query(self):
|
||||
for i in range(100):
|
||||
self.to_error = None
|
||||
cnn = self.connect()
|
||||
cur = cnn.cursor()
|
||||
self.to_error = i
|
||||
try:
|
||||
cur.execute("select 1")
|
||||
cur.fetchone()
|
||||
except ZeroDivisionError:
|
||||
pass
|
||||
else:
|
||||
# The query completed
|
||||
return
|
||||
|
||||
self.fail("you should have had a success or an error by now")
|
||||
|
||||
@skip_if_crdb("named cursor")
|
||||
def test_errors_named_cursor(self):
|
||||
for i in range(100):
|
||||
self.to_error = None
|
||||
cnn = self.connect()
|
||||
cur = cnn.cursor('foo')
|
||||
self.to_error = i
|
||||
try:
|
||||
cur.execute("select 1")
|
||||
cur.fetchone()
|
||||
except ZeroDivisionError:
|
||||
pass
|
||||
else:
|
||||
# The query completed
|
||||
return
|
||||
|
||||
self.fail("you should have had a success or an error by now")
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
120
tests/test_ipaddress.py
Executable file
120
tests/test_ipaddress.py
Executable file
@ -0,0 +1,120 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# test_ipaddress.py - tests for ipaddress support
|
||||
#
|
||||
# Copyright (C) 2016-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
|
||||
from . import testutils
|
||||
import unittest
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
|
||||
try:
|
||||
import ipaddress as ip
|
||||
except ImportError:
|
||||
# Python 2
|
||||
ip = None
|
||||
|
||||
|
||||
@unittest.skipIf(ip is None, "'ipaddress' module not available")
|
||||
class NetworkingTestCase(testutils.ConnectingTestCase):
|
||||
def test_inet_cast(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.register_ipaddress(cur)
|
||||
|
||||
cur.execute("select null::inet")
|
||||
self.assert_(cur.fetchone()[0] is None)
|
||||
|
||||
cur.execute("select '127.0.0.1/24'::inet")
|
||||
obj = cur.fetchone()[0]
|
||||
self.assert_(isinstance(obj, ip.IPv4Interface), repr(obj))
|
||||
self.assertEquals(obj, ip.ip_interface('127.0.0.1/24'))
|
||||
|
||||
cur.execute("select '::ffff:102:300/128'::inet")
|
||||
obj = cur.fetchone()[0]
|
||||
self.assert_(isinstance(obj, ip.IPv6Interface), repr(obj))
|
||||
self.assertEquals(obj, ip.ip_interface('::ffff:102:300/128'))
|
||||
|
||||
@testutils.skip_before_postgres(8, 2)
|
||||
def test_inet_array_cast(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.register_ipaddress(cur)
|
||||
cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::inet[]")
|
||||
l = cur.fetchone()[0]
|
||||
self.assert_(l[0] is None)
|
||||
self.assertEquals(l[1], ip.ip_interface('127.0.0.1'))
|
||||
self.assertEquals(l[2], ip.ip_interface('::ffff:102:300/128'))
|
||||
self.assert_(isinstance(l[1], ip.IPv4Interface), l)
|
||||
self.assert_(isinstance(l[2], ip.IPv6Interface), l)
|
||||
|
||||
def test_inet_adapt(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.register_ipaddress(cur)
|
||||
|
||||
cur.execute("select %s", [ip.ip_interface('127.0.0.1/24')])
|
||||
self.assertEquals(cur.fetchone()[0], '127.0.0.1/24')
|
||||
|
||||
cur.execute("select %s", [ip.ip_interface('::ffff:102:300/128')])
|
||||
self.assertEquals(cur.fetchone()[0], '::ffff:102:300/128')
|
||||
|
||||
@testutils.skip_if_crdb("cidr")
|
||||
def test_cidr_cast(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.register_ipaddress(cur)
|
||||
|
||||
cur.execute("select null::cidr")
|
||||
self.assert_(cur.fetchone()[0] is None)
|
||||
|
||||
cur.execute("select '127.0.0.0/24'::cidr")
|
||||
obj = cur.fetchone()[0]
|
||||
self.assert_(isinstance(obj, ip.IPv4Network), repr(obj))
|
||||
self.assertEquals(obj, ip.ip_network('127.0.0.0/24'))
|
||||
|
||||
cur.execute("select '::ffff:102:300/128'::cidr")
|
||||
obj = cur.fetchone()[0]
|
||||
self.assert_(isinstance(obj, ip.IPv6Network), repr(obj))
|
||||
self.assertEquals(obj, ip.ip_network('::ffff:102:300/128'))
|
||||
|
||||
@testutils.skip_if_crdb("cidr")
|
||||
@testutils.skip_before_postgres(8, 2)
|
||||
def test_cidr_array_cast(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.register_ipaddress(cur)
|
||||
cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::cidr[]")
|
||||
l = cur.fetchone()[0]
|
||||
self.assert_(l[0] is None)
|
||||
self.assertEquals(l[1], ip.ip_network('127.0.0.1'))
|
||||
self.assertEquals(l[2], ip.ip_network('::ffff:102:300/128'))
|
||||
self.assert_(isinstance(l[1], ip.IPv4Network), l)
|
||||
self.assert_(isinstance(l[2], ip.IPv6Network), l)
|
||||
|
||||
def test_cidr_adapt(self):
|
||||
cur = self.conn.cursor()
|
||||
psycopg2.extras.register_ipaddress(cur)
|
||||
|
||||
cur.execute("select %s", [ip.ip_network('127.0.0.0/24')])
|
||||
self.assertEquals(cur.fetchone()[0], '127.0.0.0/24')
|
||||
|
||||
cur.execute("select %s", [ip.ip_network('::ffff:102:300/128')])
|
||||
self.assertEquals(cur.fetchone()[0], '::ffff:102:300/128')
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
530
tests/test_lobject.py
Executable file
530
tests/test_lobject.py
Executable file
@ -0,0 +1,530 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# test_lobject.py - unit test for large objects support
|
||||
#
|
||||
# Copyright (C) 2008-2019 James Henstridge <james@jamesh.id.au>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from functools import wraps
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extensions
|
||||
import unittest
|
||||
from .testutils import (decorate_all_tests, skip_if_tpc_disabled,
|
||||
skip_before_postgres, ConnectingTestCase, skip_if_green, skip_if_crdb, slow)
|
||||
|
||||
|
||||
def skip_if_no_lo(f):
|
||||
f = skip_before_postgres(8, 1, "large objects only supported from PG 8.1")(f)
|
||||
f = skip_if_green("libpq doesn't support LO in async mode")(f)
|
||||
f = skip_if_crdb("large objects")(f)
|
||||
return f
|
||||
|
||||
|
||||
class LargeObjectTestCase(ConnectingTestCase):
|
||||
def setUp(self):
|
||||
ConnectingTestCase.setUp(self)
|
||||
self.lo_oid = None
|
||||
self.tmpdir = None
|
||||
|
||||
def tearDown(self):
|
||||
if self.tmpdir:
|
||||
shutil.rmtree(self.tmpdir, ignore_errors=True)
|
||||
|
||||
if self.conn.closed:
|
||||
return
|
||||
|
||||
if self.lo_oid is not None:
|
||||
self.conn.rollback()
|
||||
try:
|
||||
lo = self.conn.lobject(self.lo_oid, "n")
|
||||
except psycopg2.OperationalError:
|
||||
pass
|
||||
else:
|
||||
lo.unlink()
|
||||
|
||||
ConnectingTestCase.tearDown(self)
|
||||
|
||||
|
||||
@skip_if_no_lo
|
||||
class LargeObjectTests(LargeObjectTestCase):
|
||||
def test_create(self):
|
||||
lo = self.conn.lobject()
|
||||
self.assertNotEqual(lo, None)
|
||||
self.assertEqual(lo.mode[0], "w")
|
||||
|
||||
def test_connection_needed(self):
|
||||
self.assertRaises(TypeError,
|
||||
psycopg2.extensions.lobject, [])
|
||||
|
||||
def test_open_non_existent(self):
|
||||
# By creating then removing a large object, we get an Oid that
|
||||
# should be unused.
|
||||
lo = self.conn.lobject()
|
||||
lo.unlink()
|
||||
self.assertRaises(psycopg2.OperationalError, self.conn.lobject, lo.oid)
|
||||
|
||||
def test_open_existing(self):
|
||||
lo = self.conn.lobject()
|
||||
lo2 = self.conn.lobject(lo.oid)
|
||||
self.assertNotEqual(lo2, None)
|
||||
self.assertEqual(lo2.oid, lo.oid)
|
||||
self.assertEqual(lo2.mode[0], "r")
|
||||
|
||||
def test_open_for_write(self):
|
||||
lo = self.conn.lobject()
|
||||
lo2 = self.conn.lobject(lo.oid, "w")
|
||||
self.assertEqual(lo2.mode[0], "w")
|
||||
lo2.write(b"some data")
|
||||
|
||||
def test_open_mode_n(self):
|
||||
# Openning an object in mode "n" gives us a closed lobject.
|
||||
lo = self.conn.lobject()
|
||||
lo.close()
|
||||
|
||||
lo2 = self.conn.lobject(lo.oid, "n")
|
||||
self.assertEqual(lo2.oid, lo.oid)
|
||||
self.assertEqual(lo2.closed, True)
|
||||
|
||||
def test_mode_defaults(self):
|
||||
lo = self.conn.lobject()
|
||||
lo2 = self.conn.lobject(mode=None)
|
||||
lo3 = self.conn.lobject(mode="")
|
||||
self.assertEqual(lo.mode, lo2.mode)
|
||||
self.assertEqual(lo.mode, lo3.mode)
|
||||
|
||||
def test_close_connection_gone(self):
|
||||
lo = self.conn.lobject()
|
||||
self.conn.close()
|
||||
lo.close()
|
||||
|
||||
def test_create_with_oid(self):
|
||||
# Create and delete a large object to get an unused Oid.
|
||||
lo = self.conn.lobject()
|
||||
oid = lo.oid
|
||||
lo.unlink()
|
||||
|
||||
lo = self.conn.lobject(0, "w", oid)
|
||||
self.assertEqual(lo.oid, oid)
|
||||
|
||||
def test_create_with_existing_oid(self):
|
||||
lo = self.conn.lobject()
|
||||
lo.close()
|
||||
|
||||
self.assertRaises(psycopg2.OperationalError,
|
||||
self.conn.lobject, 0, "w", lo.oid)
|
||||
self.assert_(not self.conn.closed)
|
||||
|
||||
def test_import(self):
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
filename = os.path.join(self.tmpdir, "data.txt")
|
||||
fp = open(filename, "wb")
|
||||
fp.write(b"some data")
|
||||
fp.close()
|
||||
|
||||
lo = self.conn.lobject(0, "r", 0, filename)
|
||||
self.assertEqual(lo.read(), "some data")
|
||||
|
||||
def test_close(self):
|
||||
lo = self.conn.lobject()
|
||||
self.assertEqual(lo.closed, False)
|
||||
lo.close()
|
||||
self.assertEqual(lo.closed, True)
|
||||
|
||||
def test_write(self):
|
||||
lo = self.conn.lobject()
|
||||
self.assertEqual(lo.write(b"some data"), len("some data"))
|
||||
|
||||
def test_write_large(self):
|
||||
lo = self.conn.lobject()
|
||||
data = "data" * 1000000
|
||||
self.assertEqual(lo.write(data), len(data))
|
||||
|
||||
def test_read(self):
|
||||
lo = self.conn.lobject()
|
||||
lo.write(b"some data")
|
||||
lo.close()
|
||||
|
||||
lo = self.conn.lobject(lo.oid)
|
||||
x = lo.read(4)
|
||||
self.assertEqual(type(x), type(''))
|
||||
self.assertEqual(x, "some")
|
||||
self.assertEqual(lo.read(), " data")
|
||||
|
||||
def test_read_binary(self):
|
||||
lo = self.conn.lobject()
|
||||
lo.write(b"some data")
|
||||
lo.close()
|
||||
|
||||
lo = self.conn.lobject(lo.oid, "rb")
|
||||
x = lo.read(4)
|
||||
self.assertEqual(type(x), type(b''))
|
||||
self.assertEqual(x, b"some")
|
||||
self.assertEqual(lo.read(), b" data")
|
||||
|
||||
def test_read_text(self):
|
||||
lo = self.conn.lobject()
|
||||
snowman = "\u2603"
|
||||
lo.write("some data " + snowman)
|
||||
lo.close()
|
||||
|
||||
lo = self.conn.lobject(lo.oid, "rt")
|
||||
x = lo.read(4)
|
||||
self.assertEqual(type(x), type(''))
|
||||
self.assertEqual(x, "some")
|
||||
self.assertEqual(lo.read(), " data " + snowman)
|
||||
|
||||
@slow
|
||||
def test_read_large(self):
|
||||
lo = self.conn.lobject()
|
||||
data = "data" * 1000000
|
||||
lo.write("some" + data)
|
||||
lo.close()
|
||||
|
||||
lo = self.conn.lobject(lo.oid)
|
||||
self.assertEqual(lo.read(4), "some")
|
||||
data1 = lo.read()
|
||||
# avoid dumping megacraps in the console in case of error
|
||||
self.assert_(data == data1,
|
||||
f"{data[:100]!r}... != {data1[:100]!r}...")
|
||||
|
||||
def test_seek_tell(self):
|
||||
lo = self.conn.lobject()
|
||||
length = lo.write(b"some data")
|
||||
self.assertEqual(lo.tell(), length)
|
||||
lo.close()
|
||||
lo = self.conn.lobject(lo.oid)
|
||||
|
||||
self.assertEqual(lo.seek(5, 0), 5)
|
||||
self.assertEqual(lo.tell(), 5)
|
||||
self.assertEqual(lo.read(), "data")
|
||||
|
||||
# SEEK_CUR: relative current location
|
||||
lo.seek(5)
|
||||
self.assertEqual(lo.seek(2, 1), 7)
|
||||
self.assertEqual(lo.tell(), 7)
|
||||
self.assertEqual(lo.read(), "ta")
|
||||
|
||||
# SEEK_END: relative to end of file
|
||||
self.assertEqual(lo.seek(-2, 2), length - 2)
|
||||
self.assertEqual(lo.read(), "ta")
|
||||
|
||||
def test_unlink(self):
|
||||
lo = self.conn.lobject()
|
||||
lo.unlink()
|
||||
|
||||
# the object doesn't exist now, so we can't reopen it.
|
||||
self.assertRaises(psycopg2.OperationalError, self.conn.lobject, lo.oid)
|
||||
# And the object has been closed.
|
||||
self.assertEquals(lo.closed, True)
|
||||
|
||||
def test_export(self):
|
||||
lo = self.conn.lobject()
|
||||
lo.write(b"some data")
|
||||
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
filename = os.path.join(self.tmpdir, "data.txt")
|
||||
lo.export(filename)
|
||||
self.assertTrue(os.path.exists(filename))
|
||||
f = open(filename, "rb")
|
||||
try:
|
||||
self.assertEqual(f.read(), b"some data")
|
||||
finally:
|
||||
f.close()
|
||||
|
||||
def test_close_twice(self):
|
||||
lo = self.conn.lobject()
|
||||
lo.close()
|
||||
lo.close()
|
||||
|
||||
def test_write_after_close(self):
|
||||
lo = self.conn.lobject()
|
||||
lo.close()
|
||||
self.assertRaises(psycopg2.InterfaceError, lo.write, b"some data")
|
||||
|
||||
def test_read_after_close(self):
|
||||
lo = self.conn.lobject()
|
||||
lo.close()
|
||||
self.assertRaises(psycopg2.InterfaceError, lo.read, 5)
|
||||
|
||||
def test_seek_after_close(self):
|
||||
lo = self.conn.lobject()
|
||||
lo.close()
|
||||
self.assertRaises(psycopg2.InterfaceError, lo.seek, 0)
|
||||
|
||||
def test_tell_after_close(self):
|
||||
lo = self.conn.lobject()
|
||||
lo.close()
|
||||
self.assertRaises(psycopg2.InterfaceError, lo.tell)
|
||||
|
||||
def test_unlink_after_close(self):
|
||||
lo = self.conn.lobject()
|
||||
lo.close()
|
||||
# Unlink works on closed files.
|
||||
lo.unlink()
|
||||
|
||||
def test_export_after_close(self):
|
||||
lo = self.conn.lobject()
|
||||
lo.write(b"some data")
|
||||
lo.close()
|
||||
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
filename = os.path.join(self.tmpdir, "data.txt")
|
||||
lo.export(filename)
|
||||
self.assertTrue(os.path.exists(filename))
|
||||
f = open(filename, "rb")
|
||||
try:
|
||||
self.assertEqual(f.read(), b"some data")
|
||||
finally:
|
||||
f.close()
|
||||
|
||||
def test_close_after_commit(self):
|
||||
lo = self.conn.lobject()
|
||||
self.lo_oid = lo.oid
|
||||
self.conn.commit()
|
||||
|
||||
# Closing outside of the transaction is okay.
|
||||
lo.close()
|
||||
|
||||
def test_write_after_commit(self):
|
||||
lo = self.conn.lobject()
|
||||
self.lo_oid = lo.oid
|
||||
self.conn.commit()
|
||||
|
||||
self.assertRaises(psycopg2.ProgrammingError, lo.write, b"some data")
|
||||
|
||||
def test_read_after_commit(self):
|
||||
lo = self.conn.lobject()
|
||||
self.lo_oid = lo.oid
|
||||
self.conn.commit()
|
||||
|
||||
self.assertRaises(psycopg2.ProgrammingError, lo.read, 5)
|
||||
|
||||
def test_seek_after_commit(self):
|
||||
lo = self.conn.lobject()
|
||||
self.lo_oid = lo.oid
|
||||
self.conn.commit()
|
||||
|
||||
self.assertRaises(psycopg2.ProgrammingError, lo.seek, 0)
|
||||
|
||||
def test_tell_after_commit(self):
|
||||
lo = self.conn.lobject()
|
||||
self.lo_oid = lo.oid
|
||||
self.conn.commit()
|
||||
|
||||
self.assertRaises(psycopg2.ProgrammingError, lo.tell)
|
||||
|
||||
def test_unlink_after_commit(self):
|
||||
lo = self.conn.lobject()
|
||||
self.lo_oid = lo.oid
|
||||
self.conn.commit()
|
||||
|
||||
# Unlink of stale lobject is okay
|
||||
lo.unlink()
|
||||
|
||||
def test_export_after_commit(self):
|
||||
lo = self.conn.lobject()
|
||||
lo.write(b"some data")
|
||||
self.conn.commit()
|
||||
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
filename = os.path.join(self.tmpdir, "data.txt")
|
||||
lo.export(filename)
|
||||
self.assertTrue(os.path.exists(filename))
|
||||
f = open(filename, "rb")
|
||||
try:
|
||||
self.assertEqual(f.read(), b"some data")
|
||||
finally:
|
||||
f.close()
|
||||
|
||||
@skip_if_tpc_disabled
|
||||
def test_read_after_tpc_commit(self):
|
||||
self.conn.tpc_begin('test_lobject')
|
||||
lo = self.conn.lobject()
|
||||
self.lo_oid = lo.oid
|
||||
self.conn.tpc_commit()
|
||||
|
||||
self.assertRaises(psycopg2.ProgrammingError, lo.read, 5)
|
||||
|
||||
@skip_if_tpc_disabled
|
||||
def test_read_after_tpc_prepare(self):
|
||||
self.conn.tpc_begin('test_lobject')
|
||||
lo = self.conn.lobject()
|
||||
self.lo_oid = lo.oid
|
||||
self.conn.tpc_prepare()
|
||||
|
||||
try:
|
||||
self.assertRaises(psycopg2.ProgrammingError, lo.read, 5)
|
||||
finally:
|
||||
self.conn.tpc_commit()
|
||||
|
||||
def test_large_oid(self):
|
||||
# Test we don't overflow with an oid not fitting a signed int
|
||||
try:
|
||||
self.conn.lobject(0xFFFFFFFE)
|
||||
except psycopg2.OperationalError:
|
||||
pass
|
||||
|
||||
def test_factory(self):
|
||||
class lobject_subclass(psycopg2.extensions.lobject):
|
||||
pass
|
||||
|
||||
lo = self.conn.lobject(lobject_factory=lobject_subclass)
|
||||
self.assert_(isinstance(lo, lobject_subclass))
|
||||
|
||||
|
||||
@decorate_all_tests
|
||||
def skip_if_no_truncate(f):
|
||||
@wraps(f)
|
||||
def skip_if_no_truncate_(self):
|
||||
if self.conn.info.server_version < 80300:
|
||||
return self.skipTest(
|
||||
"the server doesn't support large object truncate")
|
||||
|
||||
if not hasattr(psycopg2.extensions.lobject, 'truncate'):
|
||||
return self.skipTest(
|
||||
"psycopg2 has been built against a libpq "
|
||||
"without large object truncate support.")
|
||||
|
||||
return f(self)
|
||||
|
||||
return skip_if_no_truncate_
|
||||
|
||||
|
||||
@skip_if_no_lo
|
||||
@skip_if_no_truncate
|
||||
class LargeObjectTruncateTests(LargeObjectTestCase):
|
||||
def test_truncate(self):
|
||||
lo = self.conn.lobject()
|
||||
lo.write("some data")
|
||||
lo.close()
|
||||
|
||||
lo = self.conn.lobject(lo.oid, "w")
|
||||
lo.truncate(4)
|
||||
|
||||
# seek position unchanged
|
||||
self.assertEqual(lo.tell(), 0)
|
||||
# data truncated
|
||||
self.assertEqual(lo.read(), "some")
|
||||
|
||||
lo.truncate(6)
|
||||
lo.seek(0)
|
||||
# large object extended with zeroes
|
||||
self.assertEqual(lo.read(), "some\x00\x00")
|
||||
|
||||
lo.truncate()
|
||||
lo.seek(0)
|
||||
# large object empty
|
||||
self.assertEqual(lo.read(), "")
|
||||
|
||||
def test_truncate_after_close(self):
|
||||
lo = self.conn.lobject()
|
||||
lo.close()
|
||||
self.assertRaises(psycopg2.InterfaceError, lo.truncate)
|
||||
|
||||
def test_truncate_after_commit(self):
|
||||
lo = self.conn.lobject()
|
||||
self.lo_oid = lo.oid
|
||||
self.conn.commit()
|
||||
|
||||
self.assertRaises(psycopg2.ProgrammingError, lo.truncate)
|
||||
|
||||
|
||||
def _has_lo64(conn):
|
||||
"""Return (bool, msg) about the lo64 support"""
|
||||
if conn.info.server_version < 90300:
|
||||
return (False, "server version %s doesn't support the lo64 API"
|
||||
% conn.info.server_version)
|
||||
|
||||
if 'lo64' not in psycopg2.__version__:
|
||||
return False, "this psycopg build doesn't support the lo64 API"
|
||||
|
||||
return True, "this server and build support the lo64 API"
|
||||
|
||||
|
||||
@decorate_all_tests
|
||||
def skip_if_no_lo64(f):
|
||||
@wraps(f)
|
||||
def skip_if_no_lo64_(self):
|
||||
lo64, msg = _has_lo64(self.conn)
|
||||
if not lo64:
|
||||
return self.skipTest(msg)
|
||||
else:
|
||||
return f(self)
|
||||
|
||||
return skip_if_no_lo64_
|
||||
|
||||
|
||||
@skip_if_no_lo
|
||||
@skip_if_no_truncate
|
||||
@skip_if_no_lo64
|
||||
class LargeObject64Tests(LargeObjectTestCase):
|
||||
def test_seek_tell_truncate_greater_than_2gb(self):
|
||||
lo = self.conn.lobject()
|
||||
|
||||
length = (1 << 31) + (1 << 30) # 2gb + 1gb = 3gb
|
||||
lo.truncate(length)
|
||||
|
||||
self.assertEqual(lo.seek(length, 0), length)
|
||||
self.assertEqual(lo.tell(), length)
|
||||
|
||||
|
||||
@decorate_all_tests
|
||||
def skip_if_lo64(f):
|
||||
@wraps(f)
|
||||
def skip_if_lo64_(self):
|
||||
lo64, msg = _has_lo64(self.conn)
|
||||
if lo64:
|
||||
return self.skipTest(msg)
|
||||
else:
|
||||
return f(self)
|
||||
|
||||
return skip_if_lo64_
|
||||
|
||||
|
||||
@skip_if_no_lo
|
||||
@skip_if_no_truncate
|
||||
@skip_if_lo64
|
||||
class LargeObjectNot64Tests(LargeObjectTestCase):
|
||||
def test_seek_larger_than_2gb(self):
|
||||
lo = self.conn.lobject()
|
||||
offset = 1 << 32 # 4gb
|
||||
self.assertRaises(
|
||||
(OverflowError, psycopg2.InterfaceError, psycopg2.NotSupportedError),
|
||||
lo.seek, offset, 0)
|
||||
|
||||
def test_truncate_larger_than_2gb(self):
|
||||
lo = self.conn.lobject()
|
||||
length = 1 << 32 # 4gb
|
||||
self.assertRaises(
|
||||
(OverflowError, psycopg2.InterfaceError, psycopg2.NotSupportedError),
|
||||
lo.truncate, length)
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
367
tests/test_module.py
Executable file
367
tests/test_module.py
Executable file
@ -0,0 +1,367 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# test_module.py - unit test for the module interface
|
||||
#
|
||||
# Copyright (C) 2011-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import pickle
|
||||
from subprocess import Popen
|
||||
from weakref import ref
|
||||
|
||||
import unittest
|
||||
from .testutils import (skip_before_postgres,
|
||||
ConnectingTestCase, skip_copy_if_green, skip_if_crdb, slow, StringIO)
|
||||
|
||||
import psycopg2
|
||||
|
||||
|
||||
class ConnectTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.args = None
|
||||
|
||||
def connect_stub(dsn, connection_factory=None, async_=False):
|
||||
self.args = (dsn, connection_factory, async_)
|
||||
|
||||
self._connect_orig = psycopg2._connect
|
||||
psycopg2._connect = connect_stub
|
||||
|
||||
def tearDown(self):
|
||||
psycopg2._connect = self._connect_orig
|
||||
|
||||
def test_there_might_be_nothing(self):
|
||||
psycopg2.connect()
|
||||
self.assertEqual(self.args[0], '')
|
||||
self.assertEqual(self.args[1], None)
|
||||
self.assertEqual(self.args[2], False)
|
||||
|
||||
psycopg2.connect(
|
||||
connection_factory=lambda dsn, async_=False: None)
|
||||
self.assertEqual(self.args[0], '')
|
||||
self.assertNotEqual(self.args[1], None)
|
||||
self.assertEqual(self.args[2], False)
|
||||
|
||||
psycopg2.connect(async_=True)
|
||||
self.assertEqual(self.args[0], '')
|
||||
self.assertEqual(self.args[1], None)
|
||||
self.assertEqual(self.args[2], True)
|
||||
|
||||
def test_no_keywords(self):
|
||||
psycopg2.connect('')
|
||||
self.assertEqual(self.args[0], '')
|
||||
self.assertEqual(self.args[1], None)
|
||||
self.assertEqual(self.args[2], False)
|
||||
|
||||
def test_dsn(self):
|
||||
psycopg2.connect('dbname=blah host=y')
|
||||
self.assertEqual(self.args[0], 'dbname=blah host=y')
|
||||
self.assertEqual(self.args[1], None)
|
||||
self.assertEqual(self.args[2], False)
|
||||
|
||||
def test_supported_keywords(self):
|
||||
psycopg2.connect(database='foo')
|
||||
self.assertEqual(self.args[0], 'dbname=foo')
|
||||
psycopg2.connect(user='postgres')
|
||||
self.assertEqual(self.args[0], 'user=postgres')
|
||||
psycopg2.connect(password='secret')
|
||||
self.assertEqual(self.args[0], 'password=secret')
|
||||
psycopg2.connect(port=5432)
|
||||
self.assertEqual(self.args[0], 'port=5432')
|
||||
psycopg2.connect(sslmode='require')
|
||||
self.assertEqual(self.args[0], 'sslmode=require')
|
||||
|
||||
psycopg2.connect(database='foo',
|
||||
user='postgres', password='secret', port=5432)
|
||||
self.assert_('dbname=foo' in self.args[0])
|
||||
self.assert_('user=postgres' in self.args[0])
|
||||
self.assert_('password=secret' in self.args[0])
|
||||
self.assert_('port=5432' in self.args[0])
|
||||
self.assertEqual(len(self.args[0].split()), 4)
|
||||
|
||||
def test_generic_keywords(self):
|
||||
psycopg2.connect(options='stuff')
|
||||
self.assertEqual(self.args[0], 'options=stuff')
|
||||
|
||||
def test_factory(self):
|
||||
def f(dsn, async_=False):
|
||||
pass
|
||||
|
||||
psycopg2.connect(database='foo', host='baz', connection_factory=f)
|
||||
self.assertDsnEqual(self.args[0], 'dbname=foo host=baz')
|
||||
self.assertEqual(self.args[1], f)
|
||||
self.assertEqual(self.args[2], False)
|
||||
|
||||
psycopg2.connect("dbname=foo host=baz", connection_factory=f)
|
||||
self.assertDsnEqual(self.args[0], 'dbname=foo host=baz')
|
||||
self.assertEqual(self.args[1], f)
|
||||
self.assertEqual(self.args[2], False)
|
||||
|
||||
def test_async(self):
|
||||
psycopg2.connect(database='foo', host='baz', async_=1)
|
||||
self.assertDsnEqual(self.args[0], 'dbname=foo host=baz')
|
||||
self.assertEqual(self.args[1], None)
|
||||
self.assert_(self.args[2])
|
||||
|
||||
psycopg2.connect("dbname=foo host=baz", async_=True)
|
||||
self.assertDsnEqual(self.args[0], 'dbname=foo host=baz')
|
||||
self.assertEqual(self.args[1], None)
|
||||
self.assert_(self.args[2])
|
||||
|
||||
def test_int_port_param(self):
|
||||
psycopg2.connect(database='sony', port=6543)
|
||||
dsn = f" {self.args[0]} "
|
||||
self.assert_(" dbname=sony " in dsn, dsn)
|
||||
self.assert_(" port=6543 " in dsn, dsn)
|
||||
|
||||
def test_empty_param(self):
|
||||
psycopg2.connect(database='sony', password='')
|
||||
self.assertDsnEqual(self.args[0], "dbname=sony password=''")
|
||||
|
||||
def test_escape(self):
|
||||
psycopg2.connect(database='hello world')
|
||||
self.assertEqual(self.args[0], "dbname='hello world'")
|
||||
|
||||
psycopg2.connect(database=r'back\slash')
|
||||
self.assertEqual(self.args[0], r"dbname=back\\slash")
|
||||
|
||||
psycopg2.connect(database="quo'te")
|
||||
self.assertEqual(self.args[0], r"dbname=quo\'te")
|
||||
|
||||
psycopg2.connect(database="with\ttab")
|
||||
self.assertEqual(self.args[0], "dbname='with\ttab'")
|
||||
|
||||
psycopg2.connect(database=r"\every thing'")
|
||||
self.assertEqual(self.args[0], r"dbname='\\every thing\''")
|
||||
|
||||
def test_params_merging(self):
|
||||
psycopg2.connect('dbname=foo', database='bar')
|
||||
self.assertEqual(self.args[0], 'dbname=bar')
|
||||
|
||||
psycopg2.connect('dbname=foo', user='postgres')
|
||||
self.assertDsnEqual(self.args[0], 'dbname=foo user=postgres')
|
||||
|
||||
|
||||
class ExceptionsTestCase(ConnectingTestCase):
|
||||
def test_attributes(self):
|
||||
cur = self.conn.cursor()
|
||||
try:
|
||||
cur.execute("select * from nonexist")
|
||||
except psycopg2.Error as exc:
|
||||
e = exc
|
||||
|
||||
self.assertEqual(e.pgcode, '42P01')
|
||||
self.assert_(e.pgerror)
|
||||
self.assert_(e.cursor is cur)
|
||||
|
||||
def test_diagnostics_attributes(self):
|
||||
cur = self.conn.cursor()
|
||||
try:
|
||||
cur.execute("select * from nonexist")
|
||||
except psycopg2.Error as exc:
|
||||
e = exc
|
||||
|
||||
diag = e.diag
|
||||
self.assert_(isinstance(diag, psycopg2.extensions.Diagnostics))
|
||||
for attr in [
|
||||
'column_name', 'constraint_name', 'context', 'datatype_name',
|
||||
'internal_position', 'internal_query', 'message_detail',
|
||||
'message_hint', 'message_primary', 'schema_name', 'severity',
|
||||
'severity_nonlocalized', 'source_file', 'source_function',
|
||||
'source_line', 'sqlstate', 'statement_position', 'table_name', ]:
|
||||
v = getattr(diag, attr)
|
||||
if v is not None:
|
||||
self.assert_(isinstance(v, str))
|
||||
|
||||
def test_diagnostics_values(self):
|
||||
cur = self.conn.cursor()
|
||||
try:
|
||||
cur.execute("select * from nonexist")
|
||||
except psycopg2.Error as exc:
|
||||
e = exc
|
||||
|
||||
self.assertEqual(e.diag.sqlstate, '42P01')
|
||||
self.assertEqual(e.diag.severity, 'ERROR')
|
||||
|
||||
def test_diagnostics_life(self):
|
||||
def tmp():
|
||||
cur = self.conn.cursor()
|
||||
try:
|
||||
cur.execute("select * from nonexist")
|
||||
except psycopg2.Error as exc:
|
||||
return cur, exc
|
||||
|
||||
cur, e = tmp()
|
||||
diag = e.diag
|
||||
w = ref(cur)
|
||||
|
||||
del e, cur
|
||||
gc.collect()
|
||||
assert(w() is not None)
|
||||
|
||||
self.assertEqual(diag.sqlstate, '42P01')
|
||||
|
||||
del diag
|
||||
gc.collect()
|
||||
gc.collect()
|
||||
assert(w() is None)
|
||||
|
||||
@skip_if_crdb("copy")
|
||||
@skip_copy_if_green
|
||||
def test_diagnostics_copy(self):
|
||||
f = StringIO()
|
||||
cur = self.conn.cursor()
|
||||
try:
|
||||
cur.copy_to(f, 'nonexist')
|
||||
except psycopg2.Error as exc:
|
||||
diag = exc.diag
|
||||
|
||||
self.assertEqual(diag.sqlstate, '42P01')
|
||||
|
||||
def test_diagnostics_independent(self):
|
||||
cur = self.conn.cursor()
|
||||
try:
|
||||
cur.execute("l'acqua e' poca e 'a papera nun galleggia")
|
||||
except Exception as exc:
|
||||
diag1 = exc.diag
|
||||
|
||||
self.conn.rollback()
|
||||
|
||||
try:
|
||||
cur.execute("select level from water where ducks > 1")
|
||||
except psycopg2.Error as exc:
|
||||
diag2 = exc.diag
|
||||
|
||||
self.assertEqual(diag1.sqlstate, '42601')
|
||||
self.assertEqual(diag2.sqlstate, '42P01')
|
||||
|
||||
@skip_if_crdb("deferrable")
|
||||
def test_diagnostics_from_commit(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("""
|
||||
create temp table test_deferred (
|
||||
data int primary key,
|
||||
ref int references test_deferred (data)
|
||||
deferrable initially deferred)
|
||||
""")
|
||||
cur.execute("insert into test_deferred values (1,2)")
|
||||
try:
|
||||
self.conn.commit()
|
||||
except psycopg2.Error as exc:
|
||||
e = exc
|
||||
self.assertEqual(e.diag.sqlstate, '23503')
|
||||
|
||||
@skip_if_crdb("diagnostic")
|
||||
@skip_before_postgres(9, 3)
|
||||
def test_9_3_diagnostics(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("""
|
||||
create temp table test_exc (
|
||||
data int constraint chk_eq1 check (data = 1)
|
||||
)""")
|
||||
try:
|
||||
cur.execute("insert into test_exc values(2)")
|
||||
except psycopg2.Error as exc:
|
||||
e = exc
|
||||
self.assertEqual(e.pgcode, '23514')
|
||||
self.assertEqual(e.diag.schema_name[:7], "pg_temp")
|
||||
self.assertEqual(e.diag.table_name, "test_exc")
|
||||
self.assertEqual(e.diag.column_name, None)
|
||||
self.assertEqual(e.diag.constraint_name, "chk_eq1")
|
||||
self.assertEqual(e.diag.datatype_name, None)
|
||||
|
||||
@skip_if_crdb("diagnostic")
|
||||
@skip_before_postgres(9, 6)
|
||||
def test_9_6_diagnostics(self):
|
||||
cur = self.conn.cursor()
|
||||
try:
|
||||
cur.execute("select 1 from nosuchtable")
|
||||
except psycopg2.Error as exc:
|
||||
e = exc
|
||||
self.assertEqual(e.diag.severity_nonlocalized, 'ERROR')
|
||||
|
||||
def test_pickle(self):
|
||||
cur = self.conn.cursor()
|
||||
try:
|
||||
cur.execute("select * from nonexist")
|
||||
except psycopg2.Error as exc:
|
||||
e = exc
|
||||
|
||||
e1 = pickle.loads(pickle.dumps(e))
|
||||
|
||||
self.assertEqual(e.pgerror, e1.pgerror)
|
||||
self.assertEqual(e.pgcode, e1.pgcode)
|
||||
self.assert_(e1.cursor is None)
|
||||
|
||||
@skip_if_crdb("connect any db")
|
||||
def test_pickle_connection_error(self):
|
||||
# segfaults on psycopg 2.5.1 - see ticket #170
|
||||
try:
|
||||
psycopg2.connect('dbname=nosuchdatabasemate')
|
||||
except psycopg2.Error as exc:
|
||||
e = exc
|
||||
|
||||
e1 = pickle.loads(pickle.dumps(e))
|
||||
|
||||
self.assertEqual(e.pgerror, e1.pgerror)
|
||||
self.assertEqual(e.pgcode, e1.pgcode)
|
||||
self.assert_(e1.cursor is None)
|
||||
|
||||
|
||||
class TestExtensionModule(unittest.TestCase):
|
||||
@slow
|
||||
def test_import_internal(self):
|
||||
# check that the internal package can be imported "naked"
|
||||
# we may break this property if there is a compelling reason to do so,
|
||||
# however having it allows for some import juggling such as the one
|
||||
# required in ticket #201.
|
||||
pkgdir = os.path.dirname(psycopg2.__file__)
|
||||
pardir = os.path.dirname(pkgdir)
|
||||
self.assert_(pardir in sys.path)
|
||||
script = f"""
|
||||
import sys
|
||||
sys.path.remove({pardir!r})
|
||||
sys.path.insert(0, {pkgdir!r})
|
||||
import _psycopg
|
||||
"""
|
||||
|
||||
proc = Popen([sys.executable, '-c', script])
|
||||
proc.communicate()
|
||||
self.assertEqual(0, proc.returncode)
|
||||
|
||||
|
||||
class TestVersionDiscovery(unittest.TestCase):
|
||||
def test_libpq_version(self):
|
||||
self.assertTrue(type(psycopg2.__libpq_version__) is int)
|
||||
try:
|
||||
self.assertTrue(type(psycopg2.extensions.libpq_version()) is int)
|
||||
except psycopg2.NotSupportedError:
|
||||
self.assertTrue(psycopg2.__libpq_version__ < 90100)
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
231
tests/test_notify.py
Executable file
231
tests/test_notify.py
Executable file
@ -0,0 +1,231 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# test_notify.py - unit test for async notifications
|
||||
#
|
||||
# Copyright (C) 2010-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import unittest
|
||||
from collections import deque
|
||||
|
||||
import psycopg2
|
||||
from psycopg2 import extensions
|
||||
from psycopg2.extensions import Notify
|
||||
from .testutils import ConnectingTestCase, skip_if_crdb, slow
|
||||
from .testconfig import dsn
|
||||
|
||||
import sys
|
||||
import time
|
||||
import select
|
||||
from subprocess import Popen, PIPE
|
||||
|
||||
|
||||
@skip_if_crdb("notify")
|
||||
class NotifiesTests(ConnectingTestCase):
|
||||
|
||||
def autocommit(self, conn):
|
||||
"""Set a connection in autocommit mode."""
|
||||
conn.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT)
|
||||
|
||||
def listen(self, name):
|
||||
"""Start listening for a name on self.conn."""
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("LISTEN " + name)
|
||||
curs.close()
|
||||
|
||||
def notify(self, name, sec=0, payload=None):
|
||||
"""Send a notification to the database, eventually after some time."""
|
||||
if payload is None:
|
||||
payload = ''
|
||||
else:
|
||||
payload = f", {payload!r}"
|
||||
|
||||
script = ("""\
|
||||
import time
|
||||
time.sleep({sec})
|
||||
import {module} as psycopg2
|
||||
import {module}.extensions as ext
|
||||
conn = psycopg2.connect({dsn!r})
|
||||
conn.set_isolation_level(ext.ISOLATION_LEVEL_AUTOCOMMIT)
|
||||
print(conn.info.backend_pid)
|
||||
curs = conn.cursor()
|
||||
curs.execute("NOTIFY " {name!r} {payload!r})
|
||||
curs.close()
|
||||
conn.close()
|
||||
""".format(
|
||||
module=psycopg2.__name__,
|
||||
dsn=dsn, sec=sec, name=name, payload=payload))
|
||||
|
||||
return Popen([sys.executable, '-c', script], stdout=PIPE)
|
||||
|
||||
@slow
|
||||
def test_notifies_received_on_poll(self):
|
||||
self.autocommit(self.conn)
|
||||
self.listen('foo')
|
||||
|
||||
proc = self.notify('foo', 1)
|
||||
|
||||
t0 = time.time()
|
||||
select.select([self.conn], [], [], 5)
|
||||
t1 = time.time()
|
||||
self.assert_(0.99 < t1 - t0 < 4, t1 - t0)
|
||||
|
||||
pid = int(proc.communicate()[0])
|
||||
self.assertEqual(0, len(self.conn.notifies))
|
||||
self.assertEqual(extensions.POLL_OK, self.conn.poll())
|
||||
self.assertEqual(1, len(self.conn.notifies))
|
||||
self.assertEqual(pid, self.conn.notifies[0][0])
|
||||
self.assertEqual('foo', self.conn.notifies[0][1])
|
||||
|
||||
@slow
|
||||
def test_many_notifies(self):
|
||||
self.autocommit(self.conn)
|
||||
for name in ['foo', 'bar', 'baz']:
|
||||
self.listen(name)
|
||||
|
||||
pids = {}
|
||||
for name in ['foo', 'bar', 'baz', 'qux']:
|
||||
pids[name] = int(self.notify(name).communicate()[0])
|
||||
|
||||
self.assertEqual(0, len(self.conn.notifies))
|
||||
for i in range(10):
|
||||
self.assertEqual(extensions.POLL_OK, self.conn.poll())
|
||||
self.assertEqual(3, len(self.conn.notifies))
|
||||
|
||||
names = dict.fromkeys(['foo', 'bar', 'baz'])
|
||||
for (pid, name) in self.conn.notifies:
|
||||
self.assertEqual(pids[name], pid)
|
||||
names.pop(name) # raise if name found twice
|
||||
|
||||
@slow
|
||||
def test_notifies_received_on_execute(self):
|
||||
self.autocommit(self.conn)
|
||||
self.listen('foo')
|
||||
pid = int(self.notify('foo').communicate()[0])
|
||||
self.assertEqual(0, len(self.conn.notifies))
|
||||
self.conn.cursor().execute('select 1;')
|
||||
self.assertEqual(1, len(self.conn.notifies))
|
||||
self.assertEqual(pid, self.conn.notifies[0][0])
|
||||
self.assertEqual('foo', self.conn.notifies[0][1])
|
||||
|
||||
@slow
|
||||
def test_notify_object(self):
|
||||
self.autocommit(self.conn)
|
||||
self.listen('foo')
|
||||
self.notify('foo').communicate()
|
||||
time.sleep(0.5)
|
||||
self.conn.poll()
|
||||
notify = self.conn.notifies[0]
|
||||
self.assert_(isinstance(notify, psycopg2.extensions.Notify))
|
||||
|
||||
@slow
|
||||
def test_notify_attributes(self):
|
||||
self.autocommit(self.conn)
|
||||
self.listen('foo')
|
||||
pid = int(self.notify('foo').communicate()[0])
|
||||
time.sleep(0.5)
|
||||
self.conn.poll()
|
||||
self.assertEqual(1, len(self.conn.notifies))
|
||||
notify = self.conn.notifies[0]
|
||||
self.assertEqual(pid, notify.pid)
|
||||
self.assertEqual('foo', notify.channel)
|
||||
self.assertEqual('', notify.payload)
|
||||
|
||||
@slow
|
||||
def test_notify_payload(self):
|
||||
if self.conn.info.server_version < 90000:
|
||||
return self.skipTest("server version %s doesn't support notify payload"
|
||||
% self.conn.info.server_version)
|
||||
self.autocommit(self.conn)
|
||||
self.listen('foo')
|
||||
pid = int(self.notify('foo', payload="Hello, world!").communicate()[0])
|
||||
time.sleep(0.5)
|
||||
self.conn.poll()
|
||||
self.assertEqual(1, len(self.conn.notifies))
|
||||
notify = self.conn.notifies[0]
|
||||
self.assertEqual(pid, notify.pid)
|
||||
self.assertEqual('foo', notify.channel)
|
||||
self.assertEqual('Hello, world!', notify.payload)
|
||||
|
||||
@slow
|
||||
def test_notify_deque(self):
|
||||
self.autocommit(self.conn)
|
||||
self.conn.notifies = deque()
|
||||
self.listen('foo')
|
||||
self.notify('foo').communicate()
|
||||
time.sleep(0.5)
|
||||
self.conn.poll()
|
||||
notify = self.conn.notifies.popleft()
|
||||
self.assert_(isinstance(notify, psycopg2.extensions.Notify))
|
||||
self.assertEqual(len(self.conn.notifies), 0)
|
||||
|
||||
@slow
|
||||
def test_notify_noappend(self):
|
||||
self.autocommit(self.conn)
|
||||
self.conn.notifies = None
|
||||
self.listen('foo')
|
||||
self.notify('foo').communicate()
|
||||
time.sleep(0.5)
|
||||
self.conn.poll()
|
||||
self.assertEqual(self.conn.notifies, None)
|
||||
|
||||
def test_notify_init(self):
|
||||
n = psycopg2.extensions.Notify(10, 'foo')
|
||||
self.assertEqual(10, n.pid)
|
||||
self.assertEqual('foo', n.channel)
|
||||
self.assertEqual('', n.payload)
|
||||
(pid, channel) = n
|
||||
self.assertEqual((pid, channel), (10, 'foo'))
|
||||
|
||||
n = psycopg2.extensions.Notify(42, 'bar', 'baz')
|
||||
self.assertEqual(42, n.pid)
|
||||
self.assertEqual('bar', n.channel)
|
||||
self.assertEqual('baz', n.payload)
|
||||
(pid, channel) = n
|
||||
self.assertEqual((pid, channel), (42, 'bar'))
|
||||
|
||||
def test_compare(self):
|
||||
data = [(10, 'foo'), (20, 'foo'), (10, 'foo', 'bar'), (10, 'foo', 'baz')]
|
||||
for d1 in data:
|
||||
for d2 in data:
|
||||
n1 = psycopg2.extensions.Notify(*d1)
|
||||
n2 = psycopg2.extensions.Notify(*d2)
|
||||
self.assertEqual((n1 == n2), (d1 == d2))
|
||||
self.assertEqual((n1 != n2), (d1 != d2))
|
||||
|
||||
def test_compare_tuple(self):
|
||||
self.assertEqual((10, 'foo'), Notify(10, 'foo'))
|
||||
self.assertEqual((10, 'foo'), Notify(10, 'foo', 'bar'))
|
||||
self.assertNotEqual((10, 'foo'), Notify(20, 'foo'))
|
||||
self.assertNotEqual((10, 'foo'), Notify(10, 'bar'))
|
||||
|
||||
def test_hash(self):
|
||||
self.assertEqual(hash((10, 'foo')), hash(Notify(10, 'foo')))
|
||||
self.assertNotEqual(hash(Notify(10, 'foo', 'bar')),
|
||||
hash(Notify(10, 'foo')))
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
86
tests/test_psycopg2_dbapi20.py
Executable file
86
tests/test_psycopg2_dbapi20.py
Executable file
@ -0,0 +1,86 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# test_psycopg2_dbapi20.py - DB API conformance test for psycopg2
|
||||
#
|
||||
# Copyright (C) 2006-2019 Federico Di Gregorio <fog@debian.org>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
from . import dbapi20
|
||||
from . import dbapi20_tpc
|
||||
from .testutils import skip_if_tpc_disabled
|
||||
import unittest
|
||||
import psycopg2
|
||||
|
||||
from .testconfig import dsn
|
||||
|
||||
|
||||
class Psycopg2Tests(dbapi20.DatabaseAPI20Test):
|
||||
driver = psycopg2
|
||||
connect_args = ()
|
||||
connect_kw_args = {'dsn': dsn}
|
||||
|
||||
lower_func = 'lower' # For stored procedure test
|
||||
|
||||
def test_callproc(self):
|
||||
# Until DBAPI 2.0 compliance, callproc should return None or it's just
|
||||
# misleading. Therefore, we will skip the return value test for
|
||||
# callproc and only perform the fetch test.
|
||||
#
|
||||
# For what it's worth, the DBAPI2.0 test_callproc doesn't actually
|
||||
# test for DBAPI2.0 compliance! It doesn't check for modified OUT and
|
||||
# IN/OUT parameters in the return values!
|
||||
con = self._connect()
|
||||
try:
|
||||
cur = con.cursor()
|
||||
if self.lower_func and hasattr(cur, 'callproc'):
|
||||
cur.callproc(self.lower_func, ('FOO',))
|
||||
r = cur.fetchall()
|
||||
self.assertEqual(len(r), 1, 'callproc produced no result set')
|
||||
self.assertEqual(len(r[0]), 1,
|
||||
'callproc produced invalid result set')
|
||||
self.assertEqual(r[0][0], 'foo',
|
||||
'callproc produced invalid results')
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
def test_setoutputsize(self):
|
||||
# psycopg2's setoutputsize() is a no-op
|
||||
pass
|
||||
|
||||
def test_nextset(self):
|
||||
# psycopg2 does not implement nextset()
|
||||
pass
|
||||
|
||||
|
||||
@skip_if_tpc_disabled
|
||||
class Psycopg2TPCTests(dbapi20_tpc.TwoPhaseCommitTests, unittest.TestCase):
|
||||
driver = psycopg2
|
||||
|
||||
def connect(self):
|
||||
return psycopg2.connect(dsn=dsn)
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
229
tests/test_quote.py
Executable file
229
tests/test_quote.py
Executable file
@ -0,0 +1,229 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# test_quote.py - unit test for strings quoting
|
||||
#
|
||||
# Copyright (C) 2007-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
from . import testutils
|
||||
import unittest
|
||||
from .testutils import ConnectingTestCase, skip_if_crdb
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extensions
|
||||
from psycopg2.extensions import adapt, quote_ident
|
||||
|
||||
|
||||
class QuotingTestCase(ConnectingTestCase):
|
||||
r"""Checks the correct quoting of strings and binary objects.
|
||||
|
||||
Since ver. 8.1, PostgreSQL is moving towards SQL standard conforming
|
||||
strings, where the backslash (\) is treated as literal character,
|
||||
not as escape. To treat the backslash as a C-style escapes, PG supports
|
||||
the E'' quotes.
|
||||
|
||||
This test case checks that the E'' quotes are used whenever they are
|
||||
needed. The tests are expected to pass with all PostgreSQL server versions
|
||||
(currently tested with 7.4 <= PG <= 8.3beta) and with any
|
||||
'standard_conforming_strings' server parameter value.
|
||||
The tests also check that no warning is raised ('escape_string_warning'
|
||||
should be on).
|
||||
|
||||
https://www.postgresql.org/docs/current/static/sql-syntax-lexical.html#SQL-SYNTAX-STRINGS
|
||||
https://www.postgresql.org/docs/current/static/runtime-config-compatible.html
|
||||
"""
|
||||
def test_string(self):
|
||||
data = """some data with \t chars
|
||||
to escape into, 'quotes' and \\ a backslash too.
|
||||
"""
|
||||
data += "".join(map(chr, range(1, 127)))
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("SELECT %s;", (data,))
|
||||
res = curs.fetchone()[0]
|
||||
|
||||
self.assertEqual(res, data)
|
||||
self.assert_(not self.conn.notices)
|
||||
|
||||
def test_string_null_terminator(self):
|
||||
curs = self.conn.cursor()
|
||||
data = 'abcd\x01\x00cdefg'
|
||||
|
||||
try:
|
||||
curs.execute("SELECT %s", (data,))
|
||||
except ValueError as e:
|
||||
self.assertEquals(str(e),
|
||||
'A string literal cannot contain NUL (0x00) characters.')
|
||||
else:
|
||||
self.fail("ValueError not raised")
|
||||
|
||||
def test_binary(self):
|
||||
data = b"""some data with \000\013 binary
|
||||
stuff into, 'quotes' and \\ a backslash too.
|
||||
"""
|
||||
data += bytes(list(range(256)))
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("SELECT %s::bytea;", (psycopg2.Binary(data),))
|
||||
res = curs.fetchone()[0].tobytes()
|
||||
|
||||
if res[0] in (b'x', ord(b'x')) and self.conn.info.server_version >= 90000:
|
||||
return self.skipTest(
|
||||
"bytea broken with server >= 9.0, libpq < 9")
|
||||
|
||||
self.assertEqual(res, data)
|
||||
self.assert_(not self.conn.notices)
|
||||
|
||||
def test_unicode(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("SHOW server_encoding")
|
||||
server_encoding = curs.fetchone()[0]
|
||||
if server_encoding != "UTF8":
|
||||
return self.skipTest(
|
||||
f"Unicode test skipped since server encoding is {server_encoding}")
|
||||
|
||||
data = """some data with \t chars
|
||||
to escape into, 'quotes', \u20ac euro sign and \\ a backslash too.
|
||||
"""
|
||||
data += "".join(map(chr, [u for u in range(1, 65536)
|
||||
if not 0xD800 <= u <= 0xDFFF])) # surrogate area
|
||||
self.conn.set_client_encoding('UNICODE')
|
||||
|
||||
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, self.conn)
|
||||
curs.execute("SELECT %s::text;", (data,))
|
||||
res = curs.fetchone()[0]
|
||||
|
||||
self.assertEqual(res, data)
|
||||
self.assert_(not self.conn.notices)
|
||||
|
||||
@skip_if_crdb("encoding")
|
||||
def test_latin1(self):
|
||||
self.conn.set_client_encoding('LATIN1')
|
||||
curs = self.conn.cursor()
|
||||
data = bytes(list(range(32, 127))
|
||||
+ list(range(160, 256))).decode('latin1')
|
||||
|
||||
# as string
|
||||
curs.execute("SELECT %s::text;", (data,))
|
||||
res = curs.fetchone()[0]
|
||||
self.assertEqual(res, data)
|
||||
self.assert_(not self.conn.notices)
|
||||
|
||||
|
||||
@skip_if_crdb("encoding")
|
||||
def test_koi8(self):
|
||||
self.conn.set_client_encoding('KOI8')
|
||||
curs = self.conn.cursor()
|
||||
data = bytes(list(range(32, 127))
|
||||
+ list(range(128, 256))).decode('koi8_r')
|
||||
|
||||
# as string
|
||||
curs.execute("SELECT %s::text;", (data,))
|
||||
res = curs.fetchone()[0]
|
||||
self.assertEqual(res, data)
|
||||
self.assert_(not self.conn.notices)
|
||||
|
||||
def test_bytes(self):
|
||||
snowman = "\u2603"
|
||||
conn = self.connect()
|
||||
conn.set_client_encoding('UNICODE')
|
||||
psycopg2.extensions.register_type(psycopg2.extensions.BYTES, conn)
|
||||
curs = conn.cursor()
|
||||
curs.execute("select %s::text", (snowman,))
|
||||
x = curs.fetchone()[0]
|
||||
self.assert_(isinstance(x, bytes))
|
||||
self.assertEqual(x, snowman.encode('utf8'))
|
||||
|
||||
|
||||
class TestQuotedString(ConnectingTestCase):
|
||||
def test_encoding_from_conn(self):
|
||||
q = psycopg2.extensions.QuotedString('hi')
|
||||
self.assertEqual(q.encoding, 'latin1')
|
||||
|
||||
self.conn.set_client_encoding('utf_8')
|
||||
q.prepare(self.conn)
|
||||
self.assertEqual(q.encoding, 'utf_8')
|
||||
|
||||
|
||||
class TestQuotedIdentifier(ConnectingTestCase):
|
||||
def test_identifier(self):
|
||||
self.assertEqual(quote_ident('blah-blah', self.conn), '"blah-blah"')
|
||||
self.assertEqual(quote_ident('quote"inside', self.conn), '"quote""inside"')
|
||||
|
||||
@testutils.skip_before_postgres(8, 0)
|
||||
def test_unicode_ident(self):
|
||||
snowman = "\u2603"
|
||||
quoted = '"' + snowman + '"'
|
||||
self.assertEqual(quote_ident(snowman, self.conn), quoted)
|
||||
|
||||
|
||||
class TestStringAdapter(ConnectingTestCase):
|
||||
def test_encoding_default(self):
|
||||
a = adapt("hello")
|
||||
self.assertEqual(a.encoding, 'latin1')
|
||||
self.assertEqual(a.getquoted(), b"'hello'")
|
||||
|
||||
# NOTE: we can't really test an encoding different from utf8, because
|
||||
# when encoding without connection the libpq will use parameters from
|
||||
# a previous one, so what would happens depends jn the tests run order.
|
||||
# egrave = u'\xe8'
|
||||
# self.assertEqual(adapt(egrave).getquoted(), "'\xe8'")
|
||||
|
||||
def test_encoding_error(self):
|
||||
snowman = "\u2603"
|
||||
a = adapt(snowman)
|
||||
self.assertRaises(UnicodeEncodeError, a.getquoted)
|
||||
|
||||
def test_set_encoding(self):
|
||||
# Note: this works-ish mostly in case when the standard db connection
|
||||
# we test with is utf8, otherwise the encoding chosen by PQescapeString
|
||||
# may give bad results.
|
||||
snowman = "\u2603"
|
||||
a = adapt(snowman)
|
||||
a.encoding = 'utf8'
|
||||
self.assertEqual(a.encoding, 'utf8')
|
||||
self.assertEqual(a.getquoted(), b"'\xe2\x98\x83'")
|
||||
|
||||
def test_connection_wins_anyway(self):
|
||||
snowman = "\u2603"
|
||||
a = adapt(snowman)
|
||||
a.encoding = 'latin9'
|
||||
|
||||
self.conn.set_client_encoding('utf8')
|
||||
a.prepare(self.conn)
|
||||
|
||||
self.assertEqual(a.encoding, 'utf_8')
|
||||
self.assertQuotedEqual(a.getquoted(), b"'\xe2\x98\x83'")
|
||||
|
||||
def test_adapt_bytes(self):
|
||||
snowman = "\u2603"
|
||||
self.conn.set_client_encoding('utf8')
|
||||
a = psycopg2.extensions.QuotedString(snowman.encode('utf8'))
|
||||
a.prepare(self.conn)
|
||||
self.assertQuotedEqual(a.getquoted(), b"'\xe2\x98\x83'")
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
276
tests/test_replication.py
Executable file
276
tests/test_replication.py
Executable file
@ -0,0 +1,276 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# test_replication.py - unit test for replication protocol
|
||||
#
|
||||
# Copyright (C) 2015-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import time
|
||||
from select import select
|
||||
|
||||
import psycopg2
|
||||
from psycopg2 import sql
|
||||
from psycopg2.extras import (
|
||||
PhysicalReplicationConnection, LogicalReplicationConnection, StopReplication)
|
||||
|
||||
from . import testconfig
|
||||
import unittest
|
||||
from .testutils import ConnectingTestCase
|
||||
from .testutils import skip_before_postgres, skip_if_green
|
||||
|
||||
skip_repl_if_green = skip_if_green("replication not supported in green mode")
|
||||
|
||||
|
||||
class ReplicationTestCase(ConnectingTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.slot = testconfig.repl_slot
|
||||
self._slots = []
|
||||
|
||||
def tearDown(self):
|
||||
# first close all connections, as they might keep the slot(s) active
|
||||
super().tearDown()
|
||||
|
||||
time.sleep(0.025) # sometimes the slot is still active, wait a little
|
||||
|
||||
if self._slots:
|
||||
kill_conn = self.connect()
|
||||
if kill_conn:
|
||||
kill_cur = kill_conn.cursor()
|
||||
for slot in self._slots:
|
||||
kill_cur.execute("SELECT pg_drop_replication_slot(%s)", (slot,))
|
||||
kill_conn.commit()
|
||||
kill_conn.close()
|
||||
|
||||
def create_replication_slot(self, cur, slot_name=testconfig.repl_slot, **kwargs):
|
||||
cur.create_replication_slot(slot_name, **kwargs)
|
||||
self._slots.append(slot_name)
|
||||
|
||||
def drop_replication_slot(self, cur, slot_name=testconfig.repl_slot):
|
||||
cur.drop_replication_slot(slot_name)
|
||||
self._slots.remove(slot_name)
|
||||
|
||||
# generate some events for our replication stream
|
||||
def make_replication_events(self):
|
||||
conn = self.connect()
|
||||
if conn is None:
|
||||
return
|
||||
cur = conn.cursor()
|
||||
|
||||
try:
|
||||
cur.execute("DROP TABLE dummy1")
|
||||
except psycopg2.ProgrammingError:
|
||||
conn.rollback()
|
||||
cur.execute(
|
||||
"CREATE TABLE dummy1 AS SELECT * FROM generate_series(1, 5) AS id")
|
||||
conn.commit()
|
||||
|
||||
|
||||
class ReplicationTest(ReplicationTestCase):
|
||||
@skip_before_postgres(9, 0)
|
||||
def test_physical_replication_connection(self):
|
||||
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
|
||||
if conn is None:
|
||||
return
|
||||
cur = conn.cursor()
|
||||
cur.execute("IDENTIFY_SYSTEM")
|
||||
cur.fetchall()
|
||||
|
||||
@skip_before_postgres(9, 0)
|
||||
def test_datestyle(self):
|
||||
if testconfig.repl_dsn is None:
|
||||
return self.skipTest("replication tests disabled by default")
|
||||
|
||||
conn = self.repl_connect(
|
||||
dsn=testconfig.repl_dsn, options='-cdatestyle=german',
|
||||
connection_factory=PhysicalReplicationConnection)
|
||||
if conn is None:
|
||||
return
|
||||
cur = conn.cursor()
|
||||
cur.execute("IDENTIFY_SYSTEM")
|
||||
cur.fetchall()
|
||||
|
||||
@skip_before_postgres(9, 4)
|
||||
def test_logical_replication_connection(self):
|
||||
conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
|
||||
if conn is None:
|
||||
return
|
||||
cur = conn.cursor()
|
||||
cur.execute("IDENTIFY_SYSTEM")
|
||||
cur.fetchall()
|
||||
|
||||
@skip_before_postgres(9, 4) # slots require 9.4
|
||||
def test_create_replication_slot(self):
|
||||
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
|
||||
if conn is None:
|
||||
return
|
||||
cur = conn.cursor()
|
||||
|
||||
self.create_replication_slot(cur)
|
||||
self.assertRaises(
|
||||
psycopg2.ProgrammingError, self.create_replication_slot, cur)
|
||||
|
||||
@skip_before_postgres(9, 4) # slots require 9.4
|
||||
@skip_repl_if_green
|
||||
def test_start_on_missing_replication_slot(self):
|
||||
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
|
||||
if conn is None:
|
||||
return
|
||||
cur = conn.cursor()
|
||||
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
cur.start_replication, self.slot)
|
||||
|
||||
self.create_replication_slot(cur)
|
||||
cur.start_replication(self.slot)
|
||||
|
||||
@skip_before_postgres(9, 4) # slots require 9.4
|
||||
@skip_repl_if_green
|
||||
def test_start_replication_expert_sql(self):
|
||||
conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
|
||||
if conn is None:
|
||||
return
|
||||
cur = conn.cursor()
|
||||
|
||||
self.create_replication_slot(cur, output_plugin='test_decoding')
|
||||
cur.start_replication_expert(
|
||||
sql.SQL("START_REPLICATION SLOT {slot} LOGICAL 0/00000000").format(
|
||||
slot=sql.Identifier(self.slot)))
|
||||
|
||||
@skip_before_postgres(9, 4) # slots require 9.4
|
||||
@skip_repl_if_green
|
||||
def test_start_and_recover_from_error(self):
|
||||
conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
|
||||
if conn is None:
|
||||
return
|
||||
cur = conn.cursor()
|
||||
|
||||
self.create_replication_slot(cur, output_plugin='test_decoding')
|
||||
self.make_replication_events()
|
||||
|
||||
def consume(msg):
|
||||
raise StopReplication()
|
||||
|
||||
with self.assertRaises(psycopg2.DataError):
|
||||
# try with invalid options
|
||||
cur.start_replication(
|
||||
slot_name=self.slot, options={'invalid_param': 'value'})
|
||||
cur.consume_stream(consume)
|
||||
|
||||
# try with correct command
|
||||
cur.start_replication(slot_name=self.slot)
|
||||
self.assertRaises(StopReplication, cur.consume_stream, consume)
|
||||
|
||||
@skip_before_postgres(9, 4) # slots require 9.4
|
||||
@skip_repl_if_green
|
||||
def test_keepalive(self):
|
||||
conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
|
||||
if conn is None:
|
||||
return
|
||||
|
||||
cur = conn.cursor()
|
||||
|
||||
self.create_replication_slot(cur, output_plugin='test_decoding')
|
||||
|
||||
self.make_replication_events()
|
||||
|
||||
cur.start_replication(self.slot)
|
||||
|
||||
def consume(msg):
|
||||
raise StopReplication()
|
||||
|
||||
self.assertRaises(StopReplication,
|
||||
cur.consume_stream, consume, keepalive_interval=2)
|
||||
|
||||
conn.close()
|
||||
|
||||
@skip_before_postgres(9, 4) # slots require 9.4
|
||||
@skip_repl_if_green
|
||||
def test_stop_replication(self):
|
||||
conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
|
||||
if conn is None:
|
||||
return
|
||||
cur = conn.cursor()
|
||||
|
||||
self.create_replication_slot(cur, output_plugin='test_decoding')
|
||||
|
||||
self.make_replication_events()
|
||||
|
||||
cur.start_replication(self.slot)
|
||||
|
||||
def consume(msg):
|
||||
raise StopReplication()
|
||||
self.assertRaises(StopReplication, cur.consume_stream, consume)
|
||||
|
||||
|
||||
class AsyncReplicationTest(ReplicationTestCase):
|
||||
@skip_before_postgres(9, 4) # slots require 9.4
|
||||
@skip_repl_if_green
|
||||
def test_async_replication(self):
|
||||
conn = self.repl_connect(
|
||||
connection_factory=LogicalReplicationConnection, async_=1)
|
||||
if conn is None:
|
||||
return
|
||||
|
||||
cur = conn.cursor()
|
||||
|
||||
self.create_replication_slot(cur, output_plugin='test_decoding')
|
||||
self.wait(cur)
|
||||
|
||||
cur.start_replication(self.slot)
|
||||
self.wait(cur)
|
||||
|
||||
self.make_replication_events()
|
||||
|
||||
self.msg_count = 0
|
||||
|
||||
def consume(msg):
|
||||
# just check the methods
|
||||
f"{cur.io_timestamp}: {repr(msg)}"
|
||||
f"{cur.feedback_timestamp}: {repr(msg)}"
|
||||
f"{cur.wal_end}: {repr(msg)}"
|
||||
|
||||
self.msg_count += 1
|
||||
if self.msg_count > 3:
|
||||
cur.send_feedback(reply=True)
|
||||
raise StopReplication()
|
||||
|
||||
cur.send_feedback(flush_lsn=msg.data_start)
|
||||
|
||||
# cannot be used in asynchronous mode
|
||||
self.assertRaises(psycopg2.ProgrammingError, cur.consume_stream, consume)
|
||||
|
||||
def process_stream():
|
||||
while True:
|
||||
msg = cur.read_message()
|
||||
if msg:
|
||||
consume(msg)
|
||||
else:
|
||||
select([cur], [], [])
|
||||
self.assertRaises(StopReplication, process_stream)
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
412
tests/test_sql.py
Executable file
412
tests/test_sql.py
Executable file
@ -0,0 +1,412 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# test_sql.py - tests for the psycopg2.sql module
|
||||
#
|
||||
# Copyright (C) 2016-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import datetime as dt
|
||||
import unittest
|
||||
from .testutils import (
|
||||
ConnectingTestCase, skip_before_postgres, skip_copy_if_green, StringIO,
|
||||
skip_if_crdb)
|
||||
|
||||
import psycopg2
|
||||
from psycopg2 import sql
|
||||
|
||||
|
||||
class SqlFormatTests(ConnectingTestCase):
|
||||
def test_pos(self):
|
||||
s = sql.SQL("select {} from {}").format(
|
||||
sql.Identifier('field'), sql.Identifier('table'))
|
||||
s1 = s.as_string(self.conn)
|
||||
self.assert_(isinstance(s1, str))
|
||||
self.assertEqual(s1, 'select "field" from "table"')
|
||||
|
||||
def test_pos_spec(self):
|
||||
s = sql.SQL("select {0} from {1}").format(
|
||||
sql.Identifier('field'), sql.Identifier('table'))
|
||||
s1 = s.as_string(self.conn)
|
||||
self.assert_(isinstance(s1, str))
|
||||
self.assertEqual(s1, 'select "field" from "table"')
|
||||
|
||||
s = sql.SQL("select {1} from {0}").format(
|
||||
sql.Identifier('table'), sql.Identifier('field'))
|
||||
s1 = s.as_string(self.conn)
|
||||
self.assert_(isinstance(s1, str))
|
||||
self.assertEqual(s1, 'select "field" from "table"')
|
||||
|
||||
def test_dict(self):
|
||||
s = sql.SQL("select {f} from {t}").format(
|
||||
f=sql.Identifier('field'), t=sql.Identifier('table'))
|
||||
s1 = s.as_string(self.conn)
|
||||
self.assert_(isinstance(s1, str))
|
||||
self.assertEqual(s1, 'select "field" from "table"')
|
||||
|
||||
def test_compose_literal(self):
|
||||
s = sql.SQL("select {0};").format(sql.Literal(dt.date(2016, 12, 31)))
|
||||
s1 = s.as_string(self.conn)
|
||||
self.assertEqual(s1, "select '2016-12-31'::date;")
|
||||
|
||||
def test_compose_empty(self):
|
||||
s = sql.SQL("select foo;").format()
|
||||
s1 = s.as_string(self.conn)
|
||||
self.assertEqual(s1, "select foo;")
|
||||
|
||||
def test_percent_escape(self):
|
||||
s = sql.SQL("42 % {0}").format(sql.Literal(7))
|
||||
s1 = s.as_string(self.conn)
|
||||
self.assertEqual(s1, "42 % 7")
|
||||
|
||||
def test_braces_escape(self):
|
||||
s = sql.SQL("{{{0}}}").format(sql.Literal(7))
|
||||
self.assertEqual(s.as_string(self.conn), "{7}")
|
||||
s = sql.SQL("{{1,{0}}}").format(sql.Literal(7))
|
||||
self.assertEqual(s.as_string(self.conn), "{1,7}")
|
||||
|
||||
def test_compose_badnargs(self):
|
||||
self.assertRaises(IndexError, sql.SQL("select {0};").format)
|
||||
|
||||
def test_compose_badnargs_auto(self):
|
||||
self.assertRaises(IndexError, sql.SQL("select {};").format)
|
||||
self.assertRaises(ValueError, sql.SQL("select {} {1};").format, 10, 20)
|
||||
self.assertRaises(ValueError, sql.SQL("select {0} {};").format, 10, 20)
|
||||
|
||||
def test_compose_bad_args_type(self):
|
||||
self.assertRaises(IndexError, sql.SQL("select {0};").format, a=10)
|
||||
self.assertRaises(KeyError, sql.SQL("select {x};").format, 10)
|
||||
|
||||
def test_must_be_composable(self):
|
||||
self.assertRaises(TypeError, sql.SQL("select {0};").format, 'foo')
|
||||
self.assertRaises(TypeError, sql.SQL("select {0};").format, 10)
|
||||
|
||||
def test_no_modifiers(self):
|
||||
self.assertRaises(ValueError, sql.SQL("select {a!r};").format, a=10)
|
||||
self.assertRaises(ValueError, sql.SQL("select {a:<};").format, a=10)
|
||||
|
||||
def test_must_be_adaptable(self):
|
||||
class Foo:
|
||||
pass
|
||||
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
sql.SQL("select {0};").format(sql.Literal(Foo())).as_string, self.conn)
|
||||
|
||||
def test_execute(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("""
|
||||
create table test_compose (
|
||||
id serial primary key,
|
||||
foo text, bar text, "ba'z" text)
|
||||
""")
|
||||
cur.execute(
|
||||
sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format(
|
||||
sql.Identifier('test_compose'),
|
||||
sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])),
|
||||
(sql.Placeholder() * 3).join(', ')),
|
||||
(10, 'a', 'b', 'c'))
|
||||
|
||||
cur.execute("select * from test_compose")
|
||||
self.assertEqual(cur.fetchall(), [(10, 'a', 'b', 'c')])
|
||||
|
||||
def test_executemany(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("""
|
||||
create table test_compose (
|
||||
id serial primary key,
|
||||
foo text, bar text, "ba'z" text)
|
||||
""")
|
||||
cur.executemany(
|
||||
sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format(
|
||||
sql.Identifier('test_compose'),
|
||||
sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])),
|
||||
(sql.Placeholder() * 3).join(', ')),
|
||||
[(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')])
|
||||
|
||||
cur.execute("select * from test_compose")
|
||||
self.assertEqual(cur.fetchall(),
|
||||
[(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')])
|
||||
|
||||
@skip_if_crdb("copy")
|
||||
@skip_copy_if_green
|
||||
@skip_before_postgres(8, 2)
|
||||
def test_copy(self):
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("""
|
||||
create table test_compose (
|
||||
id serial primary key,
|
||||
foo text, bar text, "ba'z" text)
|
||||
""")
|
||||
|
||||
s = StringIO("10\ta\tb\tc\n20\td\te\tf\n")
|
||||
cur.copy_expert(
|
||||
sql.SQL("copy {t} (id, foo, bar, {f}) from stdin").format(
|
||||
t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")), s)
|
||||
|
||||
s1 = StringIO()
|
||||
cur.copy_expert(
|
||||
sql.SQL("copy (select {f} from {t} order by id) to stdout").format(
|
||||
t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")), s1)
|
||||
s1.seek(0)
|
||||
self.assertEqual(s1.read(), 'c\nf\n')
|
||||
|
||||
|
||||
class IdentifierTests(ConnectingTestCase):
|
||||
def test_class(self):
|
||||
self.assert_(issubclass(sql.Identifier, sql.Composable))
|
||||
|
||||
def test_init(self):
|
||||
self.assert_(isinstance(sql.Identifier('foo'), sql.Identifier))
|
||||
self.assert_(isinstance(sql.Identifier('foo'), sql.Identifier))
|
||||
self.assert_(isinstance(sql.Identifier('foo', 'bar', 'baz'), sql.Identifier))
|
||||
self.assertRaises(TypeError, sql.Identifier)
|
||||
self.assertRaises(TypeError, sql.Identifier, 10)
|
||||
self.assertRaises(TypeError, sql.Identifier, dt.date(2016, 12, 31))
|
||||
|
||||
def test_strings(self):
|
||||
self.assertEqual(sql.Identifier('foo').strings, ('foo',))
|
||||
self.assertEqual(sql.Identifier('foo', 'bar').strings, ('foo', 'bar'))
|
||||
|
||||
# Legacy method
|
||||
self.assertEqual(sql.Identifier('foo').string, 'foo')
|
||||
self.assertRaises(AttributeError,
|
||||
getattr, sql.Identifier('foo', 'bar'), 'string')
|
||||
|
||||
def test_repr(self):
|
||||
obj = sql.Identifier("fo'o")
|
||||
self.assertEqual(repr(obj), 'Identifier("fo\'o")')
|
||||
self.assertEqual(repr(obj), str(obj))
|
||||
|
||||
obj = sql.Identifier("fo'o", 'ba"r')
|
||||
self.assertEqual(repr(obj), 'Identifier("fo\'o", \'ba"r\')')
|
||||
self.assertEqual(repr(obj), str(obj))
|
||||
|
||||
def test_eq(self):
|
||||
self.assert_(sql.Identifier('foo') == sql.Identifier('foo'))
|
||||
self.assert_(sql.Identifier('foo', 'bar') == sql.Identifier('foo', 'bar'))
|
||||
self.assert_(sql.Identifier('foo') != sql.Identifier('bar'))
|
||||
self.assert_(sql.Identifier('foo') != 'foo')
|
||||
self.assert_(sql.Identifier('foo') != sql.SQL('foo'))
|
||||
|
||||
def test_as_str(self):
|
||||
self.assertEqual(
|
||||
sql.Identifier('foo').as_string(self.conn), '"foo"')
|
||||
self.assertEqual(
|
||||
sql.Identifier('foo', 'bar').as_string(self.conn), '"foo"."bar"')
|
||||
self.assertEqual(
|
||||
sql.Identifier("fo'o", 'ba"r').as_string(self.conn), '"fo\'o"."ba""r"')
|
||||
|
||||
def test_join(self):
|
||||
self.assert_(not hasattr(sql.Identifier('foo'), 'join'))
|
||||
|
||||
|
||||
class LiteralTests(ConnectingTestCase):
|
||||
def test_class(self):
|
||||
self.assert_(issubclass(sql.Literal, sql.Composable))
|
||||
|
||||
def test_init(self):
|
||||
self.assert_(isinstance(sql.Literal('foo'), sql.Literal))
|
||||
self.assert_(isinstance(sql.Literal('foo'), sql.Literal))
|
||||
self.assert_(isinstance(sql.Literal(b'foo'), sql.Literal))
|
||||
self.assert_(isinstance(sql.Literal(42), sql.Literal))
|
||||
self.assert_(isinstance(
|
||||
sql.Literal(dt.date(2016, 12, 31)), sql.Literal))
|
||||
|
||||
def test_wrapped(self):
|
||||
self.assertEqual(sql.Literal('foo').wrapped, 'foo')
|
||||
|
||||
def test_repr(self):
|
||||
self.assertEqual(repr(sql.Literal("foo")), "Literal('foo')")
|
||||
self.assertEqual(str(sql.Literal("foo")), "Literal('foo')")
|
||||
self.assertQuotedEqual(sql.Literal("foo").as_string(self.conn), "'foo'")
|
||||
self.assertEqual(sql.Literal(42).as_string(self.conn), "42")
|
||||
self.assertEqual(
|
||||
sql.Literal(dt.date(2017, 1, 1)).as_string(self.conn),
|
||||
"'2017-01-01'::date")
|
||||
|
||||
def test_eq(self):
|
||||
self.assert_(sql.Literal('foo') == sql.Literal('foo'))
|
||||
self.assert_(sql.Literal('foo') != sql.Literal('bar'))
|
||||
self.assert_(sql.Literal('foo') != 'foo')
|
||||
self.assert_(sql.Literal('foo') != sql.SQL('foo'))
|
||||
|
||||
def test_must_be_adaptable(self):
|
||||
class Foo:
|
||||
pass
|
||||
|
||||
self.assertRaises(psycopg2.ProgrammingError,
|
||||
sql.Literal(Foo()).as_string, self.conn)
|
||||
|
||||
|
||||
class SQLTests(ConnectingTestCase):
|
||||
def test_class(self):
|
||||
self.assert_(issubclass(sql.SQL, sql.Composable))
|
||||
|
||||
def test_init(self):
|
||||
self.assert_(isinstance(sql.SQL('foo'), sql.SQL))
|
||||
self.assert_(isinstance(sql.SQL('foo'), sql.SQL))
|
||||
self.assertRaises(TypeError, sql.SQL, 10)
|
||||
self.assertRaises(TypeError, sql.SQL, dt.date(2016, 12, 31))
|
||||
|
||||
def test_string(self):
|
||||
self.assertEqual(sql.SQL('foo').string, 'foo')
|
||||
|
||||
def test_repr(self):
|
||||
self.assertEqual(repr(sql.SQL("foo")), "SQL('foo')")
|
||||
self.assertEqual(str(sql.SQL("foo")), "SQL('foo')")
|
||||
self.assertEqual(sql.SQL("foo").as_string(self.conn), "foo")
|
||||
|
||||
def test_eq(self):
|
||||
self.assert_(sql.SQL('foo') == sql.SQL('foo'))
|
||||
self.assert_(sql.SQL('foo') != sql.SQL('bar'))
|
||||
self.assert_(sql.SQL('foo') != 'foo')
|
||||
self.assert_(sql.SQL('foo') != sql.Literal('foo'))
|
||||
|
||||
def test_sum(self):
|
||||
obj = sql.SQL("foo") + sql.SQL("bar")
|
||||
self.assert_(isinstance(obj, sql.Composed))
|
||||
self.assertEqual(obj.as_string(self.conn), "foobar")
|
||||
|
||||
def test_sum_inplace(self):
|
||||
obj = sql.SQL("foo")
|
||||
obj += sql.SQL("bar")
|
||||
self.assert_(isinstance(obj, sql.Composed))
|
||||
self.assertEqual(obj.as_string(self.conn), "foobar")
|
||||
|
||||
def test_multiply(self):
|
||||
obj = sql.SQL("foo") * 3
|
||||
self.assert_(isinstance(obj, sql.Composed))
|
||||
self.assertEqual(obj.as_string(self.conn), "foofoofoo")
|
||||
|
||||
def test_join(self):
|
||||
obj = sql.SQL(", ").join(
|
||||
[sql.Identifier('foo'), sql.SQL('bar'), sql.Literal(42)])
|
||||
self.assert_(isinstance(obj, sql.Composed))
|
||||
self.assertEqual(obj.as_string(self.conn), '"foo", bar, 42')
|
||||
|
||||
obj = sql.SQL(", ").join(
|
||||
sql.Composed([sql.Identifier('foo'), sql.SQL('bar'), sql.Literal(42)]))
|
||||
self.assert_(isinstance(obj, sql.Composed))
|
||||
self.assertEqual(obj.as_string(self.conn), '"foo", bar, 42')
|
||||
|
||||
obj = sql.SQL(", ").join([])
|
||||
self.assertEqual(obj, sql.Composed([]))
|
||||
|
||||
|
||||
class ComposedTest(ConnectingTestCase):
|
||||
def test_class(self):
|
||||
self.assert_(issubclass(sql.Composed, sql.Composable))
|
||||
|
||||
def test_repr(self):
|
||||
obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
|
||||
self.assertEqual(repr(obj),
|
||||
"""Composed([Literal('foo'), Identifier("b'ar")])""")
|
||||
self.assertEqual(str(obj), repr(obj))
|
||||
|
||||
def test_seq(self):
|
||||
l = [sql.SQL('foo'), sql.Literal('bar'), sql.Identifier('baz')]
|
||||
self.assertEqual(sql.Composed(l).seq, l)
|
||||
|
||||
def test_eq(self):
|
||||
l = [sql.Literal("foo"), sql.Identifier("b'ar")]
|
||||
l2 = [sql.Literal("foo"), sql.Literal("b'ar")]
|
||||
self.assert_(sql.Composed(l) == sql.Composed(list(l)))
|
||||
self.assert_(sql.Composed(l) != l)
|
||||
self.assert_(sql.Composed(l) != sql.Composed(l2))
|
||||
|
||||
def test_join(self):
|
||||
obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
|
||||
obj = obj.join(", ")
|
||||
self.assert_(isinstance(obj, sql.Composed))
|
||||
self.assertQuotedEqual(obj.as_string(self.conn), "'foo', \"b'ar\"")
|
||||
|
||||
def test_sum(self):
|
||||
obj = sql.Composed([sql.SQL("foo ")])
|
||||
obj = obj + sql.Literal("bar")
|
||||
self.assert_(isinstance(obj, sql.Composed))
|
||||
self.assertQuotedEqual(obj.as_string(self.conn), "foo 'bar'")
|
||||
|
||||
def test_sum_inplace(self):
|
||||
obj = sql.Composed([sql.SQL("foo ")])
|
||||
obj += sql.Literal("bar")
|
||||
self.assert_(isinstance(obj, sql.Composed))
|
||||
self.assertQuotedEqual(obj.as_string(self.conn), "foo 'bar'")
|
||||
|
||||
obj = sql.Composed([sql.SQL("foo ")])
|
||||
obj += sql.Composed([sql.Literal("bar")])
|
||||
self.assert_(isinstance(obj, sql.Composed))
|
||||
self.assertQuotedEqual(obj.as_string(self.conn), "foo 'bar'")
|
||||
|
||||
def test_iter(self):
|
||||
obj = sql.Composed([sql.SQL("foo"), sql.SQL('bar')])
|
||||
it = iter(obj)
|
||||
i = next(it)
|
||||
self.assertEqual(i, sql.SQL('foo'))
|
||||
i = next(it)
|
||||
self.assertEqual(i, sql.SQL('bar'))
|
||||
self.assertRaises(StopIteration, next, it)
|
||||
|
||||
|
||||
class PlaceholderTest(ConnectingTestCase):
|
||||
def test_class(self):
|
||||
self.assert_(issubclass(sql.Placeholder, sql.Composable))
|
||||
|
||||
def test_name(self):
|
||||
self.assertEqual(sql.Placeholder().name, None)
|
||||
self.assertEqual(sql.Placeholder('foo').name, 'foo')
|
||||
|
||||
def test_repr(self):
|
||||
self.assert_(str(sql.Placeholder()), 'Placeholder()')
|
||||
self.assert_(repr(sql.Placeholder()), 'Placeholder()')
|
||||
self.assert_(sql.Placeholder().as_string(self.conn), '%s')
|
||||
|
||||
def test_repr_name(self):
|
||||
self.assert_(str(sql.Placeholder('foo')), "Placeholder('foo')")
|
||||
self.assert_(repr(sql.Placeholder('foo')), "Placeholder('foo')")
|
||||
self.assert_(sql.Placeholder('foo').as_string(self.conn), '%(foo)s')
|
||||
|
||||
def test_bad_name(self):
|
||||
self.assertRaises(ValueError, sql.Placeholder, ')')
|
||||
|
||||
def test_eq(self):
|
||||
self.assert_(sql.Placeholder('foo') == sql.Placeholder('foo'))
|
||||
self.assert_(sql.Placeholder('foo') != sql.Placeholder('bar'))
|
||||
self.assert_(sql.Placeholder('foo') != 'foo')
|
||||
self.assert_(sql.Placeholder() == sql.Placeholder())
|
||||
self.assert_(sql.Placeholder('foo') != sql.Placeholder())
|
||||
self.assert_(sql.Placeholder('foo') != sql.Literal('foo'))
|
||||
|
||||
|
||||
class ValuesTest(ConnectingTestCase):
|
||||
def test_null(self):
|
||||
self.assert_(isinstance(sql.NULL, sql.SQL))
|
||||
self.assertEqual(sql.NULL.as_string(self.conn), "NULL")
|
||||
|
||||
def test_default(self):
|
||||
self.assert_(isinstance(sql.DEFAULT, sql.SQL))
|
||||
self.assertEqual(sql.DEFAULT.as_string(self.conn), "DEFAULT")
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
258
tests/test_transaction.py
Executable file
258
tests/test_transaction.py
Executable file
@ -0,0 +1,258 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# test_transaction - unit test on transaction behaviour
|
||||
#
|
||||
# Copyright (C) 2007-2019 Federico Di Gregorio <fog@debian.org>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import threading
|
||||
import unittest
|
||||
from .testutils import ConnectingTestCase, skip_before_postgres, slow
|
||||
from .testutils import skip_if_crdb
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extensions import (
|
||||
ISOLATION_LEVEL_SERIALIZABLE, STATUS_BEGIN, STATUS_READY)
|
||||
|
||||
|
||||
class TransactionTests(ConnectingTestCase):
|
||||
|
||||
def setUp(self):
|
||||
ConnectingTestCase.setUp(self)
|
||||
skip_if_crdb("isolation level", self.conn)
|
||||
self.conn.set_isolation_level(ISOLATION_LEVEL_SERIALIZABLE)
|
||||
curs = self.conn.cursor()
|
||||
curs.execute('''
|
||||
CREATE TEMPORARY TABLE table1 (
|
||||
id int PRIMARY KEY
|
||||
)''')
|
||||
# The constraint is set to deferrable for the commit_failed test
|
||||
curs.execute('''
|
||||
CREATE TEMPORARY TABLE table2 (
|
||||
id int PRIMARY KEY,
|
||||
table1_id int,
|
||||
CONSTRAINT table2__table1_id__fk
|
||||
FOREIGN KEY (table1_id) REFERENCES table1(id) DEFERRABLE)''')
|
||||
curs.execute('INSERT INTO table1 VALUES (1)')
|
||||
curs.execute('INSERT INTO table2 VALUES (1, 1)')
|
||||
self.conn.commit()
|
||||
|
||||
def test_rollback(self):
|
||||
# Test that rollback undoes changes
|
||||
curs = self.conn.cursor()
|
||||
curs.execute('INSERT INTO table2 VALUES (2, 1)')
|
||||
# Rollback takes us from BEGIN state to READY state
|
||||
self.assertEqual(self.conn.status, STATUS_BEGIN)
|
||||
self.conn.rollback()
|
||||
self.assertEqual(self.conn.status, STATUS_READY)
|
||||
curs.execute('SELECT id, table1_id FROM table2 WHERE id = 2')
|
||||
self.assertEqual(curs.fetchall(), [])
|
||||
|
||||
def test_commit(self):
|
||||
# Test that commit stores changes
|
||||
curs = self.conn.cursor()
|
||||
curs.execute('INSERT INTO table2 VALUES (2, 1)')
|
||||
# Rollback takes us from BEGIN state to READY state
|
||||
self.assertEqual(self.conn.status, STATUS_BEGIN)
|
||||
self.conn.commit()
|
||||
self.assertEqual(self.conn.status, STATUS_READY)
|
||||
# Now rollback and show that the new record is still there:
|
||||
self.conn.rollback()
|
||||
curs.execute('SELECT id, table1_id FROM table2 WHERE id = 2')
|
||||
self.assertEqual(curs.fetchall(), [(2, 1)])
|
||||
|
||||
def test_failed_commit(self):
|
||||
# Test that we can recover from a failed commit.
|
||||
# We use a deferred constraint to cause a failure on commit.
|
||||
curs = self.conn.cursor()
|
||||
curs.execute('SET CONSTRAINTS table2__table1_id__fk DEFERRED')
|
||||
curs.execute('INSERT INTO table2 VALUES (2, 42)')
|
||||
# The commit should fail, and move the cursor back to READY state
|
||||
self.assertEqual(self.conn.status, STATUS_BEGIN)
|
||||
self.assertRaises(psycopg2.IntegrityError, self.conn.commit)
|
||||
self.assertEqual(self.conn.status, STATUS_READY)
|
||||
# The connection should be ready to use for the next transaction:
|
||||
curs.execute('SELECT 1')
|
||||
self.assertEqual(curs.fetchone()[0], 1)
|
||||
|
||||
|
||||
class DeadlockSerializationTests(ConnectingTestCase):
|
||||
"""Test deadlock and serialization failure errors."""
|
||||
|
||||
def connect(self):
|
||||
conn = ConnectingTestCase.connect(self)
|
||||
conn.set_isolation_level(ISOLATION_LEVEL_SERIALIZABLE)
|
||||
return conn
|
||||
|
||||
def setUp(self):
|
||||
ConnectingTestCase.setUp(self)
|
||||
skip_if_crdb("isolation level", self.conn)
|
||||
|
||||
curs = self.conn.cursor()
|
||||
# Drop table if it already exists
|
||||
try:
|
||||
curs.execute("DROP TABLE table1")
|
||||
self.conn.commit()
|
||||
except psycopg2.DatabaseError:
|
||||
self.conn.rollback()
|
||||
try:
|
||||
curs.execute("DROP TABLE table2")
|
||||
self.conn.commit()
|
||||
except psycopg2.DatabaseError:
|
||||
self.conn.rollback()
|
||||
# Create sample data
|
||||
curs.execute("""
|
||||
CREATE TABLE table1 (
|
||||
id int PRIMARY KEY,
|
||||
name text)
|
||||
""")
|
||||
curs.execute("INSERT INTO table1 VALUES (1, 'hello')")
|
||||
curs.execute("CREATE TABLE table2 (id int PRIMARY KEY)")
|
||||
self.conn.commit()
|
||||
|
||||
def tearDown(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("DROP TABLE table1")
|
||||
curs.execute("DROP TABLE table2")
|
||||
self.conn.commit()
|
||||
|
||||
ConnectingTestCase.tearDown(self)
|
||||
|
||||
@slow
|
||||
def test_deadlock(self):
|
||||
self.thread1_error = self.thread2_error = None
|
||||
step1 = threading.Event()
|
||||
step2 = threading.Event()
|
||||
|
||||
def task1():
|
||||
try:
|
||||
conn = self.connect()
|
||||
curs = conn.cursor()
|
||||
curs.execute("LOCK table1 IN ACCESS EXCLUSIVE MODE")
|
||||
step1.set()
|
||||
step2.wait()
|
||||
curs.execute("LOCK table2 IN ACCESS EXCLUSIVE MODE")
|
||||
except psycopg2.DatabaseError as exc:
|
||||
self.thread1_error = exc
|
||||
step1.set()
|
||||
conn.close()
|
||||
|
||||
def task2():
|
||||
try:
|
||||
conn = self.connect()
|
||||
curs = conn.cursor()
|
||||
step1.wait()
|
||||
curs.execute("LOCK table2 IN ACCESS EXCLUSIVE MODE")
|
||||
step2.set()
|
||||
curs.execute("LOCK table1 IN ACCESS EXCLUSIVE MODE")
|
||||
except psycopg2.DatabaseError as exc:
|
||||
self.thread2_error = exc
|
||||
step2.set()
|
||||
conn.close()
|
||||
|
||||
# Run the threads in parallel. The "step1" and "step2" events
|
||||
# ensure that the two transactions overlap.
|
||||
thread1 = threading.Thread(target=task1)
|
||||
thread2 = threading.Thread(target=task2)
|
||||
thread1.start()
|
||||
thread2.start()
|
||||
thread1.join()
|
||||
thread2.join()
|
||||
|
||||
# Exactly one of the threads should have failed with
|
||||
# TransactionRollbackError:
|
||||
self.assertFalse(self.thread1_error and self.thread2_error)
|
||||
error = self.thread1_error or self.thread2_error
|
||||
self.assertTrue(isinstance(
|
||||
error, psycopg2.extensions.TransactionRollbackError))
|
||||
|
||||
@slow
|
||||
def test_serialisation_failure(self):
|
||||
self.thread1_error = self.thread2_error = None
|
||||
step1 = threading.Event()
|
||||
step2 = threading.Event()
|
||||
|
||||
def task1():
|
||||
try:
|
||||
conn = self.connect()
|
||||
curs = conn.cursor()
|
||||
curs.execute("SELECT name FROM table1 WHERE id = 1")
|
||||
curs.fetchall()
|
||||
step1.set()
|
||||
step2.wait()
|
||||
curs.execute("UPDATE table1 SET name='task1' WHERE id = 1")
|
||||
conn.commit()
|
||||
except psycopg2.DatabaseError as exc:
|
||||
self.thread1_error = exc
|
||||
step1.set()
|
||||
conn.close()
|
||||
|
||||
def task2():
|
||||
try:
|
||||
conn = self.connect()
|
||||
curs = conn.cursor()
|
||||
step1.wait()
|
||||
curs.execute("UPDATE table1 SET name='task2' WHERE id = 1")
|
||||
conn.commit()
|
||||
except psycopg2.DatabaseError as exc:
|
||||
self.thread2_error = exc
|
||||
step2.set()
|
||||
conn.close()
|
||||
|
||||
# Run the threads in parallel. The "step1" and "step2" events
|
||||
# ensure that the two transactions overlap.
|
||||
thread1 = threading.Thread(target=task1)
|
||||
thread2 = threading.Thread(target=task2)
|
||||
thread1.start()
|
||||
thread2.start()
|
||||
thread1.join()
|
||||
thread2.join()
|
||||
|
||||
# Exactly one of the threads should have failed with
|
||||
# TransactionRollbackError:
|
||||
self.assertFalse(self.thread1_error and self.thread2_error)
|
||||
error = self.thread1_error or self.thread2_error
|
||||
self.assertTrue(isinstance(
|
||||
error, psycopg2.extensions.TransactionRollbackError))
|
||||
|
||||
|
||||
class QueryCancellationTests(ConnectingTestCase):
|
||||
"""Tests for query cancellation."""
|
||||
|
||||
def setUp(self):
|
||||
ConnectingTestCase.setUp(self)
|
||||
self.conn.set_isolation_level(ISOLATION_LEVEL_SERIALIZABLE)
|
||||
|
||||
@skip_before_postgres(8, 2)
|
||||
def test_statement_timeout(self):
|
||||
curs = self.conn.cursor()
|
||||
# Set a low statement timeout, then sleep for a longer period.
|
||||
curs.execute('SET statement_timeout TO 10')
|
||||
self.assertRaises(psycopg2.extensions.QueryCanceledError,
|
||||
curs.execute, 'SELECT pg_sleep(50)')
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
492
tests/test_types_basic.py
Executable file
492
tests/test_types_basic.py
Executable file
@ -0,0 +1,492 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# types_basic.py - tests for basic types conversions
|
||||
#
|
||||
# Copyright (C) 2004-2019 Federico Di Gregorio <fog@debian.org>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import string
|
||||
import ctypes
|
||||
import decimal
|
||||
import datetime
|
||||
import platform
|
||||
|
||||
from . import testutils
|
||||
import unittest
|
||||
from .testutils import ConnectingTestCase, restore_types
|
||||
from .testutils import skip_if_crdb
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extensions import AsIs, adapt, register_adapter
|
||||
|
||||
|
||||
class TypesBasicTests(ConnectingTestCase):
|
||||
"""Test that all type conversions are working."""
|
||||
|
||||
def execute(self, *args):
|
||||
curs = self.conn.cursor()
|
||||
curs.execute(*args)
|
||||
return curs.fetchone()[0]
|
||||
|
||||
def testQuoting(self):
|
||||
s = "Quote'this\\! ''ok?''"
|
||||
self.failUnless(self.execute("SELECT %s AS foo", (s,)) == s,
|
||||
"wrong quoting: " + s)
|
||||
|
||||
def testUnicode(self):
|
||||
s = "Quote'this\\! ''ok?''"
|
||||
self.failUnless(self.execute("SELECT %s AS foo", (s,)) == s,
|
||||
"wrong unicode quoting: " + s)
|
||||
|
||||
def testNumber(self):
|
||||
s = self.execute("SELECT %s AS foo", (1971,))
|
||||
self.failUnless(s == 1971, "wrong integer quoting: " + str(s))
|
||||
|
||||
def testBoolean(self):
|
||||
x = self.execute("SELECT %s as foo", (False,))
|
||||
self.assert_(x is False)
|
||||
x = self.execute("SELECT %s as foo", (True,))
|
||||
self.assert_(x is True)
|
||||
|
||||
def testDecimal(self):
|
||||
s = self.execute("SELECT %s AS foo", (decimal.Decimal("19.10"),))
|
||||
self.failUnless(s - decimal.Decimal("19.10") == 0,
|
||||
"wrong decimal quoting: " + str(s))
|
||||
s = self.execute("SELECT %s AS foo", (decimal.Decimal("NaN"),))
|
||||
self.failUnless(str(s) == "NaN", "wrong decimal quoting: " + str(s))
|
||||
self.failUnless(type(s) == decimal.Decimal,
|
||||
"wrong decimal conversion: " + repr(s))
|
||||
s = self.execute("SELECT %s AS foo", (decimal.Decimal("infinity"),))
|
||||
self.failUnless(str(s) == "NaN", "wrong decimal quoting: " + str(s))
|
||||
self.failUnless(type(s) == decimal.Decimal,
|
||||
"wrong decimal conversion: " + repr(s))
|
||||
s = self.execute("SELECT %s AS foo", (decimal.Decimal("-infinity"),))
|
||||
self.failUnless(str(s) == "NaN", "wrong decimal quoting: " + str(s))
|
||||
self.failUnless(type(s) == decimal.Decimal,
|
||||
"wrong decimal conversion: " + repr(s))
|
||||
|
||||
def testFloatNan(self):
|
||||
try:
|
||||
float("nan")
|
||||
except ValueError:
|
||||
return self.skipTest("nan not available on this platform")
|
||||
|
||||
s = self.execute("SELECT %s AS foo", (float("nan"),))
|
||||
self.failUnless(str(s) == "nan", "wrong float quoting: " + str(s))
|
||||
self.failUnless(type(s) == float, "wrong float conversion: " + repr(s))
|
||||
|
||||
def testFloatInf(self):
|
||||
try:
|
||||
self.execute("select 'inf'::float")
|
||||
except psycopg2.DataError:
|
||||
return self.skipTest("inf::float not available on the server")
|
||||
except ValueError:
|
||||
return self.skipTest("inf not available on this platform")
|
||||
s = self.execute("SELECT %s AS foo", (float("inf"),))
|
||||
self.failUnless(str(s) == "inf", "wrong float quoting: " + str(s))
|
||||
self.failUnless(type(s) == float, "wrong float conversion: " + repr(s))
|
||||
|
||||
s = self.execute("SELECT %s AS foo", (float("-inf"),))
|
||||
self.failUnless(str(s) == "-inf", "wrong float quoting: " + str(s))
|
||||
|
||||
def testBinary(self):
|
||||
s = bytes(range(256))
|
||||
b = psycopg2.Binary(s)
|
||||
buf = self.execute("SELECT %s::bytea AS foo", (b,))
|
||||
self.assertEqual(s, buf.tobytes())
|
||||
|
||||
def testBinaryNone(self):
|
||||
b = psycopg2.Binary(None)
|
||||
buf = self.execute("SELECT %s::bytea AS foo", (b,))
|
||||
self.assertEqual(buf, None)
|
||||
|
||||
def testBinaryEmptyString(self):
|
||||
# test to make sure an empty Binary is converted to an empty string
|
||||
b = psycopg2.Binary(bytes([]))
|
||||
self.assertEqual(str(b), "''::bytea")
|
||||
|
||||
def testBinaryRoundTrip(self):
|
||||
# test to make sure buffers returned by psycopg2 are
|
||||
# understood by execute:
|
||||
s = bytes(range(256))
|
||||
buf = self.execute("SELECT %s::bytea AS foo", (psycopg2.Binary(s),))
|
||||
buf2 = self.execute("SELECT %s::bytea AS foo", (buf,))
|
||||
self.assertEqual(s, buf2.tobytes())
|
||||
|
||||
@skip_if_crdb("nested array")
|
||||
def testArray(self):
|
||||
s = self.execute("SELECT %s AS foo", ([[1, 2], [3, 4]],))
|
||||
self.failUnlessEqual(s, [[1, 2], [3, 4]])
|
||||
s = self.execute("SELECT %s AS foo", (['one', 'two', 'three'],))
|
||||
self.failUnlessEqual(s, ['one', 'two', 'three'])
|
||||
|
||||
@skip_if_crdb("nested array")
|
||||
def testEmptyArrayRegression(self):
|
||||
# ticket #42
|
||||
curs = self.conn.cursor()
|
||||
curs.execute(
|
||||
"create table array_test "
|
||||
"(id integer, col timestamp without time zone[])")
|
||||
|
||||
curs.execute("insert into array_test values (%s, %s)",
|
||||
(1, [datetime.date(2011, 2, 14)]))
|
||||
curs.execute("select col from array_test where id = 1")
|
||||
self.assertEqual(curs.fetchone()[0], [datetime.datetime(2011, 2, 14, 0, 0)])
|
||||
|
||||
curs.execute("insert into array_test values (%s, %s)", (2, []))
|
||||
curs.execute("select col from array_test where id = 2")
|
||||
self.assertEqual(curs.fetchone()[0], [])
|
||||
|
||||
@skip_if_crdb("nested array")
|
||||
@testutils.skip_before_postgres(8, 4)
|
||||
def testNestedEmptyArray(self):
|
||||
# issue #788
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select 10 = any(%s::int[])", ([[]], ))
|
||||
self.assertFalse(curs.fetchone()[0])
|
||||
|
||||
def testEmptyArrayNoCast(self):
|
||||
s = self.execute("SELECT '{}' AS foo")
|
||||
self.assertEqual(s, '{}')
|
||||
s = self.execute("SELECT %s AS foo", ([],))
|
||||
self.assertEqual(s, '{}')
|
||||
|
||||
def testEmptyArray(self):
|
||||
s = self.execute("SELECT '{}'::text[] AS foo")
|
||||
self.failUnlessEqual(s, [])
|
||||
s = self.execute("SELECT 1 != ALL(%s)", ([],))
|
||||
self.failUnlessEqual(s, True)
|
||||
# but don't break the strings :)
|
||||
s = self.execute("SELECT '{}'::text AS foo")
|
||||
self.failUnlessEqual(s, "{}")
|
||||
|
||||
def testArrayEscape(self):
|
||||
ss = ['', '\\', '"', '\\\\', '\\"']
|
||||
for s in ss:
|
||||
r = self.execute("SELECT %s AS foo", (s,))
|
||||
self.failUnlessEqual(s, r)
|
||||
r = self.execute("SELECT %s AS foo", ([s],))
|
||||
self.failUnlessEqual([s], r)
|
||||
|
||||
r = self.execute("SELECT %s AS foo", (ss,))
|
||||
self.failUnlessEqual(ss, r)
|
||||
|
||||
def testArrayMalformed(self):
|
||||
curs = self.conn.cursor()
|
||||
ss = ['', '{', '{}}', '{' * 20 + '}' * 20]
|
||||
for s in ss:
|
||||
self.assertRaises(psycopg2.DataError,
|
||||
psycopg2.extensions.STRINGARRAY, s.encode('utf8'), curs)
|
||||
|
||||
def testTextArray(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select '{a,b,c}'::text[]")
|
||||
x = curs.fetchone()[0]
|
||||
self.assert_(isinstance(x[0], str))
|
||||
self.assertEqual(x, ['a', 'b', 'c'])
|
||||
|
||||
def testUnicodeArray(self):
|
||||
psycopg2.extensions.register_type(
|
||||
psycopg2.extensions.UNICODEARRAY, self.conn)
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select '{a,b,c}'::text[]")
|
||||
x = curs.fetchone()[0]
|
||||
self.assert_(isinstance(x[0], str))
|
||||
self.assertEqual(x, ['a', 'b', 'c'])
|
||||
|
||||
def testBytesArray(self):
|
||||
psycopg2.extensions.register_type(
|
||||
psycopg2.extensions.BYTESARRAY, self.conn)
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select '{a,b,c}'::text[]")
|
||||
x = curs.fetchone()[0]
|
||||
self.assert_(isinstance(x[0], bytes))
|
||||
self.assertEqual(x, [b'a', b'b', b'c'])
|
||||
|
||||
@skip_if_crdb("nested array")
|
||||
@testutils.skip_before_postgres(8, 2)
|
||||
def testArrayOfNulls(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("""
|
||||
create table na (
|
||||
texta text[],
|
||||
inta int[],
|
||||
boola boolean[],
|
||||
|
||||
textaa text[][],
|
||||
intaa int[][],
|
||||
boolaa boolean[][]
|
||||
)""")
|
||||
|
||||
curs.execute("insert into na (texta) values (%s)", ([None],))
|
||||
curs.execute("insert into na (texta) values (%s)", (['a', None],))
|
||||
curs.execute("insert into na (texta) values (%s)", ([None, None],))
|
||||
curs.execute("insert into na (inta) values (%s)", ([None],))
|
||||
curs.execute("insert into na (inta) values (%s)", ([42, None],))
|
||||
curs.execute("insert into na (inta) values (%s)", ([None, None],))
|
||||
curs.execute("insert into na (boola) values (%s)", ([None],))
|
||||
curs.execute("insert into na (boola) values (%s)", ([True, None],))
|
||||
curs.execute("insert into na (boola) values (%s)", ([None, None],))
|
||||
|
||||
curs.execute("insert into na (textaa) values (%s)", ([[None]],))
|
||||
curs.execute("insert into na (textaa) values (%s)", ([['a', None]],))
|
||||
curs.execute("insert into na (textaa) values (%s)", ([[None, None]],))
|
||||
|
||||
curs.execute("insert into na (intaa) values (%s)", ([[None]],))
|
||||
curs.execute("insert into na (intaa) values (%s)", ([[42, None]],))
|
||||
curs.execute("insert into na (intaa) values (%s)", ([[None, None]],))
|
||||
|
||||
curs.execute("insert into na (boolaa) values (%s)", ([[None]],))
|
||||
curs.execute("insert into na (boolaa) values (%s)", ([[True, None]],))
|
||||
curs.execute("insert into na (boolaa) values (%s)", ([[None, None]],))
|
||||
|
||||
@skip_if_crdb("nested array")
|
||||
@testutils.skip_before_postgres(8, 2)
|
||||
def testNestedArrays(self):
|
||||
curs = self.conn.cursor()
|
||||
for a in [
|
||||
[[1]],
|
||||
[[None]],
|
||||
[[None, None, None]],
|
||||
[[None, None], [1, None]],
|
||||
[[None, None], [None, None]],
|
||||
[[[None, None], [None, None]]],
|
||||
]:
|
||||
curs.execute("select %s::int[]", (a,))
|
||||
self.assertEqual(curs.fetchone()[0], a)
|
||||
|
||||
def testTypeRoundtripBytes(self):
|
||||
o1 = bytes(range(256))
|
||||
o2 = self.execute("select %s;", (o1,))
|
||||
self.assertEqual(memoryview, type(o2))
|
||||
|
||||
# Test with an empty buffer
|
||||
o1 = bytes([])
|
||||
o2 = self.execute("select %s;", (o1,))
|
||||
self.assertEqual(memoryview, type(o2))
|
||||
|
||||
def testTypeRoundtripBytesArray(self):
|
||||
o1 = bytes(range(256))
|
||||
o1 = [o1]
|
||||
o2 = self.execute("select %s;", (o1,))
|
||||
self.assertEqual(memoryview, type(o2[0]))
|
||||
|
||||
def testAdaptBytearray(self):
|
||||
o1 = bytearray(range(256))
|
||||
o2 = self.execute("select %s;", (o1,))
|
||||
self.assertEqual(memoryview, type(o2))
|
||||
self.assertEqual(len(o1), len(o2))
|
||||
for c1, c2 in zip(o1, o2):
|
||||
self.assertEqual(c1, ord(c2))
|
||||
|
||||
# Test with an empty buffer
|
||||
o1 = bytearray([])
|
||||
o2 = self.execute("select %s;", (o1,))
|
||||
self.assertEqual(len(o2), 0)
|
||||
self.assertEqual(memoryview, type(o2))
|
||||
|
||||
def testAdaptMemoryview(self):
|
||||
o1 = memoryview(bytearray(range(256)))
|
||||
o2 = self.execute("select %s;", (o1,))
|
||||
self.assertEqual(memoryview, type(o2))
|
||||
|
||||
# Test with an empty buffer
|
||||
o1 = memoryview(bytearray([]))
|
||||
o2 = self.execute("select %s;", (o1,))
|
||||
self.assertEqual(memoryview, type(o2))
|
||||
|
||||
def testByteaHexCheckFalsePositive(self):
|
||||
# the check \x -> x to detect bad bytea decode
|
||||
# may be fooled if the first char is really an 'x'
|
||||
o1 = psycopg2.Binary(b'x')
|
||||
o2 = self.execute("SELECT %s::bytea AS foo", (o1,))
|
||||
self.assertEqual(b'x', o2[0])
|
||||
|
||||
def testNegNumber(self):
|
||||
d1 = self.execute("select -%s;", (decimal.Decimal('-1.0'),))
|
||||
self.assertEqual(1, d1)
|
||||
f1 = self.execute("select -%s;", (-1.0,))
|
||||
self.assertEqual(1, f1)
|
||||
i1 = self.execute("select -%s;", (-1,))
|
||||
self.assertEqual(1, i1)
|
||||
|
||||
def testGenericArray(self):
|
||||
a = self.execute("select '{1, 2, 3}'::int4[]")
|
||||
self.assertEqual(a, [1, 2, 3])
|
||||
a = self.execute("select array['a', 'b', '''']::text[]")
|
||||
self.assertEqual(a, ['a', 'b', "'"])
|
||||
|
||||
@testutils.skip_before_postgres(8, 2)
|
||||
def testGenericArrayNull(self):
|
||||
def caster(s, cur):
|
||||
if s is None:
|
||||
return "nada"
|
||||
return int(s) * 2
|
||||
base = psycopg2.extensions.new_type((23,), "INT4", caster)
|
||||
array = psycopg2.extensions.new_array_type((1007,), "INT4ARRAY", base)
|
||||
|
||||
psycopg2.extensions.register_type(array, self.conn)
|
||||
a = self.execute("select '{1, 2, 3}'::int4[]")
|
||||
self.assertEqual(a, [2, 4, 6])
|
||||
a = self.execute("select '{1, 2, NULL}'::int4[]")
|
||||
self.assertEqual(a, [2, 4, 'nada'])
|
||||
|
||||
@skip_if_crdb("cidr")
|
||||
@testutils.skip_before_postgres(8, 2)
|
||||
def testNetworkArray(self):
|
||||
# we don't know these types, but we know their arrays
|
||||
a = self.execute("select '{192.168.0.1/24}'::inet[]")
|
||||
self.assertEqual(a, ['192.168.0.1/24'])
|
||||
a = self.execute("select '{192.168.0.0/24}'::cidr[]")
|
||||
self.assertEqual(a, ['192.168.0.0/24'])
|
||||
a = self.execute("select '{10:20:30:40:50:60}'::macaddr[]")
|
||||
self.assertEqual(a, ['10:20:30:40:50:60'])
|
||||
|
||||
def testIntEnum(self):
|
||||
from enum import IntEnum
|
||||
|
||||
class Color(IntEnum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
BLUE = 4
|
||||
|
||||
a = self.execute("select %s", (Color.GREEN,))
|
||||
self.assertEqual(a, Color.GREEN)
|
||||
|
||||
|
||||
class AdaptSubclassTest(unittest.TestCase):
|
||||
def test_adapt_subtype(self):
|
||||
class Sub(str):
|
||||
pass
|
||||
s1 = "hel'lo"
|
||||
s2 = Sub(s1)
|
||||
self.assertEqual(adapt(s1).getquoted(), adapt(s2).getquoted())
|
||||
|
||||
@restore_types
|
||||
def test_adapt_most_specific(self):
|
||||
class A:
|
||||
pass
|
||||
|
||||
class B(A):
|
||||
pass
|
||||
|
||||
class C(B):
|
||||
pass
|
||||
|
||||
register_adapter(A, lambda a: AsIs("a"))
|
||||
register_adapter(B, lambda b: AsIs("b"))
|
||||
self.assertEqual(b'b', adapt(C()).getquoted())
|
||||
|
||||
@restore_types
|
||||
def test_adapt_subtype_3(self):
|
||||
class A:
|
||||
pass
|
||||
|
||||
class B(A):
|
||||
pass
|
||||
|
||||
register_adapter(A, lambda a: AsIs("a"))
|
||||
self.assertEqual(b"a", adapt(B()).getquoted())
|
||||
|
||||
def test_conform_subclass_precedence(self):
|
||||
class foo(tuple):
|
||||
def __conform__(self, proto):
|
||||
return self
|
||||
|
||||
def getquoted(self):
|
||||
return 'bar'
|
||||
|
||||
self.assertEqual(adapt(foo((1, 2, 3))).getquoted(), 'bar')
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
platform.system() == 'Windows',
|
||||
"Not testing because we are useless with ctypes on Windows")
|
||||
class ByteaParserTest(unittest.TestCase):
|
||||
"""Unit test for our bytea format parser."""
|
||||
def setUp(self):
|
||||
self._cast = self._import_cast()
|
||||
|
||||
def _import_cast(self):
|
||||
"""Use ctypes to access the C function."""
|
||||
lib = ctypes.pydll.LoadLibrary(psycopg2._psycopg.__file__)
|
||||
cast = lib.typecast_BINARY_cast
|
||||
cast.argtypes = [ctypes.c_char_p, ctypes.c_size_t, ctypes.py_object]
|
||||
cast.restype = ctypes.py_object
|
||||
return cast
|
||||
|
||||
def cast(self, buffer):
|
||||
"""Cast a buffer from the output format"""
|
||||
l = buffer and len(buffer) or 0
|
||||
rv = self._cast(buffer, l, None)
|
||||
|
||||
if rv is None:
|
||||
return None
|
||||
|
||||
return rv.tobytes()
|
||||
|
||||
def test_null(self):
|
||||
rv = self.cast(None)
|
||||
self.assertEqual(rv, None)
|
||||
|
||||
def test_blank(self):
|
||||
rv = self.cast(b'')
|
||||
self.assertEqual(rv, b'')
|
||||
|
||||
def test_blank_hex(self):
|
||||
# Reported as problematic in ticket #48
|
||||
rv = self.cast(b'\\x')
|
||||
self.assertEqual(rv, b'')
|
||||
|
||||
def test_full_hex(self, upper=False):
|
||||
buf = ''.join(("%02x" % i) for i in range(256))
|
||||
if upper:
|
||||
buf = buf.upper()
|
||||
buf = '\\x' + buf
|
||||
rv = self.cast(buf.encode('utf8'))
|
||||
self.assertEqual(rv, bytes(range(256)))
|
||||
|
||||
def test_full_hex_upper(self):
|
||||
return self.test_full_hex(upper=True)
|
||||
|
||||
def test_full_escaped_octal(self):
|
||||
buf = ''.join(("\\%03o" % i) for i in range(256))
|
||||
rv = self.cast(buf.encode('utf8'))
|
||||
self.assertEqual(rv, bytes(range(256)))
|
||||
|
||||
def test_escaped_mixed(self):
|
||||
buf = ''.join(("\\%03o" % i) for i in range(32))
|
||||
buf += string.ascii_letters
|
||||
buf += ''.join('\\' + c for c in string.ascii_letters)
|
||||
buf += '\\\\'
|
||||
rv = self.cast(buf.encode('utf8'))
|
||||
tgt = bytes(range(32)) + \
|
||||
(string.ascii_letters * 2 + '\\').encode('ascii')
|
||||
|
||||
self.assertEqual(rv, tgt)
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
1597
tests/test_types_extras.py
Executable file
1597
tests/test_types_extras.py
Executable file
File diff suppressed because it is too large
Load Diff
319
tests/test_with.py
Executable file
319
tests/test_with.py
Executable file
@ -0,0 +1,319 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# test_ctxman.py - unit test for connection and cursor used as context manager
|
||||
#
|
||||
# Copyright (C) 2012-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extensions as ext
|
||||
|
||||
import unittest
|
||||
from .testutils import ConnectingTestCase, skip_before_postgres, skip_if_crdb
|
||||
|
||||
|
||||
class WithTestCase(ConnectingTestCase):
|
||||
def setUp(self):
|
||||
ConnectingTestCase.setUp(self)
|
||||
curs = self.conn.cursor()
|
||||
try:
|
||||
curs.execute("delete from test_with")
|
||||
self.conn.commit()
|
||||
except psycopg2.ProgrammingError:
|
||||
# assume table doesn't exist
|
||||
self.conn.rollback()
|
||||
curs.execute("create table test_with (id integer primary key)")
|
||||
self.conn.commit()
|
||||
|
||||
|
||||
class WithConnectionTestCase(WithTestCase):
|
||||
def test_with_ok(self):
|
||||
with self.conn as conn:
|
||||
self.assert_(self.conn is conn)
|
||||
self.assertEqual(conn.status, ext.STATUS_READY)
|
||||
curs = conn.cursor()
|
||||
curs.execute("insert into test_with values (1)")
|
||||
self.assertEqual(conn.status, ext.STATUS_BEGIN)
|
||||
|
||||
self.assertEqual(self.conn.status, ext.STATUS_READY)
|
||||
self.assert_(not self.conn.closed)
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select * from test_with")
|
||||
self.assertEqual(curs.fetchall(), [(1,)])
|
||||
|
||||
def test_with_connect_idiom(self):
|
||||
with self.connect() as conn:
|
||||
self.assertEqual(conn.status, ext.STATUS_READY)
|
||||
curs = conn.cursor()
|
||||
curs.execute("insert into test_with values (2)")
|
||||
self.assertEqual(conn.status, ext.STATUS_BEGIN)
|
||||
|
||||
self.assertEqual(self.conn.status, ext.STATUS_READY)
|
||||
self.assert_(not self.conn.closed)
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select * from test_with")
|
||||
self.assertEqual(curs.fetchall(), [(2,)])
|
||||
|
||||
def test_with_error_db(self):
|
||||
def f():
|
||||
with self.conn as conn:
|
||||
curs = conn.cursor()
|
||||
curs.execute("insert into test_with values ('a')")
|
||||
|
||||
self.assertRaises(psycopg2.DataError, f)
|
||||
self.assertEqual(self.conn.status, ext.STATUS_READY)
|
||||
self.assert_(not self.conn.closed)
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select * from test_with")
|
||||
self.assertEqual(curs.fetchall(), [])
|
||||
|
||||
def test_with_error_python(self):
|
||||
def f():
|
||||
with self.conn as conn:
|
||||
curs = conn.cursor()
|
||||
curs.execute("insert into test_with values (3)")
|
||||
1 / 0
|
||||
|
||||
self.assertRaises(ZeroDivisionError, f)
|
||||
self.assertEqual(self.conn.status, ext.STATUS_READY)
|
||||
self.assert_(not self.conn.closed)
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select * from test_with")
|
||||
self.assertEqual(curs.fetchall(), [])
|
||||
|
||||
def test_with_closed(self):
|
||||
def f():
|
||||
with self.conn:
|
||||
pass
|
||||
|
||||
self.conn.close()
|
||||
self.assertRaises(psycopg2.InterfaceError, f)
|
||||
|
||||
def test_subclass_commit(self):
|
||||
commits = []
|
||||
|
||||
class MyConn(ext.connection):
|
||||
def commit(self):
|
||||
commits.append(None)
|
||||
super().commit()
|
||||
|
||||
with self.connect(connection_factory=MyConn) as conn:
|
||||
curs = conn.cursor()
|
||||
curs.execute("insert into test_with values (10)")
|
||||
|
||||
self.assertEqual(conn.status, ext.STATUS_READY)
|
||||
self.assert_(commits)
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select * from test_with")
|
||||
self.assertEqual(curs.fetchall(), [(10,)])
|
||||
|
||||
def test_subclass_rollback(self):
|
||||
rollbacks = []
|
||||
|
||||
class MyConn(ext.connection):
|
||||
def rollback(self):
|
||||
rollbacks.append(None)
|
||||
super().rollback()
|
||||
|
||||
try:
|
||||
with self.connect(connection_factory=MyConn) as conn:
|
||||
curs = conn.cursor()
|
||||
curs.execute("insert into test_with values (11)")
|
||||
1 / 0
|
||||
except ZeroDivisionError:
|
||||
pass
|
||||
else:
|
||||
self.assert_("exception not raised")
|
||||
|
||||
self.assertEqual(conn.status, ext.STATUS_READY)
|
||||
self.assert_(rollbacks)
|
||||
|
||||
curs = conn.cursor()
|
||||
curs.execute("select * from test_with")
|
||||
self.assertEqual(curs.fetchall(), [])
|
||||
|
||||
def test_cant_reenter(self):
|
||||
raised_ok = False
|
||||
with self.conn:
|
||||
try:
|
||||
with self.conn:
|
||||
pass
|
||||
except psycopg2.ProgrammingError:
|
||||
raised_ok = True
|
||||
|
||||
self.assert_(raised_ok)
|
||||
|
||||
# Still good
|
||||
with self.conn:
|
||||
pass
|
||||
|
||||
def test_with_autocommit(self):
|
||||
self.conn.autocommit = True
|
||||
self.assertEqual(
|
||||
self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE
|
||||
)
|
||||
with self.conn:
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("insert into test_with values (1)")
|
||||
self.assertEqual(
|
||||
self.conn.info.transaction_status,
|
||||
ext.TRANSACTION_STATUS_INTRANS,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE
|
||||
)
|
||||
curs.execute("select count(*) from test_with")
|
||||
self.assertEqual(curs.fetchone()[0], 1)
|
||||
self.assertEqual(
|
||||
self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE
|
||||
)
|
||||
|
||||
def test_with_autocommit_pyerror(self):
|
||||
self.conn.autocommit = True
|
||||
raised_ok = False
|
||||
try:
|
||||
with self.conn:
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("insert into test_with values (2)")
|
||||
self.assertEqual(
|
||||
self.conn.info.transaction_status,
|
||||
ext.TRANSACTION_STATUS_INTRANS,
|
||||
)
|
||||
1 / 0
|
||||
except ZeroDivisionError:
|
||||
raised_ok = True
|
||||
|
||||
self.assert_(raised_ok)
|
||||
self.assertEqual(
|
||||
self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE
|
||||
)
|
||||
curs.execute("select count(*) from test_with")
|
||||
self.assertEqual(curs.fetchone()[0], 0)
|
||||
self.assertEqual(
|
||||
self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE
|
||||
)
|
||||
|
||||
def test_with_autocommit_pgerror(self):
|
||||
self.conn.autocommit = True
|
||||
raised_ok = False
|
||||
try:
|
||||
with self.conn:
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("insert into test_with values (2)")
|
||||
self.assertEqual(
|
||||
self.conn.info.transaction_status,
|
||||
ext.TRANSACTION_STATUS_INTRANS,
|
||||
)
|
||||
curs.execute("insert into test_with values ('x')")
|
||||
except psycopg2.errors.InvalidTextRepresentation:
|
||||
raised_ok = True
|
||||
|
||||
self.assert_(raised_ok)
|
||||
self.assertEqual(
|
||||
self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE
|
||||
)
|
||||
curs.execute("select count(*) from test_with")
|
||||
self.assertEqual(curs.fetchone()[0], 0)
|
||||
self.assertEqual(
|
||||
self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE
|
||||
)
|
||||
|
||||
|
||||
class WithCursorTestCase(WithTestCase):
|
||||
def test_with_ok(self):
|
||||
with self.conn as conn:
|
||||
with conn.cursor() as curs:
|
||||
curs.execute("insert into test_with values (4)")
|
||||
self.assert_(not curs.closed)
|
||||
self.assertEqual(self.conn.status, ext.STATUS_BEGIN)
|
||||
self.assert_(curs.closed)
|
||||
|
||||
self.assertEqual(self.conn.status, ext.STATUS_READY)
|
||||
self.assert_(not self.conn.closed)
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select * from test_with")
|
||||
self.assertEqual(curs.fetchall(), [(4,)])
|
||||
|
||||
def test_with_error(self):
|
||||
try:
|
||||
with self.conn as conn:
|
||||
with conn.cursor() as curs:
|
||||
curs.execute("insert into test_with values (5)")
|
||||
1 / 0
|
||||
except ZeroDivisionError:
|
||||
pass
|
||||
|
||||
self.assertEqual(self.conn.status, ext.STATUS_READY)
|
||||
self.assert_(not self.conn.closed)
|
||||
self.assert_(curs.closed)
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select * from test_with")
|
||||
self.assertEqual(curs.fetchall(), [])
|
||||
|
||||
def test_subclass(self):
|
||||
closes = []
|
||||
|
||||
class MyCurs(ext.cursor):
|
||||
def close(self):
|
||||
closes.append(None)
|
||||
super().close()
|
||||
|
||||
with self.conn.cursor(cursor_factory=MyCurs) as curs:
|
||||
self.assert_(isinstance(curs, MyCurs))
|
||||
|
||||
self.assert_(curs.closed)
|
||||
self.assert_(closes)
|
||||
|
||||
@skip_if_crdb("named cursor")
|
||||
def test_exception_swallow(self):
|
||||
# bug #262: __exit__ calls cur.close() that hides the exception
|
||||
# with another error.
|
||||
try:
|
||||
with self.conn as conn:
|
||||
with conn.cursor('named') as cur:
|
||||
cur.execute("select 1/0")
|
||||
cur.fetchone()
|
||||
except psycopg2.DataError as e:
|
||||
self.assertEqual(e.pgcode, '22012')
|
||||
else:
|
||||
self.fail("where is my exception?")
|
||||
|
||||
@skip_if_crdb("named cursor")
|
||||
@skip_before_postgres(8, 2)
|
||||
def test_named_with_noop(self):
|
||||
with self.conn.cursor('named'):
|
||||
pass
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
42
tests/testconfig.py
Normal file
42
tests/testconfig.py
Normal file
@ -0,0 +1,42 @@
|
||||
# Configure the test suite from the env variables.
|
||||
|
||||
import os
|
||||
|
||||
dbname = os.environ.get('PSYCOPG2_TESTDB', 'psycopg2_test')
|
||||
dbhost = os.environ.get('PSYCOPG2_TESTDB_HOST', os.environ.get('PGHOST'))
|
||||
dbport = os.environ.get('PSYCOPG2_TESTDB_PORT', os.environ.get('PGPORT'))
|
||||
dbuser = os.environ.get('PSYCOPG2_TESTDB_USER', os.environ.get('PGUSER'))
|
||||
dbpass = os.environ.get('PSYCOPG2_TESTDB_PASSWORD', os.environ.get('PGPASSWORD'))
|
||||
|
||||
# Check if we want to test psycopg's green path.
|
||||
green = os.environ.get('PSYCOPG2_TEST_GREEN', None)
|
||||
if green:
|
||||
if green == '1':
|
||||
from psycopg2.extras import wait_select as wait_callback
|
||||
elif green == 'eventlet':
|
||||
from eventlet.support.psycopg2_patcher import eventlet_wait_callback \
|
||||
as wait_callback
|
||||
else:
|
||||
raise ValueError("please set 'PSYCOPG2_TEST_GREEN' to a valid value")
|
||||
|
||||
import psycopg2.extensions
|
||||
psycopg2.extensions.set_wait_callback(wait_callback)
|
||||
|
||||
# Construct a DSN to connect to the test database:
|
||||
dsn = f'dbname={dbname}'
|
||||
if dbhost is not None:
|
||||
dsn += f' host={dbhost}'
|
||||
if dbport is not None:
|
||||
dsn += f' port={dbport}'
|
||||
if dbuser is not None:
|
||||
dsn += f' user={dbuser}'
|
||||
if dbpass is not None:
|
||||
dsn += f' password={dbpass}'
|
||||
|
||||
# Don't run replication tests if REPL_DSN is not set, default to normal DSN if
|
||||
# set to empty string.
|
||||
repl_dsn = os.environ.get('PSYCOPG2_TEST_REPL_DSN', None)
|
||||
if repl_dsn == '':
|
||||
repl_dsn = dsn
|
||||
|
||||
repl_slot = os.environ.get('PSYCOPG2_TEST_REPL_SLOT', 'psycopg2_test_slot')
|
||||
544
tests/testutils.py
Normal file
544
tests/testutils.py
Normal file
@ -0,0 +1,544 @@
|
||||
# testutils.py - utility module for psycopg2 testing.
|
||||
|
||||
#
|
||||
# Copyright (C) 2010-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
|
||||
# Copyright (C) 2020-2021 The Psycopg Team
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
|
||||
import re
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
import ctypes
|
||||
import select
|
||||
import operator
|
||||
import platform
|
||||
import unittest
|
||||
from functools import wraps
|
||||
from ctypes.util import find_library
|
||||
from io import StringIO # noqa
|
||||
from io import TextIOBase # noqa
|
||||
from importlib import reload # noqa
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.errors
|
||||
import psycopg2.extensions
|
||||
|
||||
from .testconfig import green, dsn, repl_dsn
|
||||
|
||||
|
||||
# Silence warnings caused by the stubbornness of the Python unittest
|
||||
# maintainers
|
||||
# https://bugs.python.org/issue9424
|
||||
if (not hasattr(unittest.TestCase, 'assert_')
|
||||
or unittest.TestCase.assert_ is not unittest.TestCase.assertTrue):
|
||||
# mavaff...
|
||||
unittest.TestCase.assert_ = unittest.TestCase.assertTrue
|
||||
unittest.TestCase.failUnless = unittest.TestCase.assertTrue
|
||||
unittest.TestCase.assertEquals = unittest.TestCase.assertEqual
|
||||
unittest.TestCase.failUnlessEqual = unittest.TestCase.assertEqual
|
||||
|
||||
|
||||
def assertDsnEqual(self, dsn1, dsn2, msg=None):
|
||||
"""Check that two conninfo string have the same content"""
|
||||
self.assertEqual(set(dsn1.split()), set(dsn2.split()), msg)
|
||||
|
||||
|
||||
unittest.TestCase.assertDsnEqual = assertDsnEqual
|
||||
|
||||
|
||||
class ConnectingTestCase(unittest.TestCase):
|
||||
"""A test case providing connections for tests.
|
||||
|
||||
A connection for the test is always available as `self.conn`. Others can be
|
||||
created with `self.connect()`. All are closed on tearDown.
|
||||
|
||||
Subclasses needing to customize setUp and tearDown should remember to call
|
||||
the base class implementations.
|
||||
"""
|
||||
def setUp(self):
|
||||
self._conns = []
|
||||
|
||||
def tearDown(self):
|
||||
# close the connections used in the test
|
||||
for conn in self._conns:
|
||||
if not conn.closed:
|
||||
conn.close()
|
||||
|
||||
def assertQuotedEqual(self, first, second, msg=None):
|
||||
"""Compare two quoted strings disregarding eventual E'' quotes"""
|
||||
def f(s):
|
||||
if isinstance(s, str):
|
||||
return re.sub(r"\bE'", "'", s)
|
||||
elif isinstance(first, bytes):
|
||||
return re.sub(br"\bE'", b"'", s)
|
||||
else:
|
||||
return s
|
||||
|
||||
return self.assertEqual(f(first), f(second), msg)
|
||||
|
||||
def connect(self, **kwargs):
|
||||
try:
|
||||
self._conns
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
f"{e} (did you forget to call ConnectingTestCase.setUp()?)")
|
||||
|
||||
if 'dsn' in kwargs:
|
||||
conninfo = kwargs.pop('dsn')
|
||||
else:
|
||||
conninfo = dsn
|
||||
conn = psycopg2.connect(conninfo, **kwargs)
|
||||
self._conns.append(conn)
|
||||
return conn
|
||||
|
||||
def repl_connect(self, **kwargs):
|
||||
"""Return a connection set up for replication
|
||||
|
||||
The connection is on "PSYCOPG2_TEST_REPL_DSN" unless overridden by
|
||||
a *dsn* kwarg.
|
||||
|
||||
Should raise a skip test if not available, but guard for None on
|
||||
old Python versions.
|
||||
"""
|
||||
if repl_dsn is None:
|
||||
return self.skipTest("replication tests disabled by default")
|
||||
|
||||
if 'dsn' not in kwargs:
|
||||
kwargs['dsn'] = repl_dsn
|
||||
try:
|
||||
conn = self.connect(**kwargs)
|
||||
if conn.async_ == 1:
|
||||
self.wait(conn)
|
||||
except psycopg2.OperationalError as e:
|
||||
# If pgcode is not set it is a genuine connection error
|
||||
# Otherwise we tried to run some bad operation in the connection
|
||||
# (e.g. bug #482) and we'd rather know that.
|
||||
if e.pgcode is None:
|
||||
return self.skipTest(f"replication db not configured: {e}")
|
||||
else:
|
||||
raise
|
||||
|
||||
return conn
|
||||
|
||||
def _get_conn(self):
|
||||
if not hasattr(self, '_the_conn'):
|
||||
self._the_conn = self.connect()
|
||||
|
||||
return self._the_conn
|
||||
|
||||
def _set_conn(self, conn):
|
||||
self._the_conn = conn
|
||||
|
||||
conn = property(_get_conn, _set_conn)
|
||||
|
||||
# for use with async connections only
|
||||
def wait(self, cur_or_conn):
|
||||
pollable = cur_or_conn
|
||||
if not hasattr(pollable, 'poll'):
|
||||
pollable = cur_or_conn.connection
|
||||
while True:
|
||||
state = pollable.poll()
|
||||
if state == psycopg2.extensions.POLL_OK:
|
||||
break
|
||||
elif state == psycopg2.extensions.POLL_READ:
|
||||
select.select([pollable], [], [], 1)
|
||||
elif state == psycopg2.extensions.POLL_WRITE:
|
||||
select.select([], [pollable], [], 1)
|
||||
else:
|
||||
raise Exception("Unexpected result from poll: %r", state)
|
||||
|
||||
_libpq = None
|
||||
|
||||
@property
|
||||
def libpq(self):
|
||||
"""Return a ctypes wrapper for the libpq library"""
|
||||
if ConnectingTestCase._libpq is not None:
|
||||
return ConnectingTestCase._libpq
|
||||
|
||||
libname = find_library('pq')
|
||||
if libname is None and platform.system() == 'Windows':
|
||||
raise self.skipTest("can't import libpq on windows")
|
||||
|
||||
try:
|
||||
rv = ConnectingTestCase._libpq = ctypes.pydll.LoadLibrary(libname)
|
||||
except OSError as e:
|
||||
raise self.skipTest("couldn't open libpq for testing: %s" % e)
|
||||
return rv
|
||||
|
||||
|
||||
def decorate_all_tests(obj, *decorators):
|
||||
"""
|
||||
Apply all the *decorators* to all the tests defined in the TestCase *obj*.
|
||||
|
||||
The decorator can also be applied to a decorator: if *obj* is a function,
|
||||
return a new decorator which can be applied either to a method or to a
|
||||
class, in which case it will decorate all the tests.
|
||||
"""
|
||||
if isinstance(obj, types.FunctionType):
|
||||
def decorator(func_or_cls):
|
||||
if isinstance(func_or_cls, types.FunctionType):
|
||||
return obj(func_or_cls)
|
||||
else:
|
||||
decorate_all_tests(func_or_cls, obj)
|
||||
return func_or_cls
|
||||
|
||||
return decorator
|
||||
|
||||
for n in dir(obj):
|
||||
if n.startswith('test'):
|
||||
for d in decorators:
|
||||
setattr(obj, n, d(getattr(obj, n)))
|
||||
|
||||
|
||||
@decorate_all_tests
|
||||
def skip_if_no_uuid(f):
|
||||
"""Decorator to skip a test if uuid is not supported by PG."""
|
||||
@wraps(f)
|
||||
def skip_if_no_uuid_(self):
|
||||
try:
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("select typname from pg_type where typname = 'uuid'")
|
||||
has = cur.fetchone()
|
||||
finally:
|
||||
self.conn.rollback()
|
||||
|
||||
if has:
|
||||
return f(self)
|
||||
else:
|
||||
return self.skipTest("uuid type not available on the server")
|
||||
|
||||
return skip_if_no_uuid_
|
||||
|
||||
|
||||
@decorate_all_tests
|
||||
def skip_if_tpc_disabled(f):
|
||||
"""Skip a test if the server has tpc support disabled."""
|
||||
@wraps(f)
|
||||
def skip_if_tpc_disabled_(self):
|
||||
cnn = self.connect()
|
||||
skip_if_crdb("2-phase commit", cnn)
|
||||
|
||||
cur = cnn.cursor()
|
||||
try:
|
||||
cur.execute("SHOW max_prepared_transactions;")
|
||||
except psycopg2.ProgrammingError:
|
||||
return self.skipTest(
|
||||
"server too old: two phase transactions not supported.")
|
||||
else:
|
||||
mtp = int(cur.fetchone()[0])
|
||||
cnn.close()
|
||||
|
||||
if not mtp:
|
||||
return self.skipTest(
|
||||
"server not configured for two phase transactions. "
|
||||
"set max_prepared_transactions to > 0 to run the test")
|
||||
return f(self)
|
||||
|
||||
return skip_if_tpc_disabled_
|
||||
|
||||
|
||||
def skip_before_postgres(*ver):
|
||||
"""Skip a test on PostgreSQL before a certain version."""
|
||||
reason = None
|
||||
if isinstance(ver[-1], str):
|
||||
ver, reason = ver[:-1], ver[-1]
|
||||
|
||||
ver = ver + (0,) * (3 - len(ver))
|
||||
|
||||
@decorate_all_tests
|
||||
def skip_before_postgres_(f):
|
||||
@wraps(f)
|
||||
def skip_before_postgres__(self):
|
||||
if self.conn.info.server_version < int("%d%02d%02d" % ver):
|
||||
return self.skipTest(
|
||||
reason or "skipped because PostgreSQL %s"
|
||||
% self.conn.info.server_version)
|
||||
else:
|
||||
return f(self)
|
||||
|
||||
return skip_before_postgres__
|
||||
return skip_before_postgres_
|
||||
|
||||
|
||||
def skip_after_postgres(*ver):
|
||||
"""Skip a test on PostgreSQL after (including) a certain version."""
|
||||
ver = ver + (0,) * (3 - len(ver))
|
||||
|
||||
@decorate_all_tests
|
||||
def skip_after_postgres_(f):
|
||||
@wraps(f)
|
||||
def skip_after_postgres__(self):
|
||||
if self.conn.info.server_version >= int("%d%02d%02d" % ver):
|
||||
return self.skipTest("skipped because PostgreSQL %s"
|
||||
% self.conn.info.server_version)
|
||||
else:
|
||||
return f(self)
|
||||
|
||||
return skip_after_postgres__
|
||||
return skip_after_postgres_
|
||||
|
||||
|
||||
def libpq_version():
|
||||
v = psycopg2.__libpq_version__
|
||||
if v >= 90100:
|
||||
v = min(v, psycopg2.extensions.libpq_version())
|
||||
return v
|
||||
|
||||
|
||||
def skip_before_libpq(*ver):
|
||||
"""Skip a test if libpq we're linked to is older than a certain version."""
|
||||
ver = ver + (0,) * (3 - len(ver))
|
||||
|
||||
def skip_before_libpq_(cls):
|
||||
v = libpq_version()
|
||||
decorator = unittest.skipIf(
|
||||
v < int("%d%02d%02d" % ver),
|
||||
f"skipped because libpq {v}",
|
||||
)
|
||||
return decorator(cls)
|
||||
return skip_before_libpq_
|
||||
|
||||
|
||||
def skip_after_libpq(*ver):
|
||||
"""Skip a test if libpq we're linked to is newer than a certain version."""
|
||||
ver = ver + (0,) * (3 - len(ver))
|
||||
|
||||
def skip_after_libpq_(cls):
|
||||
v = libpq_version()
|
||||
decorator = unittest.skipIf(
|
||||
v >= int("%d%02d%02d" % ver),
|
||||
f"skipped because libpq {v}",
|
||||
)
|
||||
return decorator(cls)
|
||||
return skip_after_libpq_
|
||||
|
||||
|
||||
def skip_before_python(*ver):
|
||||
"""Skip a test on Python before a certain version."""
|
||||
def skip_before_python_(cls):
|
||||
decorator = unittest.skipIf(
|
||||
sys.version_info[:len(ver)] < ver,
|
||||
f"skipped because Python {'.'.join(map(str, sys.version_info[:len(ver)]))}",
|
||||
)
|
||||
return decorator(cls)
|
||||
return skip_before_python_
|
||||
|
||||
|
||||
def skip_from_python(*ver):
|
||||
"""Skip a test on Python after (including) a certain version."""
|
||||
def skip_from_python_(cls):
|
||||
decorator = unittest.skipIf(
|
||||
sys.version_info[:len(ver)] >= ver,
|
||||
f"skipped because Python {'.'.join(map(str, sys.version_info[:len(ver)]))}",
|
||||
)
|
||||
return decorator(cls)
|
||||
return skip_from_python_
|
||||
|
||||
|
||||
@decorate_all_tests
|
||||
def skip_if_no_superuser(f):
|
||||
"""Skip a test if the database user running the test is not a superuser"""
|
||||
@wraps(f)
|
||||
def skip_if_no_superuser_(self):
|
||||
try:
|
||||
return f(self)
|
||||
except psycopg2.errors.InsufficientPrivilege:
|
||||
self.skipTest("skipped because not superuser")
|
||||
|
||||
return skip_if_no_superuser_
|
||||
|
||||
|
||||
def skip_if_green(reason):
|
||||
def skip_if_green_(cls):
|
||||
decorator = unittest.skipIf(green, reason)
|
||||
return decorator(cls)
|
||||
return skip_if_green_
|
||||
|
||||
|
||||
skip_copy_if_green = skip_if_green("copy in async mode currently not supported")
|
||||
|
||||
|
||||
def skip_if_no_getrefcount(cls):
|
||||
decorator = unittest.skipUnless(
|
||||
hasattr(sys, 'getrefcount'),
|
||||
'no sys.getrefcount()',
|
||||
)
|
||||
return decorator(cls)
|
||||
|
||||
|
||||
def skip_if_windows(cls):
|
||||
"""Skip a test if run on windows"""
|
||||
decorator = unittest.skipIf(
|
||||
platform.system() == 'Windows',
|
||||
"Not supported on Windows",
|
||||
)
|
||||
return decorator(cls)
|
||||
|
||||
|
||||
def crdb_version(conn, __crdb_version=[]):
|
||||
"""
|
||||
Return the CockroachDB version if that's the db being tested, else None.
|
||||
|
||||
Return the number as an integer similar to PQserverVersion: return
|
||||
v20.1.3 as 200103.
|
||||
|
||||
Assume all the connections are on the same db: return a cached result on
|
||||
following calls.
|
||||
|
||||
"""
|
||||
if __crdb_version:
|
||||
return __crdb_version[0]
|
||||
|
||||
sver = conn.info.parameter_status("crdb_version")
|
||||
if sver is None:
|
||||
__crdb_version.append(None)
|
||||
else:
|
||||
m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)
|
||||
if not m:
|
||||
raise ValueError(
|
||||
f"can't parse CockroachDB version from {sver}")
|
||||
|
||||
ver = int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3))
|
||||
__crdb_version.append(ver)
|
||||
|
||||
return __crdb_version[0]
|
||||
|
||||
|
||||
def skip_if_crdb(reason, conn=None, version=None):
|
||||
"""Skip a test or test class if we are testing against CockroachDB.
|
||||
|
||||
Can be used as a decorator for tests function or classes:
|
||||
|
||||
@skip_if_crdb("my reason")
|
||||
class SomeUnitTest(UnitTest):
|
||||
# ...
|
||||
|
||||
Or as a normal function if the *conn* argument is passed.
|
||||
|
||||
If *version* is specified it should be a string such as ">= 20.1", "< 20",
|
||||
"== 20.1.3": the test will be skipped only if the version matches.
|
||||
|
||||
"""
|
||||
if not isinstance(reason, str):
|
||||
raise TypeError(f"reason should be a string, got {reason!r} instead")
|
||||
|
||||
if conn is not None:
|
||||
ver = crdb_version(conn)
|
||||
if ver is not None and _crdb_match_version(ver, version):
|
||||
if reason in crdb_reasons:
|
||||
reason = (
|
||||
"%s (https://github.com/cockroachdb/cockroach/issues/%s)"
|
||||
% (reason, crdb_reasons[reason]))
|
||||
raise unittest.SkipTest(
|
||||
f"not supported on CockroachDB {ver}: {reason}")
|
||||
|
||||
@decorate_all_tests
|
||||
def skip_if_crdb_(f):
|
||||
@wraps(f)
|
||||
def skip_if_crdb__(self, *args, **kwargs):
|
||||
skip_if_crdb(reason, conn=self.connect(), version=version)
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
return skip_if_crdb__
|
||||
|
||||
return skip_if_crdb_
|
||||
|
||||
|
||||
# mapping from reason description to ticket number
|
||||
crdb_reasons = {
|
||||
"2-phase commit": 22329,
|
||||
"backend pid": 35897,
|
||||
"cancel": 41335,
|
||||
"cast adds tz": 51692,
|
||||
"cidr": 18846,
|
||||
"composite": 27792,
|
||||
"copy": 41608,
|
||||
"deferrable": 48307,
|
||||
"encoding": 35882,
|
||||
"hstore": 41284,
|
||||
"infinity date": 41564,
|
||||
"interval style": 35807,
|
||||
"large objects": 243,
|
||||
"named cursor": 41412,
|
||||
"nested array": 32552,
|
||||
"notify": 41522,
|
||||
"password_encryption": 42519,
|
||||
"range": 41282,
|
||||
"stored procedure": 1751,
|
||||
}
|
||||
|
||||
|
||||
def _crdb_match_version(version, pattern):
|
||||
if pattern is None:
|
||||
return True
|
||||
|
||||
m = re.match(r'^(>|>=|<|<=|==|!=)\s*(\d+)(?:\.(\d+))?(?:\.(\d+))?$', pattern)
|
||||
if m is None:
|
||||
raise ValueError(
|
||||
"bad crdb version pattern %r: should be 'OP MAJOR[.MINOR[.BUGFIX]]'"
|
||||
% pattern)
|
||||
|
||||
ops = {'>': 'gt', '>=': 'ge', '<': 'lt', '<=': 'le', '==': 'eq', '!=': 'ne'}
|
||||
op = getattr(operator, ops[m.group(1)])
|
||||
ref = int(m.group(2)) * 10000 + int(m.group(3) or 0) * 100 + int(m.group(4) or 0)
|
||||
return op(version, ref)
|
||||
|
||||
|
||||
class raises_typeerror:
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, type, exc, tb):
|
||||
assert type is TypeError
|
||||
return True
|
||||
|
||||
|
||||
def slow(f):
|
||||
"""Decorator to mark slow tests we may want to skip
|
||||
|
||||
Note: in order to find slow tests you can run:
|
||||
|
||||
make check 2>&1 | ts -i "%.s" | sort -n
|
||||
"""
|
||||
@wraps(f)
|
||||
def slow_(self):
|
||||
if os.environ.get('PSYCOPG2_TEST_FAST', '0') != '0':
|
||||
return self.skipTest("slow test")
|
||||
return f(self)
|
||||
return slow_
|
||||
|
||||
|
||||
def restore_types(f):
|
||||
"""Decorator to restore the adaptation system after running a test"""
|
||||
@wraps(f)
|
||||
def restore_types_(self):
|
||||
types = psycopg2.extensions.string_types.copy()
|
||||
adapters = psycopg2.extensions.adapters.copy()
|
||||
try:
|
||||
return f(self)
|
||||
finally:
|
||||
psycopg2.extensions.string_types.clear()
|
||||
psycopg2.extensions.string_types.update(types)
|
||||
psycopg2.extensions.adapters.clear()
|
||||
psycopg2.extensions.adapters.update(adapters)
|
||||
|
||||
return restore_types_
|
||||
Reference in New Issue
Block a user