first commit based on psycopg2 2.9 version

This commit is contained in:
lishifu_db
2021-07-05 21:34:17 +08:00
parent d126b6ec53
commit 3553ed0e30
178 changed files with 51253 additions and 0 deletions

104
tests/__init__.py Executable file
View File

@ -0,0 +1,104 @@
#!/usr/bin/env python
# psycopg2 test suite
#
# Copyright (C) 2007-2019 Federico Di Gregorio <fog@debian.org>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
# Convert warnings into errors here. We can't do it with -W because on
# Travis importing site raises a warning.
import warnings
warnings.simplefilter('error') # noqa
import sys
from .testconfig import dsn
import unittest
from . import test_async
from . import test_bugX000
from . import test_bug_gc
from . import test_cancel
from . import test_connection
from . import test_copy
from . import test_cursor
from . import test_dates
from . import test_errcodes
from . import test_errors
from . import test_extras_dictcursor
from . import test_fast_executemany
from . import test_green
from . import test_ipaddress
from . import test_lobject
from . import test_module
from . import test_notify
from . import test_psycopg2_dbapi20
from . import test_quote
from . import test_replication
from . import test_sql
from . import test_transaction
from . import test_types_basic
from . import test_types_extras
from . import test_with
def test_suite():
# If connection to test db fails, bail out early.
import psycopg2
try:
cnn = psycopg2.connect(dsn)
except Exception as e:
print("Failed connection to test db:", e.__class__.__name__, e)
print("Please set env vars 'PSYCOPG2_TESTDB*' to valid values.")
sys.exit(1)
else:
cnn.close()
suite = unittest.TestSuite()
suite.addTest(test_async.test_suite())
suite.addTest(test_bugX000.test_suite())
suite.addTest(test_bug_gc.test_suite())
suite.addTest(test_cancel.test_suite())
suite.addTest(test_connection.test_suite())
suite.addTest(test_copy.test_suite())
suite.addTest(test_cursor.test_suite())
suite.addTest(test_dates.test_suite())
suite.addTest(test_errcodes.test_suite())
suite.addTest(test_errors.test_suite())
suite.addTest(test_extras_dictcursor.test_suite())
suite.addTest(test_fast_executemany.test_suite())
suite.addTest(test_green.test_suite())
suite.addTest(test_ipaddress.test_suite())
suite.addTest(test_lobject.test_suite())
suite.addTest(test_module.test_suite())
suite.addTest(test_notify.test_suite())
suite.addTest(test_psycopg2_dbapi20.test_suite())
suite.addTest(test_quote.test_suite())
suite.addTest(test_replication.test_suite())
suite.addTest(test_sql.test_suite())
suite.addTest(test_transaction.test_suite())
suite.addTest(test_types_basic.test_suite())
suite.addTest(test_types_extras.test_suite())
suite.addTest(test_with.test_suite())
return suite
if __name__ == '__main__':
unittest.main(defaultTest='test_suite')

862
tests/dbapi20.py Normal file
View File

@ -0,0 +1,862 @@
#!/usr/bin/env python
''' Python DB API 2.0 driver compliance unit test suite.
This software is Public Domain and may be used without restrictions.
"Now we have booze and barflies entering the discussion, plus rumours of
DBAs on drugs... and I won't tell you what flashes through my mind each
time I read the subject line with 'Anal Compliance' in it. All around
this is turning out to be a thoroughly unwholesome unit test."
-- Ian Bicking
'''
__rcs_id__ = '$Id: dbapi20.py,v 1.11 2005/01/02 02:41:01 zenzen Exp $'
__version__ = '$Revision: 1.12 $'[11:-2]
__author__ = 'Stuart Bishop <stuart@stuartbishop.net>'
import unittest
import time
import sys
# Revision 1.12 2009/02/06 03:35:11 kf7xm
# Tested okay with Python 3.0, includes last minute patches from Mark H.
#
# Revision 1.1.1.1.2.1 2008/09/20 19:54:59 rupole
# Include latest changes from main branch
# Updates for py3k
#
# Revision 1.11 2005/01/02 02:41:01 zenzen
# Update author email address
#
# Revision 1.10 2003/10/09 03:14:14 zenzen
# Add test for DB API 2.0 optional extension, where database exceptions
# are exposed as attributes on the Connection object.
#
# Revision 1.9 2003/08/13 01:16:36 zenzen
# Minor tweak from Stefan Fleiter
#
# Revision 1.8 2003/04/10 00:13:25 zenzen
# Changes, as per suggestions by M.-A. Lemburg
# - Add a table prefix, to ensure namespace collisions can always be avoided
#
# Revision 1.7 2003/02/26 23:33:37 zenzen
# Break out DDL into helper functions, as per request by David Rushby
#
# Revision 1.6 2003/02/21 03:04:33 zenzen
# Stuff from Henrik Ekelund:
# added test_None
# added test_nextset & hooks
#
# Revision 1.5 2003/02/17 22:08:43 zenzen
# Implement suggestions and code from Henrik Eklund - test that cursor.arraysize
# defaults to 1 & generic cursor.callproc test added
#
# Revision 1.4 2003/02/15 00:16:33 zenzen
# Changes, as per suggestions and bug reports by M.-A. Lemburg,
# Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar
# - Class renamed
# - Now a subclass of TestCase, to avoid requiring the driver stub
# to use multiple inheritance
# - Reversed the polarity of buggy test in test_description
# - Test exception hierarchy correctly
# - self.populate is now self._populate(), so if a driver stub
# overrides self.ddl1 this change propagates
# - VARCHAR columns now have a width, which will hopefully make the
# DDL even more portable (this will be reversed if it causes more problems)
# - cursor.rowcount being checked after various execute and fetchXXX methods
# - Check for fetchall and fetchmany returning empty lists after results
# are exhausted (already checking for empty lists if select retrieved
# nothing
# - Fix bugs in test_setoutputsize_basic and test_setinputsizes
#
class DatabaseAPI20Test(unittest.TestCase):
''' Test a database self.driver for DB API 2.0 compatibility.
This implementation tests Gadfly, but the TestCase
is structured so that other self.drivers can subclass this
test case to ensure compliance with the DB-API. It is
expected that this TestCase may be expanded in the future
if ambiguities or edge conditions are discovered.
The 'Optional Extensions' are not yet being tested.
self.drivers should subclass this test, overriding setUp, tearDown,
self.driver, connect_args and connect_kw_args. Class specification
should be as follows:
from . import dbapi20
class mytest(dbapi20.DatabaseAPI20Test):
[...]
Don't 'from .dbapi20 import DatabaseAPI20Test', or you will
confuse the unit tester - just 'from . import dbapi20'.
'''
# The self.driver module. This should be the module where the 'connect'
# method is to be found
driver = None
connect_args = () # List of arguments to pass to connect
connect_kw_args = {} # Keyword arguments for connect
table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables
ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix
ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix
xddl1 = 'drop table %sbooze' % table_prefix
xddl2 = 'drop table %sbarflys' % table_prefix
lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase
# Some drivers may need to override these helpers, for example adding
# a 'commit' after the execute.
def executeDDL1(self,cursor):
cursor.execute(self.ddl1)
def executeDDL2(self,cursor):
cursor.execute(self.ddl2)
def setUp(self):
''' self.drivers should override this method to perform required setup
if any is necessary, such as creating the database.
'''
pass
def tearDown(self):
''' self.drivers should override this method to perform required cleanup
if any is necessary, such as deleting the test database.
The default drops the tables that may be created.
'''
con = self._connect()
try:
cur = con.cursor()
for ddl in (self.xddl1,self.xddl2):
try:
cur.execute(ddl)
con.commit()
except self.driver.Error:
# Assume table didn't exist. Other tests will check if
# execute is busted.
pass
finally:
con.close()
def _connect(self):
try:
return self.driver.connect(
*self.connect_args,**self.connect_kw_args
)
except AttributeError:
self.fail("No connect method found in self.driver module")
def test_connect(self):
con = self._connect()
con.close()
def test_apilevel(self):
try:
# Must exist
apilevel = self.driver.apilevel
# Must equal 2.0
self.assertEqual(apilevel,'2.0')
except AttributeError:
self.fail("Driver doesn't define apilevel")
def test_threadsafety(self):
try:
# Must exist
threadsafety = self.driver.threadsafety
# Must be a valid value
self.failUnless(threadsafety in (0,1,2,3))
except AttributeError:
self.fail("Driver doesn't define threadsafety")
def test_paramstyle(self):
try:
# Must exist
paramstyle = self.driver.paramstyle
# Must be a valid value
self.failUnless(paramstyle in (
'qmark','numeric','named','format','pyformat'
))
except AttributeError:
self.fail("Driver doesn't define paramstyle")
def test_Exceptions(self):
# Make sure required exceptions exist, and are in the
# defined hierarchy.
self.failUnless(issubclass(self.driver.Warning,Exception))
self.failUnless(issubclass(self.driver.Error,Exception))
self.failUnless(
issubclass(self.driver.InterfaceError,self.driver.Error)
)
self.failUnless(
issubclass(self.driver.DatabaseError,self.driver.Error)
)
self.failUnless(
issubclass(self.driver.OperationalError,self.driver.Error)
)
self.failUnless(
issubclass(self.driver.IntegrityError,self.driver.Error)
)
self.failUnless(
issubclass(self.driver.InternalError,self.driver.Error)
)
self.failUnless(
issubclass(self.driver.ProgrammingError,self.driver.Error)
)
self.failUnless(
issubclass(self.driver.NotSupportedError,self.driver.Error)
)
def test_ExceptionsAsConnectionAttributes(self):
# OPTIONAL EXTENSION
# Test for the optional DB API 2.0 extension, where the exceptions
# are exposed as attributes on the Connection object
# I figure this optional extension will be implemented by any
# driver author who is using this test suite, so it is enabled
# by default.
con = self._connect()
drv = self.driver
self.failUnless(con.Warning is drv.Warning)
self.failUnless(con.Error is drv.Error)
self.failUnless(con.InterfaceError is drv.InterfaceError)
self.failUnless(con.DatabaseError is drv.DatabaseError)
self.failUnless(con.OperationalError is drv.OperationalError)
self.failUnless(con.IntegrityError is drv.IntegrityError)
self.failUnless(con.InternalError is drv.InternalError)
self.failUnless(con.ProgrammingError is drv.ProgrammingError)
self.failUnless(con.NotSupportedError is drv.NotSupportedError)
def test_commit(self):
con = self._connect()
try:
# Commit must work, even if it doesn't do anything
con.commit()
finally:
con.close()
def test_rollback(self):
con = self._connect()
# If rollback is defined, it should either work or throw
# the documented exception
if hasattr(con,'rollback'):
try:
con.rollback()
except self.driver.NotSupportedError:
pass
def test_cursor(self):
con = self._connect()
try:
cur = con.cursor()
finally:
con.close()
def test_cursor_isolation(self):
con = self._connect()
try:
# Make sure cursors created from the same connection have
# the documented transaction isolation level
cur1 = con.cursor()
cur2 = con.cursor()
self.executeDDL1(cur1)
cur1.execute("insert into %sbooze values ('Victoria Bitter')" % (
self.table_prefix
))
cur2.execute("select name from %sbooze" % self.table_prefix)
booze = cur2.fetchall()
self.assertEqual(len(booze),1)
self.assertEqual(len(booze[0]),1)
self.assertEqual(booze[0][0],'Victoria Bitter')
finally:
con.close()
def test_description(self):
con = self._connect()
try:
cur = con.cursor()
self.executeDDL1(cur)
self.assertEqual(cur.description,None,
'cursor.description should be none after executing a '
'statement that can return no rows (such as DDL)'
)
cur.execute('select name from %sbooze' % self.table_prefix)
self.assertEqual(len(cur.description),1,
'cursor.description describes too many columns'
)
self.assertEqual(len(cur.description[0]),7,
'cursor.description[x] tuples must have 7 elements'
)
self.assertEqual(cur.description[0][0].lower(),'name',
'cursor.description[x][0] must return column name'
)
self.assertEqual(cur.description[0][1],self.driver.STRING,
'cursor.description[x][1] must return column type. Got %r'
% cur.description[0][1]
)
# Make sure self.description gets reset
self.executeDDL2(cur)
self.assertEqual(cur.description,None,
'cursor.description not being set to None when executing '
'no-result statements (eg. DDL)'
)
finally:
con.close()
def test_rowcount(self):
con = self._connect()
try:
cur = con.cursor()
self.executeDDL1(cur)
self.assertEqual(cur.rowcount,-1,
'cursor.rowcount should be -1 after executing no-result '
'statements'
)
cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
self.table_prefix
))
self.failUnless(cur.rowcount in (-1,1),
'cursor.rowcount should == number or rows inserted, or '
'set to -1 after executing an insert statement'
)
cur.execute("select name from %sbooze" % self.table_prefix)
self.failUnless(cur.rowcount in (-1,1),
'cursor.rowcount should == number of rows returned, or '
'set to -1 after executing a select statement'
)
self.executeDDL2(cur)
self.assertEqual(cur.rowcount,-1,
'cursor.rowcount not being reset to -1 after executing '
'no-result statements'
)
finally:
con.close()
lower_func = 'lower'
def test_callproc(self):
con = self._connect()
try:
cur = con.cursor()
if self.lower_func and hasattr(cur,'callproc'):
r = cur.callproc(self.lower_func,('FOO',))
self.assertEqual(len(r),1)
self.assertEqual(r[0],'FOO')
r = cur.fetchall()
self.assertEqual(len(r),1,'callproc produced no result set')
self.assertEqual(len(r[0]),1,
'callproc produced invalid result set'
)
self.assertEqual(r[0][0],'foo',
'callproc produced invalid results'
)
finally:
con.close()
def test_close(self):
con = self._connect()
try:
cur = con.cursor()
finally:
con.close()
# cursor.execute should raise an Error if called after connection
# closed
self.assertRaises(self.driver.Error,self.executeDDL1,cur)
# connection.commit should raise an Error if called after connection'
# closed.'
self.assertRaises(self.driver.Error,con.commit)
# connection.close should raise an Error if called more than once
# Issue discussed on DB-SIG: consensus seem that close() should not
# raised if called on closed objects. Issue reported back to Stuart.
# self.assertRaises(self.driver.Error,con.close)
def test_execute(self):
con = self._connect()
try:
cur = con.cursor()
self._paraminsert(cur)
finally:
con.close()
def _paraminsert(self,cur):
self.executeDDL1(cur)
cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
self.table_prefix
))
self.failUnless(cur.rowcount in (-1,1))
if self.driver.paramstyle == 'qmark':
cur.execute(
'insert into %sbooze values (?)' % self.table_prefix,
("Cooper's",)
)
elif self.driver.paramstyle == 'numeric':
cur.execute(
'insert into %sbooze values (:1)' % self.table_prefix,
("Cooper's",)
)
elif self.driver.paramstyle == 'named':
cur.execute(
'insert into %sbooze values (:beer)' % self.table_prefix,
{'beer':"Cooper's"}
)
elif self.driver.paramstyle == 'format':
cur.execute(
'insert into %sbooze values (%%s)' % self.table_prefix,
("Cooper's",)
)
elif self.driver.paramstyle == 'pyformat':
cur.execute(
'insert into %sbooze values (%%(beer)s)' % self.table_prefix,
{'beer':"Cooper's"}
)
else:
self.fail('Invalid paramstyle')
self.failUnless(cur.rowcount in (-1,1))
cur.execute('select name from %sbooze' % self.table_prefix)
res = cur.fetchall()
self.assertEqual(len(res),2,'cursor.fetchall returned too few rows')
beers = [res[0][0],res[1][0]]
beers.sort()
self.assertEqual(beers[0],"Cooper's",
'cursor.fetchall retrieved incorrect data, or data inserted '
'incorrectly'
)
self.assertEqual(beers[1],"Victoria Bitter",
'cursor.fetchall retrieved incorrect data, or data inserted '
'incorrectly'
)
def test_executemany(self):
con = self._connect()
try:
cur = con.cursor()
self.executeDDL1(cur)
largs = [ ("Cooper's",) , ("Boag's",) ]
margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ]
if self.driver.paramstyle == 'qmark':
cur.executemany(
'insert into %sbooze values (?)' % self.table_prefix,
largs
)
elif self.driver.paramstyle == 'numeric':
cur.executemany(
'insert into %sbooze values (:1)' % self.table_prefix,
largs
)
elif self.driver.paramstyle == 'named':
cur.executemany(
'insert into %sbooze values (:beer)' % self.table_prefix,
margs
)
elif self.driver.paramstyle == 'format':
cur.executemany(
'insert into %sbooze values (%%s)' % self.table_prefix,
largs
)
elif self.driver.paramstyle == 'pyformat':
cur.executemany(
'insert into %sbooze values (%%(beer)s)' % (
self.table_prefix
),
margs
)
else:
self.fail('Unknown paramstyle')
self.failUnless(cur.rowcount in (-1,2),
'insert using cursor.executemany set cursor.rowcount to '
'incorrect value %r' % cur.rowcount
)
cur.execute('select name from %sbooze' % self.table_prefix)
res = cur.fetchall()
self.assertEqual(len(res),2,
'cursor.fetchall retrieved incorrect number of rows'
)
beers = [res[0][0],res[1][0]]
beers.sort()
self.assertEqual(beers[0],"Boag's",'incorrect data retrieved')
self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved')
finally:
con.close()
def test_fetchone(self):
con = self._connect()
try:
cur = con.cursor()
# cursor.fetchone should raise an Error if called before
# executing a select-type query
self.assertRaises(self.driver.Error,cur.fetchone)
# cursor.fetchone should raise an Error if called after
# executing a query that cannot return rows
self.executeDDL1(cur)
self.assertRaises(self.driver.Error,cur.fetchone)
cur.execute('select name from %sbooze' % self.table_prefix)
self.assertEqual(cur.fetchone(),None,
'cursor.fetchone should return None if a query retrieves '
'no rows'
)
self.failUnless(cur.rowcount in (-1,0))
# cursor.fetchone should raise an Error if called after
# executing a query that cannot return rows
cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
self.table_prefix
))
self.assertRaises(self.driver.Error,cur.fetchone)
cur.execute('select name from %sbooze' % self.table_prefix)
r = cur.fetchone()
self.assertEqual(len(r),1,
'cursor.fetchone should have retrieved a single row'
)
self.assertEqual(r[0],'Victoria Bitter',
'cursor.fetchone retrieved incorrect data'
)
self.assertEqual(cur.fetchone(),None,
'cursor.fetchone should return None if no more rows available'
)
self.failUnless(cur.rowcount in (-1,1))
finally:
con.close()
samples = [
'Carlton Cold',
'Carlton Draft',
'Mountain Goat',
'Redback',
'Victoria Bitter',
'XXXX'
]
def _populate(self):
''' Return a list of sql commands to setup the DB for the fetch
tests.
'''
populate = [
f"insert into {self.table_prefix}booze values ('{s}')"
for s in self.samples
]
return populate
def test_fetchmany(self):
con = self._connect()
try:
cur = con.cursor()
# cursor.fetchmany should raise an Error if called without
#issuing a query
self.assertRaises(self.driver.Error,cur.fetchmany,4)
self.executeDDL1(cur)
for sql in self._populate():
cur.execute(sql)
cur.execute('select name from %sbooze' % self.table_prefix)
r = cur.fetchmany()
self.assertEqual(len(r),1,
'cursor.fetchmany retrieved incorrect number of rows, '
'default of arraysize is one.'
)
cur.arraysize=10
r = cur.fetchmany(3) # Should get 3 rows
self.assertEqual(len(r),3,
'cursor.fetchmany retrieved incorrect number of rows'
)
r = cur.fetchmany(4) # Should get 2 more
self.assertEqual(len(r),2,
'cursor.fetchmany retrieved incorrect number of rows'
)
r = cur.fetchmany(4) # Should be an empty sequence
self.assertEqual(len(r),0,
'cursor.fetchmany should return an empty sequence after '
'results are exhausted'
)
self.failUnless(cur.rowcount in (-1,6))
# Same as above, using cursor.arraysize
cur.arraysize=4
cur.execute('select name from %sbooze' % self.table_prefix)
r = cur.fetchmany() # Should get 4 rows
self.assertEqual(len(r),4,
'cursor.arraysize not being honoured by fetchmany'
)
r = cur.fetchmany() # Should get 2 more
self.assertEqual(len(r),2)
r = cur.fetchmany() # Should be an empty sequence
self.assertEqual(len(r),0)
self.failUnless(cur.rowcount in (-1,6))
cur.arraysize=6
cur.execute('select name from %sbooze' % self.table_prefix)
rows = cur.fetchmany() # Should get all rows
self.failUnless(cur.rowcount in (-1,6))
self.assertEqual(len(rows),6)
self.assertEqual(len(rows),6)
rows = [r[0] for r in rows]
rows.sort()
# Make sure we get the right data back out
for i in range(0,6):
self.assertEqual(rows[i],self.samples[i],
'incorrect data retrieved by cursor.fetchmany'
)
rows = cur.fetchmany() # Should return an empty list
self.assertEqual(len(rows),0,
'cursor.fetchmany should return an empty sequence if '
'called after the whole result set has been fetched'
)
self.failUnless(cur.rowcount in (-1,6))
self.executeDDL2(cur)
cur.execute('select name from %sbarflys' % self.table_prefix)
r = cur.fetchmany() # Should get empty sequence
self.assertEqual(len(r),0,
'cursor.fetchmany should return an empty sequence if '
'query retrieved no rows'
)
self.failUnless(cur.rowcount in (-1,0))
finally:
con.close()
def test_fetchall(self):
con = self._connect()
try:
cur = con.cursor()
# cursor.fetchall should raise an Error if called
# without executing a query that may return rows (such
# as a select)
self.assertRaises(self.driver.Error, cur.fetchall)
self.executeDDL1(cur)
for sql in self._populate():
cur.execute(sql)
# cursor.fetchall should raise an Error if called
# after executing a a statement that cannot return rows
self.assertRaises(self.driver.Error,cur.fetchall)
cur.execute('select name from %sbooze' % self.table_prefix)
rows = cur.fetchall()
self.failUnless(cur.rowcount in (-1,len(self.samples)))
self.assertEqual(len(rows),len(self.samples),
'cursor.fetchall did not retrieve all rows'
)
rows = [r[0] for r in rows]
rows.sort()
for i in range(0,len(self.samples)):
self.assertEqual(rows[i],self.samples[i],
'cursor.fetchall retrieved incorrect rows'
)
rows = cur.fetchall()
self.assertEqual(
len(rows),0,
'cursor.fetchall should return an empty list if called '
'after the whole result set has been fetched'
)
self.failUnless(cur.rowcount in (-1,len(self.samples)))
self.executeDDL2(cur)
cur.execute('select name from %sbarflys' % self.table_prefix)
rows = cur.fetchall()
self.failUnless(cur.rowcount in (-1,0))
self.assertEqual(len(rows),0,
'cursor.fetchall should return an empty list if '
'a select query returns no rows'
)
finally:
con.close()
def test_mixedfetch(self):
con = self._connect()
try:
cur = con.cursor()
self.executeDDL1(cur)
for sql in self._populate():
cur.execute(sql)
cur.execute('select name from %sbooze' % self.table_prefix)
rows1 = cur.fetchone()
rows23 = cur.fetchmany(2)
rows4 = cur.fetchone()
rows56 = cur.fetchall()
self.failUnless(cur.rowcount in (-1,6))
self.assertEqual(len(rows23),2,
'fetchmany returned incorrect number of rows'
)
self.assertEqual(len(rows56),2,
'fetchall returned incorrect number of rows'
)
rows = [rows1[0]]
rows.extend([rows23[0][0],rows23[1][0]])
rows.append(rows4[0])
rows.extend([rows56[0][0],rows56[1][0]])
rows.sort()
for i in range(0,len(self.samples)):
self.assertEqual(rows[i],self.samples[i],
'incorrect data retrieved or inserted'
)
finally:
con.close()
def help_nextset_setUp(self,cur):
''' Should create a procedure called deleteme
that returns two result sets, first the
number of rows in booze then "name from booze"
'''
raise NotImplementedError('Helper not implemented')
#sql="""
# create procedure deleteme as
# begin
# select count(*) from booze
# select name from booze
# end
#"""
#cur.execute(sql)
def help_nextset_tearDown(self,cur):
'If cleaning up is needed after nextSetTest'
raise NotImplementedError('Helper not implemented')
#cur.execute("drop procedure deleteme")
def test_nextset(self):
con = self._connect()
try:
cur = con.cursor()
if not hasattr(cur,'nextset'):
return
try:
self.executeDDL1(cur)
sql=self._populate()
for sql in self._populate():
cur.execute(sql)
self.help_nextset_setUp(cur)
cur.callproc('deleteme')
numberofrows=cur.fetchone()
assert numberofrows[0]== len(self.samples)
assert cur.nextset()
names=cur.fetchall()
assert len(names) == len(self.samples)
s=cur.nextset()
assert s is None, 'No more return sets, should return None'
finally:
self.help_nextset_tearDown(cur)
finally:
con.close()
def test_nextset(self):
raise NotImplementedError('Drivers need to override this test')
def test_arraysize(self):
# Not much here - rest of the tests for this are in test_fetchmany
con = self._connect()
try:
cur = con.cursor()
self.failUnless(hasattr(cur,'arraysize'),
'cursor.arraysize must be defined'
)
finally:
con.close()
def test_setinputsizes(self):
con = self._connect()
try:
cur = con.cursor()
cur.setinputsizes( (25,) )
self._paraminsert(cur) # Make sure cursor still works
finally:
con.close()
def test_setoutputsize_basic(self):
# Basic test is to make sure setoutputsize doesn't blow up
con = self._connect()
try:
cur = con.cursor()
cur.setoutputsize(1000)
cur.setoutputsize(2000,0)
self._paraminsert(cur) # Make sure the cursor still works
finally:
con.close()
def test_setoutputsize(self):
# Real test for setoutputsize is driver dependent
raise NotImplementedError('Driver needed to override this test')
def test_None(self):
con = self._connect()
try:
cur = con.cursor()
self.executeDDL1(cur)
cur.execute('insert into %sbooze values (NULL)' % self.table_prefix)
cur.execute('select name from %sbooze' % self.table_prefix)
r = cur.fetchall()
self.assertEqual(len(r),1)
self.assertEqual(len(r[0]),1)
self.assertEqual(r[0][0],None,'NULL value not returned as None')
finally:
con.close()
def test_Date(self):
d1 = self.driver.Date(2002,12,25)
d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0)))
# Can we assume this? API doesn't specify, but it seems implied
# self.assertEqual(str(d1),str(d2))
def test_Time(self):
t1 = self.driver.Time(13,45,30)
t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0)))
# Can we assume this? API doesn't specify, but it seems implied
# self.assertEqual(str(t1),str(t2))
def test_Timestamp(self):
t1 = self.driver.Timestamp(2002,12,25,13,45,30)
t2 = self.driver.TimestampFromTicks(
time.mktime((2002,12,25,13,45,30,0,0,0))
)
# Can we assume this? API doesn't specify, but it seems implied
# self.assertEqual(str(t1),str(t2))
def test_Binary(self):
b = self.driver.Binary(b'Something')
b = self.driver.Binary(b'')
def test_STRING(self):
self.failUnless(hasattr(self.driver,'STRING'),
'module.STRING must be defined'
)
def test_BINARY(self):
self.failUnless(hasattr(self.driver,'BINARY'),
'module.BINARY must be defined.'
)
def test_NUMBER(self):
self.failUnless(hasattr(self.driver,'NUMBER'),
'module.NUMBER must be defined.'
)
def test_DATETIME(self):
self.failUnless(hasattr(self.driver,'DATETIME'),
'module.DATETIME must be defined.'
)
def test_ROWID(self):
self.failUnless(hasattr(self.driver,'ROWID'),
'module.ROWID must be defined.'
)

144
tests/dbapi20_tpc.py Normal file
View File

@ -0,0 +1,144 @@
""" Python DB API 2.0 driver Two Phase Commit compliance test suite.
"""
import unittest
class TwoPhaseCommitTests(unittest.TestCase):
driver = None
def connect(self):
"""Make a database connection."""
raise NotImplementedError
_last_id = 0
_global_id_prefix = "dbapi20_tpc:"
def make_xid(self, con):
id = TwoPhaseCommitTests._last_id
TwoPhaseCommitTests._last_id += 1
return con.xid(42, f"{self._global_id_prefix}{id}", "qualifier")
def test_xid(self):
con = self.connect()
try:
xid = con.xid(42, "global", "bqual")
except self.driver.NotSupportedError:
self.fail("Driver does not support transaction IDs.")
self.assertEquals(xid[0], 42)
self.assertEquals(xid[1], "global")
self.assertEquals(xid[2], "bqual")
# Try some extremes for the transaction ID:
xid = con.xid(0, "", "")
self.assertEquals(tuple(xid), (0, "", ""))
xid = con.xid(0x7fffffff, "a" * 64, "b" * 64)
self.assertEquals(tuple(xid), (0x7fffffff, "a" * 64, "b" * 64))
def test_tpc_begin(self):
con = self.connect()
try:
xid = self.make_xid(con)
try:
con.tpc_begin(xid)
except self.driver.NotSupportedError:
self.fail("Driver does not support tpc_begin()")
finally:
con.close()
def test_tpc_commit_without_prepare(self):
con = self.connect()
try:
xid = self.make_xid(con)
con.tpc_begin(xid)
cursor = con.cursor()
cursor.execute("SELECT 1")
con.tpc_commit()
finally:
con.close()
def test_tpc_rollback_without_prepare(self):
con = self.connect()
try:
xid = self.make_xid(con)
con.tpc_begin(xid)
cursor = con.cursor()
cursor.execute("SELECT 1")
con.tpc_rollback()
finally:
con.close()
def test_tpc_commit_with_prepare(self):
con = self.connect()
try:
xid = self.make_xid(con)
con.tpc_begin(xid)
cursor = con.cursor()
cursor.execute("SELECT 1")
con.tpc_prepare()
con.tpc_commit()
finally:
con.close()
def test_tpc_rollback_with_prepare(self):
con = self.connect()
try:
xid = self.make_xid(con)
con.tpc_begin(xid)
cursor = con.cursor()
cursor.execute("SELECT 1")
con.tpc_prepare()
con.tpc_rollback()
finally:
con.close()
def test_tpc_begin_in_transaction_fails(self):
con = self.connect()
try:
xid = self.make_xid(con)
cursor = con.cursor()
cursor.execute("SELECT 1")
self.assertRaises(self.driver.ProgrammingError,
con.tpc_begin, xid)
finally:
con.close()
def test_tpc_begin_in_tpc_transaction_fails(self):
con = self.connect()
try:
xid = self.make_xid(con)
cursor = con.cursor()
cursor.execute("SELECT 1")
self.assertRaises(self.driver.ProgrammingError,
con.tpc_begin, xid)
finally:
con.close()
def test_commit_in_tpc_fails(self):
# calling commit() within a TPC transaction fails with
# ProgrammingError.
con = self.connect()
try:
xid = self.make_xid(con)
con.tpc_begin(xid)
self.assertRaises(self.driver.ProgrammingError, con.commit)
finally:
con.close()
def test_rollback_in_tpc_fails(self):
# calling rollback() within a TPC transaction fails with
# ProgrammingError.
con = self.connect()
try:
xid = self.make_xid(con)
con.tpc_begin(xid)
self.assertRaises(self.driver.ProgrammingError, con.rollback)
finally:
con.close()

546
tests/test_async.py Executable file
View File

@ -0,0 +1,546 @@
#!/usr/bin/env python
# test_async.py - unit test for asynchronous API
#
# Copyright (C) 2010-2019 Jan Urbański <wulczer@wulczer.org>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import gc
import time
import unittest
import warnings
import psycopg2
import psycopg2.errors
from psycopg2 import extensions as ext
from .testutils import (ConnectingTestCase, StringIO, skip_before_postgres,
skip_if_crdb, crdb_version, slow)
class PollableStub:
"""A 'pollable' wrapper allowing analysis of the `poll()` calls."""
def __init__(self, pollable):
self.pollable = pollable
self.polls = []
def fileno(self):
return self.pollable.fileno()
def poll(self):
rv = self.pollable.poll()
self.polls.append(rv)
return rv
class AsyncTests(ConnectingTestCase):
def setUp(self):
ConnectingTestCase.setUp(self)
self.sync_conn = self.conn
self.conn = self.connect(async_=True)
self.wait(self.conn)
curs = self.conn.cursor()
if crdb_version(self.sync_conn) is not None:
curs.execute("set experimental_enable_temp_tables = 'on'")
self.wait(curs)
curs.execute('''
CREATE TEMPORARY TABLE table1 (
id int PRIMARY KEY
)''')
self.wait(curs)
def test_connection_setup(self):
cur = self.conn.cursor()
sync_cur = self.sync_conn.cursor()
del cur, sync_cur
self.assert_(self.conn.async_)
self.assert_(not self.sync_conn.async_)
# the async connection should be autocommit
self.assert_(self.conn.autocommit)
self.assertEquals(self.conn.isolation_level, ext.ISOLATION_LEVEL_DEFAULT)
# check other properties to be found on the connection
self.assert_(self.conn.server_version)
self.assert_(self.conn.protocol_version in (2, 3))
self.assert_(self.conn.encoding in ext.encodings)
def test_async_named_cursor(self):
self.assertRaises(psycopg2.ProgrammingError,
self.conn.cursor, "name")
def test_async_select(self):
cur = self.conn.cursor()
self.assertFalse(self.conn.isexecuting())
cur.execute("select 'a'")
self.assertTrue(self.conn.isexecuting())
self.wait(cur)
self.assertFalse(self.conn.isexecuting())
self.assertEquals(cur.fetchone()[0], "a")
@slow
@skip_before_postgres(8, 2)
def test_async_callproc(self):
cur = self.conn.cursor()
cur.callproc("pg_sleep", (0.1, ))
self.assertTrue(self.conn.isexecuting())
self.wait(cur)
self.assertFalse(self.conn.isexecuting())
@slow
def test_async_after_async(self):
cur = self.conn.cursor()
cur2 = self.conn.cursor()
del cur2
cur.execute("insert into table1 values (1)")
# an async execute after an async one raises an exception
self.assertRaises(psycopg2.ProgrammingError,
cur.execute, "select * from table1")
# same for callproc
self.assertRaises(psycopg2.ProgrammingError,
cur.callproc, "version")
# but after you've waited it should be good
self.wait(cur)
cur.execute("select * from table1")
self.wait(cur)
self.assertEquals(cur.fetchall()[0][0], 1)
cur.execute("delete from table1")
self.wait(cur)
cur.execute("select * from table1")
self.wait(cur)
self.assertEquals(cur.fetchone(), None)
def test_fetch_after_async(self):
cur = self.conn.cursor()
cur.execute("select 'a'")
# a fetch after an asynchronous query should raise an error
self.assertRaises(psycopg2.ProgrammingError,
cur.fetchall)
# but after waiting it should work
self.wait(cur)
self.assertEquals(cur.fetchall()[0][0], "a")
def test_rollback_while_async(self):
cur = self.conn.cursor()
cur.execute("select 'a'")
# a rollback should not work in asynchronous mode
self.assertRaises(psycopg2.ProgrammingError, self.conn.rollback)
def test_commit_while_async(self):
cur = self.conn.cursor()
cur.execute("begin")
self.wait(cur)
cur.execute("insert into table1 values (1)")
# a commit should not work in asynchronous mode
self.assertRaises(psycopg2.ProgrammingError, self.conn.commit)
self.assertTrue(self.conn.isexecuting())
# but a manual commit should
self.wait(cur)
cur.execute("commit")
self.wait(cur)
cur.execute("select * from table1")
self.wait(cur)
self.assertEquals(cur.fetchall()[0][0], 1)
cur.execute("delete from table1")
self.wait(cur)
cur.execute("select * from table1")
self.wait(cur)
self.assertEquals(cur.fetchone(), None)
def test_set_parameters_while_async(self):
cur = self.conn.cursor()
cur.execute("select 'c'")
self.assertTrue(self.conn.isexecuting())
# getting transaction status works
self.assertEquals(self.conn.info.transaction_status,
ext.TRANSACTION_STATUS_ACTIVE)
self.assertTrue(self.conn.isexecuting())
# setting connection encoding should fail
self.assertRaises(psycopg2.ProgrammingError,
self.conn.set_client_encoding, "LATIN1")
# same for transaction isolation
self.assertRaises(psycopg2.ProgrammingError,
self.conn.set_isolation_level, 1)
def test_reset_while_async(self):
cur = self.conn.cursor()
cur.execute("select 'c'")
self.assertTrue(self.conn.isexecuting())
# a reset should fail
self.assertRaises(psycopg2.ProgrammingError, self.conn.reset)
def test_async_iter(self):
cur = self.conn.cursor()
cur.execute("begin")
self.wait(cur)
cur.execute("""
insert into table1 values (1);
insert into table1 values (2);
insert into table1 values (3);
""")
self.wait(cur)
cur.execute("select id from table1 order by id")
# iteration fails if a query is underway
self.assertRaises(psycopg2.ProgrammingError, list, cur)
# but after it's done it should work
self.wait(cur)
self.assertEquals(list(cur), [(1, ), (2, ), (3, )])
self.assertFalse(self.conn.isexecuting())
def test_copy_while_async(self):
cur = self.conn.cursor()
cur.execute("select 'a'")
# copy should fail
self.assertRaises(psycopg2.ProgrammingError,
cur.copy_from,
StringIO("1\n3\n5\n\\.\n"), "table1")
def test_lobject_while_async(self):
# large objects should be prohibited
self.assertRaises(psycopg2.ProgrammingError,
self.conn.lobject)
def test_async_executemany(self):
cur = self.conn.cursor()
self.assertRaises(
psycopg2.ProgrammingError,
cur.executemany, "insert into table1 values (%s)", [1, 2, 3])
def test_async_scroll(self):
cur = self.conn.cursor()
cur.execute("""
insert into table1 values (1);
insert into table1 values (2);
insert into table1 values (3);
""")
self.wait(cur)
cur.execute("select id from table1 order by id")
# scroll should fail if a query is underway
self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 1)
self.assertTrue(self.conn.isexecuting())
# but after it's done it should work
self.wait(cur)
cur.scroll(1)
self.assertEquals(cur.fetchall(), [(2, ), (3, )])
cur = self.conn.cursor()
cur.execute("select id from table1 order by id")
self.wait(cur)
cur2 = self.conn.cursor()
self.assertRaises(psycopg2.ProgrammingError, cur2.scroll, 1)
self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 4)
cur = self.conn.cursor()
cur.execute("select id from table1 order by id")
self.wait(cur)
cur.scroll(2)
cur.scroll(-1)
self.assertEquals(cur.fetchall(), [(2, ), (3, )])
def test_scroll(self):
cur = self.sync_conn.cursor()
cur.execute("create table table1 (id int)")
cur.execute("""
insert into table1 values (1);
insert into table1 values (2);
insert into table1 values (3);
""")
cur.execute("select id from table1 order by id")
cur.scroll(2)
cur.scroll(-1)
self.assertEquals(cur.fetchall(), [(2, ), (3, )])
def test_async_dont_read_all(self):
cur = self.conn.cursor()
cur.execute("select repeat('a', 10000); select repeat('b', 10000)")
# fetch the result
self.wait(cur)
# it should be the result of the second query
self.assertEquals(cur.fetchone()[0], "b" * 10000)
def test_async_subclass(self):
class MyConn(ext.connection):
def __init__(self, dsn, async_=0):
ext.connection.__init__(self, dsn, async_=async_)
conn = self.connect(connection_factory=MyConn, async_=True)
self.assert_(isinstance(conn, MyConn))
self.assert_(conn.async_)
conn.close()
@slow
@skip_if_crdb("flush on write flakey")
def test_flush_on_write(self):
# a very large query requires a flush loop to be sent to the backend
curs = self.conn.cursor()
for mb in 1, 5, 10, 20, 50:
size = mb * 1024 * 1024
stub = PollableStub(self.conn)
curs.execute("select %s;", ('x' * size,))
self.wait(stub)
self.assertEqual(size, len(curs.fetchone()[0]))
if stub.polls.count(ext.POLL_WRITE) > 1:
return
# This is more a testing glitch than an error: it happens
# on high load on linux: probably because the kernel has more
# buffers ready. A warning may be useful during development,
# but an error is bad during regression testing.
warnings.warn("sending a large query didn't trigger block on write.")
def test_sync_poll(self):
cur = self.sync_conn.cursor()
cur.execute("select 1")
# polling with a sync query works
cur.connection.poll()
self.assertEquals(cur.fetchone()[0], 1)
@slow
@skip_if_crdb("notify")
def test_notify(self):
cur = self.conn.cursor()
sync_cur = self.sync_conn.cursor()
sync_cur.execute("listen test_notify")
self.sync_conn.commit()
cur.execute("notify test_notify")
self.wait(cur)
self.assertEquals(self.sync_conn.notifies, [])
pid = self.conn.info.backend_pid
for _ in range(5):
self.wait(self.sync_conn)
if not self.sync_conn.notifies:
time.sleep(0.5)
continue
self.assertEquals(len(self.sync_conn.notifies), 1)
self.assertEquals(self.sync_conn.notifies.pop(),
(pid, "test_notify"))
return
self.fail("No NOTIFY in 2.5 seconds")
def test_async_fetch_wrong_cursor(self):
cur1 = self.conn.cursor()
cur2 = self.conn.cursor()
cur1.execute("select 1")
self.wait(cur1)
self.assertFalse(self.conn.isexecuting())
# fetching from a cursor with no results is an error
self.assertRaises(psycopg2.ProgrammingError, cur2.fetchone)
# fetching from the correct cursor works
self.assertEquals(cur1.fetchone()[0], 1)
def test_error(self):
cur = self.conn.cursor()
cur.execute("insert into table1 values (%s)", (1, ))
self.wait(cur)
cur.execute("insert into table1 values (%s)", (1, ))
# this should fail
self.assertRaises(psycopg2.IntegrityError, self.wait, cur)
cur.execute("insert into table1 values (%s); "
"insert into table1 values (%s)", (2, 2))
# this should fail as well (Postgres behaviour)
self.assertRaises(psycopg2.IntegrityError, self.wait, cur)
# but this should work
if crdb_version(self.sync_conn) is None:
cur.execute("insert into table1 values (%s)", (2, ))
self.wait(cur)
# and the cursor should be usable afterwards
cur.execute("insert into table1 values (%s)", (3, ))
self.wait(cur)
cur.execute("select * from table1 order by id")
self.wait(cur)
self.assertEquals(cur.fetchall(), [(1, ), (2, ), (3, )])
cur.execute("delete from table1")
self.wait(cur)
def test_stop_on_first_error(self):
cur = self.conn.cursor()
cur.execute("select 1; select x; select 1/0; select 2")
self.assertRaises(psycopg2.errors.UndefinedColumn, self.wait, cur)
cur.execute("select 1")
self.wait(cur)
self.assertEqual(cur.fetchone(), (1,))
def test_error_two_cursors(self):
cur = self.conn.cursor()
cur2 = self.conn.cursor()
cur.execute("select * from no_such_table")
self.assertRaises(psycopg2.ProgrammingError, self.wait, cur)
cur2.execute("select 1")
self.wait(cur2)
self.assertEquals(cur2.fetchone()[0], 1)
@skip_if_crdb("notice")
def test_notices(self):
del self.conn.notices[:]
cur = self.conn.cursor()
if self.conn.info.server_version >= 90300:
cur.execute("set client_min_messages=debug1")
self.wait(cur)
cur.execute("create temp table chatty (id serial primary key);")
self.wait(cur)
self.assertEqual("CREATE TABLE", cur.statusmessage)
self.assert_(self.conn.notices)
def test_async_cursor_gone(self):
cur = self.conn.cursor()
cur.execute("select 42;")
del cur
gc.collect()
self.assertRaises(psycopg2.InterfaceError, self.wait, self.conn)
# The connection is still usable
cur = self.conn.cursor()
cur.execute("select 42;")
self.wait(self.conn)
self.assertEqual(cur.fetchone(), (42,))
@skip_if_crdb("copy")
def test_async_connection_error_message(self):
try:
cnn = psycopg2.connect('dbname=thisdatabasedoesntexist', async_=True)
self.wait(cnn)
except psycopg2.Error as e:
self.assertNotEqual(str(e), "asynchronous connection failed",
"connection error reason lost")
else:
self.fail("no exception raised")
@skip_before_postgres(8, 2)
def test_copy_no_hang(self):
cur = self.conn.cursor()
cur.execute("copy (select 1) to stdout")
self.assertRaises(psycopg2.ProgrammingError, self.wait, self.conn)
@slow
@skip_if_crdb("notice")
@skip_before_postgres(9, 0)
def test_non_block_after_notification(self):
from select import select
cur = self.conn.cursor()
cur.execute("""
select 1;
do $$
begin
raise notice 'hello';
end
$$ language plpgsql;
select pg_sleep(1);
""")
polls = 0
while True:
state = self.conn.poll()
if state == psycopg2.extensions.POLL_OK:
break
elif state == psycopg2.extensions.POLL_READ:
select([self.conn], [], [], 0.1)
elif state == psycopg2.extensions.POLL_WRITE:
select([], [self.conn], [], 0.1)
else:
raise Exception("Unexpected result from poll: %r", state)
polls += 1
self.assert_(polls >= 8, polls)
def test_poll_noop(self):
self.conn.poll()
@skip_if_crdb("notify")
@skip_before_postgres(9, 0)
def test_poll_conn_for_notification(self):
with self.conn.cursor() as cur:
cur.execute("listen test")
self.wait(cur)
with self.sync_conn.cursor() as cur:
cur.execute("notify test, 'hello'")
self.sync_conn.commit()
for i in range(10):
self.conn.poll()
if self.conn.notifies:
n = self.conn.notifies.pop()
self.assertEqual(n.channel, 'test')
self.assertEqual(n.payload, 'hello')
break
time.sleep(0.1)
else:
self.fail("No notification received")
def test_close(self):
self.conn.close()
self.assertTrue(self.conn.closed)
self.assertTrue(self.conn.async_)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

48
tests/test_bugX000.py Executable file
View File

@ -0,0 +1,48 @@
#!/usr/bin/env python
# bugX000.py - test for DateTime object allocation bug
#
# Copyright (C) 2007-2019 Federico Di Gregorio <fog@debian.org>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import psycopg2
import time
import unittest
class DateTimeAllocationBugTestCase(unittest.TestCase):
def test_date_time_allocation_bug(self):
d1 = psycopg2.Date(2002, 12, 25)
d2 = psycopg2.DateFromTicks(time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0)))
t1 = psycopg2.Time(13, 45, 30)
t2 = psycopg2.TimeFromTicks(time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0)))
t1 = psycopg2.Timestamp(2002, 12, 25, 13, 45, 30)
t2 = psycopg2.TimestampFromTicks(
time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)))
del d1, d2, t1, t2
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

52
tests/test_bug_gc.py Executable file
View File

@ -0,0 +1,52 @@
#!/usr/bin/env python
# bug_gc.py - test for refcounting/GC bug
#
# Copyright (C) 2010-2019 Federico Di Gregorio <fog@debian.org>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import psycopg2
import psycopg2.extensions
import unittest
import gc
from .testutils import ConnectingTestCase, skip_if_no_uuid
class StolenReferenceTestCase(ConnectingTestCase):
@skip_if_no_uuid
def test_stolen_reference_bug(self):
def fish(val, cur):
gc.collect()
return 42
UUID = psycopg2.extensions.new_type((2950,), "UUID", fish)
psycopg2.extensions.register_type(UUID, self.conn)
curs = self.conn.cursor()
curs.execute("select 'b5219e01-19ab-4994-b71e-149225dc51e4'::uuid")
curs.fetchone()
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

117
tests/test_cancel.py Executable file
View File

@ -0,0 +1,117 @@
#!/usr/bin/env python
# test_cancel.py - unit test for query cancellation
#
# Copyright (C) 2010-2019 Jan Urbański <wulczer@wulczer.org>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import time
import threading
import psycopg2
import psycopg2.extensions
from psycopg2 import extras
from .testconfig import dsn
import unittest
from .testutils import ConnectingTestCase, skip_before_postgres, slow
from .testutils import skip_if_crdb
class CancelTests(ConnectingTestCase):
def setUp(self):
ConnectingTestCase.setUp(self)
skip_if_crdb("cancel", self.conn)
cur = self.conn.cursor()
cur.execute('''
CREATE TEMPORARY TABLE table1 (
id int PRIMARY KEY
)''')
self.conn.commit()
def test_empty_cancel(self):
self.conn.cancel()
@slow
@skip_before_postgres(8, 2)
def test_cancel(self):
errors = []
def neverending(conn):
cur = conn.cursor()
try:
self.assertRaises(psycopg2.extensions.QueryCanceledError,
cur.execute, "select pg_sleep(60)")
# make sure the connection still works
conn.rollback()
cur.execute("select 1")
self.assertEqual(cur.fetchall(), [(1, )])
except Exception as e:
errors.append(e)
raise
def canceller(conn):
cur = conn.cursor()
try:
conn.cancel()
except Exception as e:
errors.append(e)
raise
del cur
thread1 = threading.Thread(target=neverending, args=(self.conn, ))
# wait a bit to make sure that the other thread is already in
# pg_sleep -- ugly and racy, but the chances are ridiculously low
thread2 = threading.Timer(0.3, canceller, args=(self.conn, ))
thread1.start()
thread2.start()
thread1.join()
thread2.join()
self.assertEqual(errors, [])
@slow
@skip_before_postgres(8, 2)
def test_async_cancel(self):
async_conn = psycopg2.connect(dsn, async_=True)
self.assertRaises(psycopg2.OperationalError, async_conn.cancel)
extras.wait_select(async_conn)
cur = async_conn.cursor()
cur.execute("select pg_sleep(10)")
time.sleep(1)
self.assertTrue(async_conn.isexecuting())
async_conn.cancel()
self.assertRaises(psycopg2.extensions.QueryCanceledError,
extras.wait_select, async_conn)
cur.execute("select 1")
extras.wait_select(async_conn)
self.assertEqual(cur.fetchall(), [(1, )])
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

1944
tests/test_connection.py Executable file

File diff suppressed because it is too large Load Diff

404
tests/test_copy.py Executable file
View File

@ -0,0 +1,404 @@
#!/usr/bin/env python
# test_copy.py - unit test for COPY support
#
# Copyright (C) 2010-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import io
import sys
import string
import unittest
from .testutils import ConnectingTestCase, skip_before_postgres, slow, StringIO
from .testutils import skip_if_crdb
from itertools import cycle
from subprocess import Popen, PIPE
import psycopg2
import psycopg2.extensions
from .testutils import skip_copy_if_green, TextIOBase
from .testconfig import dsn
class MinimalRead(TextIOBase):
"""A file wrapper exposing the minimal interface to copy from."""
def __init__(self, f):
self.f = f
def read(self, size):
return self.f.read(size)
def readline(self):
return self.f.readline()
class MinimalWrite(TextIOBase):
"""A file wrapper exposing the minimal interface to copy to."""
def __init__(self, f):
self.f = f
def write(self, data):
return self.f.write(data)
@skip_copy_if_green
class CopyTests(ConnectingTestCase):
def setUp(self):
ConnectingTestCase.setUp(self)
self._create_temp_table()
def _create_temp_table(self):
skip_if_crdb("copy", self.conn)
curs = self.conn.cursor()
curs.execute('''
CREATE TEMPORARY TABLE tcopy (
id serial PRIMARY KEY,
data text
)''')
@slow
def test_copy_from(self):
curs = self.conn.cursor()
try:
self._copy_from(curs, nrecs=1024, srec=10 * 1024, copykw={})
finally:
curs.close()
@slow
def test_copy_from_insane_size(self):
# Trying to trigger a "would block" error
curs = self.conn.cursor()
try:
self._copy_from(curs, nrecs=10 * 1024, srec=10 * 1024,
copykw={'size': 20 * 1024 * 1024})
finally:
curs.close()
def test_copy_from_cols(self):
curs = self.conn.cursor()
f = StringIO()
for i in range(10):
f.write(f"{i}\n")
f.seek(0)
curs.copy_from(MinimalRead(f), "tcopy", columns=['id'])
curs.execute("select * from tcopy order by id")
self.assertEqual([(i, None) for i in range(10)], curs.fetchall())
def test_copy_from_cols_err(self):
curs = self.conn.cursor()
f = StringIO()
for i in range(10):
f.write(f"{i}\n")
f.seek(0)
def cols():
raise ZeroDivisionError()
yield 'id'
self.assertRaises(ZeroDivisionError,
curs.copy_from, MinimalRead(f), "tcopy", columns=cols())
@slow
def test_copy_to(self):
curs = self.conn.cursor()
try:
self._copy_from(curs, nrecs=1024, srec=10 * 1024, copykw={})
self._copy_to(curs, srec=10 * 1024)
finally:
curs.close()
def test_copy_text(self):
self.conn.set_client_encoding('latin1')
self._create_temp_table() # the above call closed the xn
abin = bytes(list(range(32, 127))
+ list(range(160, 256))).decode('latin1')
about = abin.replace('\\', '\\\\')
curs = self.conn.cursor()
curs.execute('insert into tcopy values (%s, %s)',
(42, abin))
f = io.StringIO()
curs.copy_to(f, 'tcopy', columns=('data',))
f.seek(0)
self.assertEqual(f.readline().rstrip(), about)
def test_copy_bytes(self):
self.conn.set_client_encoding('latin1')
self._create_temp_table() # the above call closed the xn
abin = bytes(list(range(32, 127))
+ list(range(160, 255))).decode('latin1')
about = abin.replace('\\', '\\\\').encode('latin1')
curs = self.conn.cursor()
curs.execute('insert into tcopy values (%s, %s)',
(42, abin))
f = io.BytesIO()
curs.copy_to(f, 'tcopy', columns=('data',))
f.seek(0)
self.assertEqual(f.readline().rstrip(), about)
def test_copy_expert_textiobase(self):
self.conn.set_client_encoding('latin1')
self._create_temp_table() # the above call closed the xn
abin = bytes(list(range(32, 127))
+ list(range(160, 256))).decode('latin1')
about = abin.replace('\\', '\\\\')
f = io.StringIO()
f.write(about)
f.seek(0)
curs = self.conn.cursor()
psycopg2.extensions.register_type(
psycopg2.extensions.UNICODE, curs)
curs.copy_expert('COPY tcopy (data) FROM STDIN', f)
curs.execute("select data from tcopy;")
self.assertEqual(curs.fetchone()[0], abin)
f = io.StringIO()
curs.copy_expert('COPY tcopy (data) TO STDOUT', f)
f.seek(0)
self.assertEqual(f.readline().rstrip(), about)
# same tests with setting size
f = io.StringIO()
f.write(about)
f.seek(0)
exp_size = 123
# hack here to leave file as is, only check size when reading
real_read = f.read
def read(_size, f=f, exp_size=exp_size):
self.assertEqual(_size, exp_size)
return real_read(_size)
f.read = read
curs.copy_expert('COPY tcopy (data) FROM STDIN', f, size=exp_size)
curs.execute("select data from tcopy;")
self.assertEqual(curs.fetchone()[0], abin)
def _copy_from(self, curs, nrecs, srec, copykw):
f = StringIO()
for i, c in zip(range(nrecs), cycle(string.ascii_letters)):
l = c * srec
f.write(f"{i}\t{l}\n")
f.seek(0)
curs.copy_from(MinimalRead(f), "tcopy", **copykw)
curs.execute("select count(*) from tcopy")
self.assertEqual(nrecs, curs.fetchone()[0])
curs.execute("select data from tcopy where id < %s order by id",
(len(string.ascii_letters),))
for i, (l,) in enumerate(curs):
self.assertEqual(l, string.ascii_letters[i] * srec)
def _copy_to(self, curs, srec):
f = StringIO()
curs.copy_to(MinimalWrite(f), "tcopy")
f.seek(0)
ntests = 0
for line in f:
n, s = line.split()
if int(n) < len(string.ascii_letters):
self.assertEqual(s, string.ascii_letters[int(n)] * srec)
ntests += 1
self.assertEqual(ntests, len(string.ascii_letters))
def test_copy_expert_file_refcount(self):
class Whatever:
pass
f = Whatever()
curs = self.conn.cursor()
self.assertRaises(TypeError,
curs.copy_expert, 'COPY tcopy (data) FROM STDIN', f)
def test_copy_no_column_limit(self):
cols = [f"c{i:050}" for i in range(200)]
curs = self.conn.cursor()
curs.execute('CREATE TEMPORARY TABLE manycols (%s)' % ',\n'.join(
["%s int" % c for c in cols]))
curs.execute("INSERT INTO manycols DEFAULT VALUES")
f = StringIO()
curs.copy_to(f, "manycols", columns=cols)
f.seek(0)
self.assertEqual(f.read().split(), ['\\N'] * len(cols))
f.seek(0)
curs.copy_from(f, "manycols", columns=cols)
curs.execute("select count(*) from manycols;")
self.assertEqual(curs.fetchone()[0], 2)
def test_copy_funny_names(self):
cols = ["select", "insert", "group"]
curs = self.conn.cursor()
curs.execute('CREATE TEMPORARY TABLE "select" (%s)' % ',\n'.join(
['"%s" int' % c for c in cols]))
curs.execute('INSERT INTO "select" DEFAULT VALUES')
f = StringIO()
curs.copy_to(f, "select", columns=cols)
f.seek(0)
self.assertEqual(f.read().split(), ['\\N'] * len(cols))
f.seek(0)
curs.copy_from(f, "select", columns=cols)
curs.execute('select count(*) from "select";')
self.assertEqual(curs.fetchone()[0], 2)
@skip_before_postgres(8, 2) # they don't send the count
def test_copy_rowcount(self):
curs = self.conn.cursor()
curs.copy_from(StringIO('aaa\nbbb\nccc\n'), 'tcopy', columns=['data'])
self.assertEqual(curs.rowcount, 3)
curs.copy_expert(
"copy tcopy (data) from stdin",
StringIO('ddd\neee\n'))
self.assertEqual(curs.rowcount, 2)
curs.copy_to(StringIO(), "tcopy")
self.assertEqual(curs.rowcount, 5)
curs.execute("insert into tcopy (data) values ('fff')")
curs.copy_expert("copy tcopy to stdout", StringIO())
self.assertEqual(curs.rowcount, 6)
def test_copy_rowcount_error(self):
curs = self.conn.cursor()
curs.execute("insert into tcopy (data) values ('fff')")
self.assertEqual(curs.rowcount, 1)
self.assertRaises(psycopg2.DataError,
curs.copy_from, StringIO('aaa\nbbb\nccc\n'), 'tcopy')
self.assertEqual(curs.rowcount, -1)
def test_copy_query(self):
curs = self.conn.cursor()
curs.copy_from(StringIO('aaa\nbbb\nccc\n'), 'tcopy', columns=['data'])
self.assert_(b"copy " in curs.query.lower())
self.assert_(b" from stdin" in curs.query.lower())
curs.copy_expert(
"copy tcopy (data) from stdin",
StringIO('ddd\neee\n'))
self.assert_(b"copy " in curs.query.lower())
self.assert_(b" from stdin" in curs.query.lower())
curs.copy_to(StringIO(), "tcopy")
self.assert_(b"copy " in curs.query.lower())
self.assert_(b" to stdout" in curs.query.lower())
curs.execute("insert into tcopy (data) values ('fff')")
curs.copy_expert("copy tcopy to stdout", StringIO())
self.assert_(b"copy " in curs.query.lower())
self.assert_(b" to stdout" in curs.query.lower())
@slow
def test_copy_from_segfault(self):
# issue #219
script = f"""import psycopg2
conn = psycopg2.connect({dsn!r})
curs = conn.cursor()
curs.execute("create table copy_segf (id int)")
try:
curs.execute("copy copy_segf from stdin")
except psycopg2.ProgrammingError:
pass
conn.close()
"""
proc = Popen([sys.executable, '-c', script])
proc.communicate()
self.assertEqual(0, proc.returncode)
@slow
def test_copy_to_segfault(self):
# issue #219
script = f"""import psycopg2
conn = psycopg2.connect({dsn!r})
curs = conn.cursor()
curs.execute("create table copy_segf (id int)")
try:
curs.execute("copy copy_segf to stdout")
except psycopg2.ProgrammingError:
pass
conn.close()
"""
proc = Popen([sys.executable, '-c', script], stdout=PIPE)
proc.communicate()
self.assertEqual(0, proc.returncode)
def test_copy_from_propagate_error(self):
class BrokenRead(TextIOBase):
def read(self, size):
return 1 / 0
def readline(self):
return 1 / 0
curs = self.conn.cursor()
# It seems we cannot do this, but now at least we propagate the error
# self.assertRaises(ZeroDivisionError,
# curs.copy_from, BrokenRead(), "tcopy")
try:
curs.copy_from(BrokenRead(), "tcopy")
except Exception as e:
self.assert_('ZeroDivisionError' in str(e))
def test_copy_to_propagate_error(self):
class BrokenWrite(TextIOBase):
def write(self, data):
return 1 / 0
curs = self.conn.cursor()
curs.execute("insert into tcopy values (10, 'hi')")
self.assertRaises(ZeroDivisionError,
curs.copy_to, BrokenWrite(), "tcopy")
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

701
tests/test_cursor.py Executable file
View File

@ -0,0 +1,701 @@
#!/usr/bin/env python
# test_cursor.py - unit test for cursor attributes
#
# Copyright (C) 2010-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import gc
import sys
import time
import ctypes
import pickle
import psycopg2
import psycopg2.extensions
import unittest
from datetime import date
from decimal import Decimal
from weakref import ref
from .testutils import (ConnectingTestCase, skip_before_postgres,
skip_if_no_getrefcount, slow, skip_if_no_superuser,
skip_if_windows, skip_if_crdb, crdb_version)
import psycopg2.extras
class CursorTests(ConnectingTestCase):
def test_close_idempotent(self):
cur = self.conn.cursor()
cur.close()
cur.close()
self.assert_(cur.closed)
def test_empty_query(self):
cur = self.conn.cursor()
self.assertRaises(psycopg2.ProgrammingError, cur.execute, "")
self.assertRaises(psycopg2.ProgrammingError, cur.execute, " ")
self.assertRaises(psycopg2.ProgrammingError, cur.execute, ";")
def test_executemany_propagate_exceptions(self):
conn = self.conn
cur = conn.cursor()
cur.execute("create table test_exc (data int);")
def buggygen():
yield 1 // 0
self.assertRaises(ZeroDivisionError,
cur.executemany, "insert into test_exc values (%s)", buggygen())
cur.close()
def test_mogrify_unicode(self):
conn = self.conn
cur = conn.cursor()
# test consistency between execute and mogrify.
# unicode query containing only ascii data
cur.execute("SELECT 'foo';")
self.assertEqual('foo', cur.fetchone()[0])
self.assertEqual(b"SELECT 'foo';", cur.mogrify("SELECT 'foo';"))
conn.set_client_encoding('UTF8')
snowman = "\u2603"
def b(s):
if isinstance(s, str):
return s.encode('utf8')
else:
return s
# unicode query with non-ascii data
cur.execute(f"SELECT '{snowman}';")
self.assertEqual(snowman.encode('utf8'), b(cur.fetchone()[0]))
self.assertQuotedEqual(f"SELECT '{snowman}';".encode('utf8'),
cur.mogrify(f"SELECT '{snowman}';"))
# unicode args
cur.execute("SELECT %s;", (snowman,))
self.assertEqual(snowman.encode("utf-8"), b(cur.fetchone()[0]))
self.assertQuotedEqual(f"SELECT '{snowman}';".encode('utf8'),
cur.mogrify("SELECT %s;", (snowman,)))
# unicode query and args
cur.execute("SELECT %s;", (snowman,))
self.assertEqual(snowman.encode("utf-8"), b(cur.fetchone()[0]))
self.assertQuotedEqual(f"SELECT '{snowman}';".encode('utf8'),
cur.mogrify("SELECT %s;", (snowman,)))
def test_mogrify_decimal_explodes(self):
conn = self.conn
cur = conn.cursor()
self.assertEqual(b'SELECT 10.3;',
cur.mogrify("SELECT %s;", (Decimal("10.3"),)))
@skip_if_no_getrefcount
def test_mogrify_leak_on_multiple_reference(self):
# issue #81: reference leak when a parameter value is referenced
# more than once from a dict.
cur = self.conn.cursor()
foo = (lambda x: x)('foo') * 10
nref1 = sys.getrefcount(foo)
cur.mogrify("select %(foo)s, %(foo)s, %(foo)s", {'foo': foo})
nref2 = sys.getrefcount(foo)
self.assertEqual(nref1, nref2)
def test_modify_closed(self):
cur = self.conn.cursor()
cur.close()
sql = cur.mogrify("select %s", (10,))
self.assertEqual(sql, b"select 10")
def test_bad_placeholder(self):
cur = self.conn.cursor()
self.assertRaises(psycopg2.ProgrammingError,
cur.mogrify, "select %(foo", {})
self.assertRaises(psycopg2.ProgrammingError,
cur.mogrify, "select %(foo", {'foo': 1})
self.assertRaises(psycopg2.ProgrammingError,
cur.mogrify, "select %(foo, %(bar)", {'foo': 1})
self.assertRaises(psycopg2.ProgrammingError,
cur.mogrify, "select %(foo, %(bar)", {'foo': 1, 'bar': 2})
def test_cast(self):
curs = self.conn.cursor()
self.assertEqual(42, curs.cast(20, '42'))
self.assertAlmostEqual(3.14, curs.cast(700, '3.14'))
self.assertEqual(Decimal('123.45'), curs.cast(1700, '123.45'))
self.assertEqual(date(2011, 1, 2), curs.cast(1082, '2011-01-02'))
self.assertEqual("who am i?", curs.cast(705, 'who am i?')) # unknown
def test_cast_specificity(self):
curs = self.conn.cursor()
self.assertEqual("foo", curs.cast(705, 'foo'))
D = psycopg2.extensions.new_type((705,), "DOUBLING", lambda v, c: v * 2)
psycopg2.extensions.register_type(D, self.conn)
self.assertEqual("foofoo", curs.cast(705, 'foo'))
T = psycopg2.extensions.new_type((705,), "TREBLING", lambda v, c: v * 3)
psycopg2.extensions.register_type(T, curs)
self.assertEqual("foofoofoo", curs.cast(705, 'foo'))
curs2 = self.conn.cursor()
self.assertEqual("foofoo", curs2.cast(705, 'foo'))
def test_weakref(self):
curs = self.conn.cursor()
w = ref(curs)
del curs
gc.collect()
self.assert_(w() is None)
def test_null_name(self):
curs = self.conn.cursor(None)
self.assertEqual(curs.name, None)
def test_description_attribs(self):
curs = self.conn.cursor()
curs.execute("""select
3.14::decimal(10,2) as pi,
'hello'::text as hi,
'2010-02-18'::date as now;
""")
self.assertEqual(len(curs.description), 3)
for c in curs.description:
self.assertEqual(len(c), 7) # DBAPI happy
for a in ('name', 'type_code', 'display_size', 'internal_size',
'precision', 'scale', 'null_ok'):
self.assert_(hasattr(c, a), a)
c = curs.description[0]
self.assertEqual(c.name, 'pi')
self.assert_(c.type_code in psycopg2.extensions.DECIMAL.values)
if crdb_version(self.conn) is None:
self.assert_(c.internal_size > 0)
self.assertEqual(c.precision, 10)
self.assertEqual(c.scale, 2)
c = curs.description[1]
self.assertEqual(c.name, 'hi')
self.assert_(c.type_code in psycopg2.STRING.values)
self.assert_(c.internal_size < 0)
self.assertEqual(c.precision, None)
self.assertEqual(c.scale, None)
c = curs.description[2]
self.assertEqual(c.name, 'now')
self.assert_(c.type_code in psycopg2.extensions.DATE.values)
self.assert_(c.internal_size > 0)
self.assertEqual(c.precision, None)
self.assertEqual(c.scale, None)
@skip_if_crdb("table oid")
def test_description_extra_attribs(self):
curs = self.conn.cursor()
curs.execute("""
create table testcol (
pi decimal(10,2),
hi text)
""")
curs.execute("select oid from pg_class where relname = %s", ('testcol',))
oid = curs.fetchone()[0]
curs.execute("insert into testcol values (3.14, 'hello')")
curs.execute("select hi, pi, 42 from testcol")
self.assertEqual(curs.description[0].table_oid, oid)
self.assertEqual(curs.description[0].table_column, 2)
self.assertEqual(curs.description[1].table_oid, oid)
self.assertEqual(curs.description[1].table_column, 1)
self.assertEqual(curs.description[2].table_oid, None)
self.assertEqual(curs.description[2].table_column, None)
def test_description_slice(self):
curs = self.conn.cursor()
curs.execute("select 1::int4 as a")
self.assertEqual(curs.description[0][0:2], ('a', 23))
def test_pickle_description(self):
curs = self.conn.cursor()
curs.execute('SELECT 1 AS foo')
description = curs.description
pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL)
unpickled = pickle.loads(pickled)
self.assertEqual(description, unpickled)
def test_column_refcount(self):
# Reproduce crash describe in ticket #1252
from psycopg2.extensions import Column
def do_stuff():
_ = Column(name='my_column')
for _ in range(1000):
do_stuff()
def test_bad_subclass(self):
# check that we get an error message instead of a segfault
# for badly written subclasses.
# see https://stackoverflow.com/questions/22019341/
class StupidCursor(psycopg2.extensions.cursor):
def __init__(self, *args, **kwargs):
# I am stupid so not calling superclass init
pass
cur = StupidCursor()
self.assertRaises(psycopg2.InterfaceError, cur.execute, 'select 1')
self.assertRaises(psycopg2.InterfaceError, cur.executemany,
'select 1', [])
def test_callproc_badparam(self):
cur = self.conn.cursor()
self.assertRaises(TypeError, cur.callproc, 'lower', 42)
# It would be inappropriate to test callproc's named parameters in the
# DBAPI2.0 test section because they are a psycopg2 extension.
@skip_before_postgres(9, 0)
@skip_if_crdb("stored procedure")
def test_callproc_dict(self):
# This parameter name tests for injection and quote escaping
paramname = '''
Robert'); drop table "students" --
'''.strip()
escaped_paramname = '"%s"' % paramname.replace('"', '""')
procname = 'pg_temp.randall'
cur = self.conn.cursor()
# Set up the temporary function
cur.execute(f'''
CREATE FUNCTION {procname}({escaped_paramname} INT)
RETURNS INT AS
'SELECT $1 * $1'
LANGUAGE SQL
''')
# Make sure callproc works right
cur.callproc(procname, {paramname: 2})
self.assertEquals(cur.fetchone()[0], 4)
# Make sure callproc fails right
failing_cases = [
({paramname: 2, 'foo': 'bar'}, psycopg2.ProgrammingError),
({paramname: '2'}, psycopg2.ProgrammingError),
({paramname: 'two'}, psycopg2.ProgrammingError),
({'bj\xc3rn': 2}, psycopg2.ProgrammingError),
({3: 2}, TypeError),
({self: 2}, TypeError),
]
for parameter_sequence, exception in failing_cases:
self.assertRaises(exception, cur.callproc, procname, parameter_sequence)
self.conn.rollback()
@skip_if_no_superuser
@skip_if_windows
@skip_if_crdb("backend pid")
@skip_before_postgres(8, 4)
def test_external_close_sync(self):
# If a "victim" connection is closed by a "control" connection
# behind psycopg2's back, psycopg2 always handles it correctly:
# raise OperationalError, set conn.closed to 2. This reproduces
# issue #443, a race between control_conn closing victim_conn and
# psycopg2 noticing.
control_conn = self.conn
connect_func = self.connect
def wait_func(conn):
pass
self._test_external_close(control_conn, connect_func, wait_func)
@skip_if_no_superuser
@skip_if_windows
@skip_if_crdb("backend pid")
@skip_before_postgres(8, 4)
def test_external_close_async(self):
# Issue #443 is in the async code too. Since the fix is duplicated,
# so is the test.
control_conn = self.conn
def connect_func():
return self.connect(async_=True)
wait_func = psycopg2.extras.wait_select
self._test_external_close(control_conn, connect_func, wait_func)
def _test_external_close(self, control_conn, connect_func, wait_func):
# The short sleep before using victim_conn the second time makes it
# much more likely to lose the race and see the bug. Repeating the
# test several times makes it even more likely.
for i in range(10):
victim_conn = connect_func()
wait_func(victim_conn)
with victim_conn.cursor() as cur:
cur.execute('select pg_backend_pid()')
wait_func(victim_conn)
pid1 = cur.fetchall()[0][0]
with control_conn.cursor() as cur:
cur.execute('select pg_terminate_backend(%s)', (pid1,))
time.sleep(0.001)
def f():
with victim_conn.cursor() as cur:
cur.execute('select 1')
wait_func(victim_conn)
self.assertRaises(psycopg2.OperationalError, f)
self.assertEqual(victim_conn.closed, 2)
@skip_before_postgres(8, 2)
def test_rowcount_on_executemany_returning(self):
cur = self.conn.cursor()
cur.execute("create table execmany(id serial primary key, data int)")
cur.executemany(
"insert into execmany (data) values (%s)",
[(i,) for i in range(4)])
self.assertEqual(cur.rowcount, 4)
cur.executemany(
"insert into execmany (data) values (%s) returning data",
[(i,) for i in range(5)])
self.assertEqual(cur.rowcount, 5)
@skip_before_postgres(9)
def test_pgresult_ptr(self):
curs = self.conn.cursor()
self.assert_(curs.pgresult_ptr is None)
curs.execute("select 'x'")
self.assert_(curs.pgresult_ptr is not None)
try:
f = self.libpq.PQcmdStatus
except AttributeError:
pass
else:
f.argtypes = [ctypes.c_void_p]
f.restype = ctypes.c_char_p
status = f(curs.pgresult_ptr)
self.assertEqual(status, b'SELECT 1')
curs.close()
self.assert_(curs.pgresult_ptr is None)
@skip_if_crdb("named cursor")
class NamedCursorTests(ConnectingTestCase):
def test_invalid_name(self):
curs = self.conn.cursor()
curs.execute("create table invname (data int);")
for i in (10, 20, 30):
curs.execute("insert into invname values (%s)", (i,))
curs.close()
curs = self.conn.cursor(r'1-2-3 \ "test"')
curs.execute("select data from invname order by data")
self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)])
def _create_withhold_table(self):
curs = self.conn.cursor()
try:
curs.execute("drop table withhold")
except psycopg2.ProgrammingError:
self.conn.rollback()
curs.execute("create table withhold (data int)")
for i in (10, 20, 30):
curs.execute("insert into withhold values (%s)", (i,))
curs.close()
def test_withhold(self):
self.assertRaises(psycopg2.ProgrammingError, self.conn.cursor,
withhold=True)
self._create_withhold_table()
curs = self.conn.cursor("W")
self.assertEqual(curs.withhold, False)
curs.withhold = True
self.assertEqual(curs.withhold, True)
curs.execute("select data from withhold order by data")
self.conn.commit()
self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)])
curs.close()
curs = self.conn.cursor("W", withhold=True)
self.assertEqual(curs.withhold, True)
curs.execute("select data from withhold order by data")
self.conn.commit()
self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)])
curs = self.conn.cursor()
curs.execute("drop table withhold")
self.conn.commit()
def test_withhold_no_begin(self):
self._create_withhold_table()
curs = self.conn.cursor("w", withhold=True)
curs.execute("select data from withhold order by data")
self.assertEqual(curs.fetchone(), (10,))
self.assertEqual(self.conn.status, psycopg2.extensions.STATUS_BEGIN)
self.assertEqual(self.conn.info.transaction_status,
psycopg2.extensions.TRANSACTION_STATUS_INTRANS)
self.conn.commit()
self.assertEqual(self.conn.status, psycopg2.extensions.STATUS_READY)
self.assertEqual(self.conn.info.transaction_status,
psycopg2.extensions.TRANSACTION_STATUS_IDLE)
self.assertEqual(curs.fetchone(), (20,))
self.assertEqual(self.conn.status, psycopg2.extensions.STATUS_READY)
self.assertEqual(self.conn.info.transaction_status,
psycopg2.extensions.TRANSACTION_STATUS_IDLE)
curs.close()
self.assertEqual(self.conn.status, psycopg2.extensions.STATUS_READY)
self.assertEqual(self.conn.info.transaction_status,
psycopg2.extensions.TRANSACTION_STATUS_IDLE)
def test_withhold_autocommit(self):
self._create_withhold_table()
self.conn.commit()
self.conn.autocommit = True
curs = self.conn.cursor("w", withhold=True)
curs.execute("select data from withhold order by data")
self.assertEqual(curs.fetchone(), (10,))
self.assertEqual(self.conn.status, psycopg2.extensions.STATUS_READY)
self.assertEqual(self.conn.info.transaction_status,
psycopg2.extensions.TRANSACTION_STATUS_IDLE)
self.conn.commit()
self.assertEqual(self.conn.status, psycopg2.extensions.STATUS_READY)
self.assertEqual(self.conn.info.transaction_status,
psycopg2.extensions.TRANSACTION_STATUS_IDLE)
curs.close()
self.assertEqual(self.conn.status, psycopg2.extensions.STATUS_READY)
self.assertEqual(self.conn.info.transaction_status,
psycopg2.extensions.TRANSACTION_STATUS_IDLE)
def test_scrollable(self):
self.assertRaises(psycopg2.ProgrammingError, self.conn.cursor,
scrollable=True)
curs = self.conn.cursor()
curs.execute("create table scrollable (data int)")
curs.executemany("insert into scrollable values (%s)",
[(i,) for i in range(100)])
curs.close()
for t in range(2):
if not t:
curs = self.conn.cursor("S")
self.assertEqual(curs.scrollable, None)
curs.scrollable = True
else:
curs = self.conn.cursor("S", scrollable=True)
self.assertEqual(curs.scrollable, True)
curs.itersize = 10
# complex enough to make postgres cursors declare without
# scroll/no scroll to fail
curs.execute("""
select x.data
from scrollable x
join scrollable y on x.data = y.data
order by y.data""")
for i, (n,) in enumerate(curs):
self.assertEqual(i, n)
curs.scroll(-1)
for i in range(99, -1, -1):
curs.scroll(-1)
self.assertEqual(i, curs.fetchone()[0])
curs.scroll(-1)
curs.close()
def test_not_scrollable(self):
self.assertRaises(psycopg2.ProgrammingError, self.conn.cursor,
scrollable=False)
curs = self.conn.cursor()
curs.execute("create table scrollable (data int)")
curs.executemany("insert into scrollable values (%s)",
[(i,) for i in range(100)])
curs.close()
curs = self.conn.cursor("S") # default scrollability
curs.execute("select * from scrollable")
self.assertEqual(curs.scrollable, None)
curs.scroll(2)
try:
curs.scroll(-1)
except psycopg2.OperationalError:
return self.skipTest("can't evaluate non-scrollable cursor")
curs.close()
curs = self.conn.cursor("S", scrollable=False)
self.assertEqual(curs.scrollable, False)
curs.execute("select * from scrollable")
curs.scroll(2)
self.assertRaises(psycopg2.OperationalError, curs.scroll, -1)
@slow
@skip_before_postgres(8, 2)
def test_iter_named_cursor_efficient(self):
curs = self.conn.cursor('tmp')
# if these records are fetched in the same roundtrip their
# timestamp will not be influenced by the pause in Python world.
curs.execute("""select clock_timestamp() from generate_series(1,2)""")
i = iter(curs)
t1 = next(i)[0]
time.sleep(0.2)
t2 = next(i)[0]
self.assert_((t2 - t1).microseconds * 1e-6 < 0.1,
f"named cursor records fetched in 2 roundtrips (delta: {t2 - t1})")
@skip_before_postgres(8, 0)
def test_iter_named_cursor_default_itersize(self):
curs = self.conn.cursor('tmp')
curs.execute('select generate_series(1,50)')
rv = [(r[0], curs.rownumber) for r in curs]
# everything swallowed in one gulp
self.assertEqual(rv, [(i, i) for i in range(1, 51)])
@skip_before_postgres(8, 0)
def test_iter_named_cursor_itersize(self):
curs = self.conn.cursor('tmp')
curs.itersize = 30
curs.execute('select generate_series(1,50)')
rv = [(r[0], curs.rownumber) for r in curs]
# everything swallowed in two gulps
self.assertEqual(rv, [(i, ((i - 1) % 30) + 1) for i in range(1, 51)])
@skip_before_postgres(8, 0)
def test_iter_named_cursor_rownumber(self):
curs = self.conn.cursor('tmp')
# note: this fails if itersize < dataset: internally we check
# rownumber == rowcount to detect when to read anoter page, so we
# would need an extra attribute to have a monotonic rownumber.
curs.itersize = 20
curs.execute('select generate_series(1,10)')
for i, rec in enumerate(curs):
self.assertEqual(i + 1, curs.rownumber)
@skip_before_postgres(8, 0)
def test_named_cursor_stealing(self):
# you can use a named cursor to iterate on a refcursor created
# somewhere else
cur1 = self.conn.cursor()
cur1.execute("DECLARE test CURSOR WITHOUT HOLD "
" FOR SELECT generate_series(1,7)")
cur2 = self.conn.cursor('test')
# can call fetch without execute
self.assertEqual((1,), cur2.fetchone())
self.assertEqual([(2,), (3,), (4,)], cur2.fetchmany(3))
self.assertEqual([(5,), (6,), (7,)], cur2.fetchall())
@skip_before_postgres(8, 2)
def test_named_noop_close(self):
cur = self.conn.cursor('test')
cur.close()
@skip_before_postgres(8, 2)
def test_stolen_named_cursor_close(self):
cur1 = self.conn.cursor()
cur1.execute("DECLARE test CURSOR WITHOUT HOLD "
" FOR SELECT generate_series(1,7)")
cur2 = self.conn.cursor('test')
cur2.close()
cur1.execute("DECLARE test CURSOR WITHOUT HOLD "
" FOR SELECT generate_series(1,7)")
cur2 = self.conn.cursor('test')
cur2.close()
@skip_before_postgres(8, 0)
def test_scroll(self):
cur = self.conn.cursor()
cur.execute("select generate_series(0,9)")
cur.scroll(2)
self.assertEqual(cur.fetchone(), (2,))
cur.scroll(2)
self.assertEqual(cur.fetchone(), (5,))
cur.scroll(2, mode='relative')
self.assertEqual(cur.fetchone(), (8,))
cur.scroll(-1)
self.assertEqual(cur.fetchone(), (8,))
cur.scroll(-2)
self.assertEqual(cur.fetchone(), (7,))
cur.scroll(2, mode='absolute')
self.assertEqual(cur.fetchone(), (2,))
# on the boundary
cur.scroll(0, mode='absolute')
self.assertEqual(cur.fetchone(), (0,))
self.assertRaises((IndexError, psycopg2.ProgrammingError),
cur.scroll, -1, mode='absolute')
cur.scroll(0, mode='absolute')
self.assertRaises((IndexError, psycopg2.ProgrammingError),
cur.scroll, -1)
cur.scroll(9, mode='absolute')
self.assertEqual(cur.fetchone(), (9,))
self.assertRaises((IndexError, psycopg2.ProgrammingError),
cur.scroll, 10, mode='absolute')
cur.scroll(9, mode='absolute')
self.assertRaises((IndexError, psycopg2.ProgrammingError),
cur.scroll, 1)
@skip_before_postgres(8, 0)
def test_scroll_named(self):
cur = self.conn.cursor('tmp', scrollable=True)
cur.execute("select generate_series(0,9)")
cur.scroll(2)
self.assertEqual(cur.fetchone(), (2,))
cur.scroll(2)
self.assertEqual(cur.fetchone(), (5,))
cur.scroll(2, mode='relative')
self.assertEqual(cur.fetchone(), (8,))
cur.scroll(9, mode='absolute')
self.assertEqual(cur.fetchone(), (9,))
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

555
tests/test_dates.py Executable file
View File

@ -0,0 +1,555 @@
#!/usr/bin/env python
# test_dates.py - unit test for dates handling
#
# Copyright (C) 2008-2019 James Henstridge <james@jamesh.id.au>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import sys
import math
import pickle
from datetime import date, datetime, time, timedelta, timezone
import psycopg2
from psycopg2.tz import FixedOffsetTimezone, ZERO
import unittest
from .testutils import ConnectingTestCase, skip_before_postgres, skip_if_crdb
def total_seconds(d):
"""Return total number of seconds of a timedelta as a float."""
return d.days * 24 * 60 * 60 + d.seconds + d.microseconds / 1000000.0
class CommonDatetimeTestsMixin:
def execute(self, *args):
self.curs.execute(*args)
return self.curs.fetchone()[0]
def test_parse_date(self):
value = self.DATE('2007-01-01', self.curs)
self.assert_(value is not None)
self.assertEqual(value.year, 2007)
self.assertEqual(value.month, 1)
self.assertEqual(value.day, 1)
def test_parse_null_date(self):
value = self.DATE(None, self.curs)
self.assertEqual(value, None)
def test_parse_incomplete_date(self):
self.assertRaises(psycopg2.DataError, self.DATE, '2007', self.curs)
self.assertRaises(psycopg2.DataError, self.DATE, '2007-01', self.curs)
def test_parse_time(self):
value = self.TIME('13:30:29', self.curs)
self.assert_(value is not None)
self.assertEqual(value.hour, 13)
self.assertEqual(value.minute, 30)
self.assertEqual(value.second, 29)
def test_parse_null_time(self):
value = self.TIME(None, self.curs)
self.assertEqual(value, None)
def test_parse_incomplete_time(self):
self.assertRaises(psycopg2.DataError, self.TIME, '13', self.curs)
self.assertRaises(psycopg2.DataError, self.TIME, '13:30', self.curs)
def test_parse_datetime(self):
value = self.DATETIME('2007-01-01 13:30:29', self.curs)
self.assert_(value is not None)
self.assertEqual(value.year, 2007)
self.assertEqual(value.month, 1)
self.assertEqual(value.day, 1)
self.assertEqual(value.hour, 13)
self.assertEqual(value.minute, 30)
self.assertEqual(value.second, 29)
def test_parse_null_datetime(self):
value = self.DATETIME(None, self.curs)
self.assertEqual(value, None)
def test_parse_incomplete_datetime(self):
self.assertRaises(psycopg2.DataError,
self.DATETIME, '2007', self.curs)
self.assertRaises(psycopg2.DataError,
self.DATETIME, '2007-01', self.curs)
self.assertRaises(psycopg2.DataError,
self.DATETIME, '2007-01-01 13', self.curs)
self.assertRaises(psycopg2.DataError,
self.DATETIME, '2007-01-01 13:30', self.curs)
def test_parse_null_interval(self):
value = self.INTERVAL(None, self.curs)
self.assertEqual(value, None)
class DatetimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
"""Tests for the datetime based date handling in psycopg2."""
def setUp(self):
ConnectingTestCase.setUp(self)
self.curs = self.conn.cursor()
self.DATE = psycopg2.extensions.PYDATE
self.TIME = psycopg2.extensions.PYTIME
self.DATETIME = psycopg2.extensions.PYDATETIME
self.INTERVAL = psycopg2.extensions.PYINTERVAL
def test_parse_bc_date(self):
# datetime does not support BC dates
self.assertRaises(ValueError, self.DATE, '00042-01-01 BC', self.curs)
def test_parse_bc_datetime(self):
# datetime does not support BC dates
self.assertRaises(ValueError, self.DATETIME,
'00042-01-01 13:30:29 BC', self.curs)
def test_parse_time_microseconds(self):
value = self.TIME('13:30:29.123456', self.curs)
self.assertEqual(value.second, 29)
self.assertEqual(value.microsecond, 123456)
def test_parse_datetime_microseconds(self):
value = self.DATETIME('2007-01-01 13:30:29.123456', self.curs)
self.assertEqual(value.second, 29)
self.assertEqual(value.microsecond, 123456)
def check_time_tz(self, str_offset, offset):
base = time(13, 30, 29)
base_str = '13:30:29'
value = self.TIME(base_str + str_offset, self.curs)
# Value has time zone info and correct UTC offset.
self.assertNotEqual(value.tzinfo, None),
self.assertEqual(value.utcoffset(), timedelta(seconds=offset))
# Time portion is correct.
self.assertEqual(value.replace(tzinfo=None), base)
def test_parse_time_timezone(self):
self.check_time_tz("+01", 3600)
self.check_time_tz("-01", -3600)
self.check_time_tz("+01:15", 4500)
self.check_time_tz("-01:15", -4500)
if sys.version_info < (3, 7):
# The Python < 3.7 datetime module does not support time zone
# offsets that are not a whole number of minutes.
# We round the offset to the nearest minute.
self.check_time_tz("+01:15:00", 60 * (60 + 15))
self.check_time_tz("+01:15:29", 60 * (60 + 15))
self.check_time_tz("+01:15:30", 60 * (60 + 16))
self.check_time_tz("+01:15:59", 60 * (60 + 16))
self.check_time_tz("-01:15:00", -60 * (60 + 15))
self.check_time_tz("-01:15:29", -60 * (60 + 15))
self.check_time_tz("-01:15:30", -60 * (60 + 16))
self.check_time_tz("-01:15:59", -60 * (60 + 16))
else:
self.check_time_tz("+01:15:00", 60 * (60 + 15))
self.check_time_tz("+01:15:29", 60 * (60 + 15) + 29)
self.check_time_tz("+01:15:30", 60 * (60 + 15) + 30)
self.check_time_tz("+01:15:59", 60 * (60 + 15) + 59)
self.check_time_tz("-01:15:00", -(60 * (60 + 15)))
self.check_time_tz("-01:15:29", -(60 * (60 + 15) + 29))
self.check_time_tz("-01:15:30", -(60 * (60 + 15) + 30))
self.check_time_tz("-01:15:59", -(60 * (60 + 15) + 59))
def check_datetime_tz(self, str_offset, offset):
base = datetime(2007, 1, 1, 13, 30, 29)
base_str = '2007-01-01 13:30:29'
value = self.DATETIME(base_str + str_offset, self.curs)
# Value has time zone info and correct UTC offset.
self.assertNotEqual(value.tzinfo, None),
self.assertEqual(value.utcoffset(), timedelta(seconds=offset))
# Datetime is correct.
self.assertEqual(value.replace(tzinfo=None), base)
# Conversion to UTC produces the expected offset.
UTC = timezone(timedelta(0))
value_utc = value.astimezone(UTC).replace(tzinfo=None)
self.assertEqual(base - value_utc, timedelta(seconds=offset))
def test_default_tzinfo(self):
self.curs.execute("select '2000-01-01 00:00+02:00'::timestamptz")
dt = self.curs.fetchone()[0]
self.assert_(isinstance(dt.tzinfo, timezone))
self.assertEqual(dt,
datetime(2000, 1, 1, tzinfo=timezone(timedelta(minutes=120))))
def test_fotz_tzinfo(self):
self.curs.tzinfo_factory = FixedOffsetTimezone
self.curs.execute("select '2000-01-01 00:00+02:00'::timestamptz")
dt = self.curs.fetchone()[0]
self.assert_(not isinstance(dt.tzinfo, timezone))
self.assert_(isinstance(dt.tzinfo, FixedOffsetTimezone))
self.assertEqual(dt,
datetime(2000, 1, 1, tzinfo=timezone(timedelta(minutes=120))))
def test_parse_datetime_timezone(self):
self.check_datetime_tz("+01", 3600)
self.check_datetime_tz("-01", -3600)
self.check_datetime_tz("+01:15", 4500)
self.check_datetime_tz("-01:15", -4500)
if sys.version_info < (3, 7):
# The Python < 3.7 datetime module does not support time zone
# offsets that are not a whole number of minutes.
# We round the offset to the nearest minute.
self.check_datetime_tz("+01:15:00", 60 * (60 + 15))
self.check_datetime_tz("+01:15:29", 60 * (60 + 15))
self.check_datetime_tz("+01:15:30", 60 * (60 + 16))
self.check_datetime_tz("+01:15:59", 60 * (60 + 16))
self.check_datetime_tz("-01:15:00", -60 * (60 + 15))
self.check_datetime_tz("-01:15:29", -60 * (60 + 15))
self.check_datetime_tz("-01:15:30", -60 * (60 + 16))
self.check_datetime_tz("-01:15:59", -60 * (60 + 16))
else:
self.check_datetime_tz("+01:15:00", 60 * (60 + 15))
self.check_datetime_tz("+01:15:29", 60 * (60 + 15) + 29)
self.check_datetime_tz("+01:15:30", 60 * (60 + 15) + 30)
self.check_datetime_tz("+01:15:59", 60 * (60 + 15) + 59)
self.check_datetime_tz("-01:15:00", -(60 * (60 + 15)))
self.check_datetime_tz("-01:15:29", -(60 * (60 + 15) + 29))
self.check_datetime_tz("-01:15:30", -(60 * (60 + 15) + 30))
self.check_datetime_tz("-01:15:59", -(60 * (60 + 15) + 59))
def test_parse_time_no_timezone(self):
self.assertEqual(self.TIME("13:30:29", self.curs).tzinfo, None)
self.assertEqual(self.TIME("13:30:29.123456", self.curs).tzinfo, None)
def test_parse_datetime_no_timezone(self):
self.assertEqual(
self.DATETIME("2007-01-01 13:30:29", self.curs).tzinfo, None)
self.assertEqual(
self.DATETIME("2007-01-01 13:30:29.123456", self.curs).tzinfo, None)
def test_parse_interval(self):
value = self.INTERVAL('42 days 12:34:56.123456', self.curs)
self.assertNotEqual(value, None)
self.assertEqual(value.days, 42)
self.assertEqual(value.seconds, 45296)
self.assertEqual(value.microseconds, 123456)
def test_parse_negative_interval(self):
value = self.INTERVAL('-42 days -12:34:56.123456', self.curs)
self.assertNotEqual(value, None)
self.assertEqual(value.days, -43)
self.assertEqual(value.seconds, 41103)
self.assertEqual(value.microseconds, 876544)
def test_parse_infinity(self):
value = self.DATETIME('-infinity', self.curs)
self.assertEqual(str(value), '0001-01-01 00:00:00')
value = self.DATETIME('infinity', self.curs)
self.assertEqual(str(value), '9999-12-31 23:59:59.999999')
value = self.DATE('infinity', self.curs)
self.assertEqual(str(value), '9999-12-31')
def test_adapt_date(self):
value = self.execute('select (%s)::date::text',
[date(2007, 1, 1)])
self.assertEqual(value, '2007-01-01')
def test_adapt_time(self):
value = self.execute('select (%s)::time::text',
[time(13, 30, 29)])
self.assertEqual(value, '13:30:29')
@skip_if_crdb("cast adds tz")
def test_adapt_datetime(self):
value = self.execute('select (%s)::timestamp::text',
[datetime(2007, 1, 1, 13, 30, 29)])
self.assertEqual(value, '2007-01-01 13:30:29')
def test_adapt_timedelta(self):
value = self.execute('select extract(epoch from (%s)::interval)',
[timedelta(days=42, seconds=45296,
microseconds=123456)])
seconds = math.floor(value)
self.assertEqual(seconds, 3674096)
self.assertEqual(int(round((value - seconds) * 1000000)), 123456)
def test_adapt_negative_timedelta(self):
value = self.execute('select extract(epoch from (%s)::interval)',
[timedelta(days=-42, seconds=45296,
microseconds=123456)])
seconds = math.floor(value)
self.assertEqual(seconds, -3583504)
self.assertEqual(int(round((value - seconds) * 1000000)), 123456)
def _test_type_roundtrip(self, o1):
o2 = self.execute("select %s;", (o1,))
self.assertEqual(type(o1), type(o2))
return o2
def _test_type_roundtrip_array(self, o1):
o1 = [o1]
o2 = self.execute("select %s;", (o1,))
self.assertEqual(type(o1[0]), type(o2[0]))
def test_type_roundtrip_date(self):
self._test_type_roundtrip(date(2010, 5, 3))
def test_type_roundtrip_datetime(self):
dt = self._test_type_roundtrip(datetime(2010, 5, 3, 10, 20, 30))
self.assertEqual(None, dt.tzinfo)
def test_type_roundtrip_datetimetz(self):
tz = timezone(timedelta(minutes=8 * 60))
dt1 = datetime(2010, 5, 3, 10, 20, 30, tzinfo=tz)
dt2 = self._test_type_roundtrip(dt1)
self.assertNotEqual(None, dt2.tzinfo)
self.assertEqual(dt1, dt2)
def test_type_roundtrip_time(self):
tm = self._test_type_roundtrip(time(10, 20, 30))
self.assertEqual(None, tm.tzinfo)
def test_type_roundtrip_timetz(self):
tz = timezone(timedelta(minutes=8 * 60))
tm1 = time(10, 20, 30, tzinfo=tz)
tm2 = self._test_type_roundtrip(tm1)
self.assertNotEqual(None, tm2.tzinfo)
self.assertEqual(tm1, tm2)
def test_type_roundtrip_interval(self):
self._test_type_roundtrip(timedelta(seconds=30))
def test_type_roundtrip_date_array(self):
self._test_type_roundtrip_array(date(2010, 5, 3))
def test_type_roundtrip_datetime_array(self):
self._test_type_roundtrip_array(datetime(2010, 5, 3, 10, 20, 30))
def test_type_roundtrip_datetimetz_array(self):
self._test_type_roundtrip_array(
datetime(2010, 5, 3, 10, 20, 30, tzinfo=timezone(timedelta(0))))
def test_type_roundtrip_time_array(self):
self._test_type_roundtrip_array(time(10, 20, 30))
def test_type_roundtrip_interval_array(self):
self._test_type_roundtrip_array(timedelta(seconds=30))
@skip_before_postgres(8, 1)
def test_time_24(self):
t = self.execute("select '24:00'::time;")
self.assertEqual(t, time(0, 0))
t = self.execute("select '24:00+05'::timetz;")
self.assertEqual(t, time(0, 0, tzinfo=timezone(timedelta(minutes=300))))
t = self.execute("select '24:00+05:30'::timetz;")
self.assertEqual(t, time(0, 0, tzinfo=timezone(timedelta(minutes=330))))
@skip_before_postgres(8, 1)
def test_large_interval(self):
t = self.execute("select '999999:00:00'::interval")
self.assertEqual(total_seconds(t), 999999 * 60 * 60)
t = self.execute("select '-999999:00:00'::interval")
self.assertEqual(total_seconds(t), -999999 * 60 * 60)
t = self.execute("select '999999:00:00.1'::interval")
self.assertEqual(total_seconds(t), 999999 * 60 * 60 + 0.1)
t = self.execute("select '999999:00:00.9'::interval")
self.assertEqual(total_seconds(t), 999999 * 60 * 60 + 0.9)
t = self.execute("select '-999999:00:00.1'::interval")
self.assertEqual(total_seconds(t), -999999 * 60 * 60 - 0.1)
t = self.execute("select '-999999:00:00.9'::interval")
self.assertEqual(total_seconds(t), -999999 * 60 * 60 - 0.9)
def test_micros_rounding(self):
t = self.execute("select '0.1'::interval")
self.assertEqual(total_seconds(t), 0.1)
t = self.execute("select '0.01'::interval")
self.assertEqual(total_seconds(t), 0.01)
t = self.execute("select '0.000001'::interval")
self.assertEqual(total_seconds(t), 1e-6)
t = self.execute("select '0.0000004'::interval")
self.assertEqual(total_seconds(t), 0)
t = self.execute("select '0.0000006'::interval")
self.assertEqual(total_seconds(t), 1e-6)
def test_interval_overflow(self):
cur = self.conn.cursor()
# hack a cursor to receive values too extreme to be represented
# but still I want an error, not a random number
psycopg2.extensions.register_type(
psycopg2.extensions.new_type(
psycopg2.STRING.values, 'WAT', psycopg2.extensions.INTERVAL),
cur)
def f(val):
cur.execute(f"select '{val}'::text")
return cur.fetchone()[0]
self.assertRaises(OverflowError, f, '100000000000000000:00:00')
self.assertRaises(OverflowError, f, '00:100000000000000000:00:00')
self.assertRaises(OverflowError, f, '00:00:100000000000000000:00')
self.assertRaises(OverflowError, f, '00:00:00.100000000000000000')
@skip_if_crdb("infinity date")
def test_adapt_infinity_tz(self):
t = self.execute("select 'infinity'::timestamp")
self.assert_(t.tzinfo is None)
self.assert_(t > datetime(4000, 1, 1))
t = self.execute("select '-infinity'::timestamp")
self.assert_(t.tzinfo is None)
self.assert_(t < datetime(1000, 1, 1))
t = self.execute("select 'infinity'::timestamptz")
self.assert_(t.tzinfo is not None)
self.assert_(t > datetime(4000, 1, 1, tzinfo=timezone(timedelta(0))))
t = self.execute("select '-infinity'::timestamptz")
self.assert_(t.tzinfo is not None)
self.assert_(t < datetime(1000, 1, 1, tzinfo=timezone(timedelta(0))))
def test_redshift_day(self):
# Redshift is reported returning 1 day interval as microsec (bug #558)
cur = self.conn.cursor()
psycopg2.extensions.register_type(
psycopg2.extensions.new_type(
psycopg2.STRING.values, 'WAT', psycopg2.extensions.INTERVAL),
cur)
for s, v in [
('0', timedelta(0)),
('1', timedelta(microseconds=1)),
('-1', timedelta(microseconds=-1)),
('1000000', timedelta(seconds=1)),
('86400000000', timedelta(days=1)),
('-86400000000', timedelta(days=-1)),
]:
cur.execute("select %s::text", (s,))
r = cur.fetchone()[0]
self.assertEqual(r, v, f"{s} -> {r} != {v}")
@skip_if_crdb("interval style")
@skip_before_postgres(8, 4)
def test_interval_iso_8601_not_supported(self):
# We may end up supporting, but no pressure for it
cur = self.conn.cursor()
cur.execute("set local intervalstyle to iso_8601")
cur.execute("select '1 day 2 hours'::interval")
self.assertRaises(psycopg2.NotSupportedError, cur.fetchone)
class FromTicksTestCase(unittest.TestCase):
# bug "TimestampFromTicks() throws ValueError (2-2.0.14)"
# reported by Jozsef Szalay on 2010-05-06
def test_timestamp_value_error_sec_59_99(self):
s = psycopg2.TimestampFromTicks(1273173119.99992)
self.assertEqual(s.adapted,
datetime(2010, 5, 6, 14, 11, 59, 999920,
tzinfo=timezone(timedelta(minutes=-5 * 60))))
def test_date_value_error_sec_59_99(self):
s = psycopg2.DateFromTicks(1273173119.99992)
# The returned date is local
self.assert_(s.adapted in [date(2010, 5, 6), date(2010, 5, 7)])
def test_time_value_error_sec_59_99(self):
s = psycopg2.TimeFromTicks(1273173119.99992)
self.assertEqual(s.adapted.replace(hour=0),
time(0, 11, 59, 999920))
class FixedOffsetTimezoneTests(unittest.TestCase):
def test_init_with_no_args(self):
tzinfo = FixedOffsetTimezone()
self.assert_(tzinfo._offset is ZERO)
self.assert_(tzinfo._name is None)
def test_repr_with_positive_offset(self):
tzinfo = FixedOffsetTimezone(5 * 60)
self.assertEqual(repr(tzinfo),
"psycopg2.tz.FixedOffsetTimezone(offset=%r, name=None)"
% timedelta(minutes=5 * 60))
def test_repr_with_negative_offset(self):
tzinfo = FixedOffsetTimezone(-5 * 60)
self.assertEqual(repr(tzinfo),
"psycopg2.tz.FixedOffsetTimezone(offset=%r, name=None)"
% timedelta(minutes=-5 * 60))
def test_init_with_timedelta(self):
td = timedelta(minutes=5 * 60)
tzinfo = FixedOffsetTimezone(td)
self.assertEqual(tzinfo, FixedOffsetTimezone(5 * 60))
self.assertEqual(repr(tzinfo),
"psycopg2.tz.FixedOffsetTimezone(offset=%r, name=None)" % td)
def test_repr_with_name(self):
tzinfo = FixedOffsetTimezone(name="FOO")
self.assertEqual(repr(tzinfo),
"psycopg2.tz.FixedOffsetTimezone(offset=%r, name='FOO')"
% timedelta(0))
def test_instance_caching(self):
self.assert_(FixedOffsetTimezone(name="FOO")
is FixedOffsetTimezone(name="FOO"))
self.assert_(FixedOffsetTimezone(7 * 60)
is FixedOffsetTimezone(7 * 60))
self.assert_(FixedOffsetTimezone(-9 * 60, 'FOO')
is FixedOffsetTimezone(-9 * 60, 'FOO'))
self.assert_(FixedOffsetTimezone(9 * 60)
is not FixedOffsetTimezone(9 * 60, 'FOO'))
self.assert_(FixedOffsetTimezone(name='FOO')
is not FixedOffsetTimezone(9 * 60, 'FOO'))
def test_pickle(self):
# ticket #135
tz11 = FixedOffsetTimezone(60)
tz12 = FixedOffsetTimezone(120)
for proto in [-1, 0, 1, 2]:
tz21, tz22 = pickle.loads(pickle.dumps([tz11, tz12], proto))
self.assertEqual(tz11, tz21)
self.assertEqual(tz12, tz22)
tz11 = FixedOffsetTimezone(60, name='foo')
tz12 = FixedOffsetTimezone(120, name='bar')
for proto in [-1, 0, 1, 2]:
tz21, tz22 = pickle.loads(pickle.dumps([tz11, tz12], proto))
self.assertEqual(tz11, tz21)
self.assertEqual(tz12, tz22)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

74
tests/test_errcodes.py Executable file
View File

@ -0,0 +1,74 @@
#!/usr/bin/env python
# test_errcodes.py - unit test for psycopg2.errcodes module
#
# Copyright (C) 2015-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import unittest
from .testutils import ConnectingTestCase, slow, reload
from threading import Thread
from psycopg2 import errorcodes
class ErrocodeTests(ConnectingTestCase):
@slow
def test_lookup_threadsafe(self):
# Increase if it does not fail with KeyError
MAX_CYCLES = 2000
errs = []
def f(pg_code='40001'):
try:
errorcodes.lookup(pg_code)
except Exception as e:
errs.append(e)
for __ in range(MAX_CYCLES):
reload(errorcodes)
(t1, t2) = (Thread(target=f), Thread(target=f))
(t1.start(), t2.start())
(t1.join(), t2.join())
if errs:
self.fail(
"raised {} errors in {} cycles (first is {} {})".format(
len(errs), MAX_CYCLES,
errs[0].__class__.__name__, errs[0]))
def test_ambiguous_names(self):
self.assertEqual(
errorcodes.lookup('2F004'), "READING_SQL_DATA_NOT_PERMITTED")
self.assertEqual(
errorcodes.lookup('38004'), "READING_SQL_DATA_NOT_PERMITTED")
self.assertEqual(errorcodes.READING_SQL_DATA_NOT_PERMITTED, '38004')
self.assertEqual(errorcodes.READING_SQL_DATA_NOT_PERMITTED_, '2F004')
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

94
tests/test_errors.py Executable file
View File

@ -0,0 +1,94 @@
#!/usr/bin/env python
# test_errors.py - unit test for psycopg2.errors module
#
# Copyright (C) 2018-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import unittest
from .testutils import ConnectingTestCase
import psycopg2
from psycopg2 import errors
from psycopg2._psycopg import sqlstate_errors
from psycopg2.errors import UndefinedTable
class ErrorsTests(ConnectingTestCase):
def test_exception_class(self):
cur = self.conn.cursor()
try:
cur.execute("select * from nonexist")
except psycopg2.Error as exc:
e = exc
self.assert_(isinstance(e, UndefinedTable), type(e))
self.assert_(isinstance(e, self.conn.ProgrammingError))
def test_exception_class_fallback(self):
cur = self.conn.cursor()
x = sqlstate_errors.pop('42P01')
try:
cur.execute("select * from nonexist")
except psycopg2.Error as exc:
e = exc
finally:
sqlstate_errors['42P01'] = x
self.assertEqual(type(e), self.conn.ProgrammingError)
def test_lookup(self):
self.assertIs(errors.lookup('42P01'), errors.UndefinedTable)
with self.assertRaises(KeyError):
errors.lookup('XXXXX')
def test_connection_exceptions_backwards_compatibility(self):
err = errors.lookup('08000')
# connection exceptions are classified as operational errors
self.assert_(issubclass(err, errors.OperationalError))
# previously these errors were classified only as DatabaseError
self.assert_(issubclass(err, errors.DatabaseError))
def test_has_base_exceptions(self):
excs = []
for n in dir(psycopg2):
obj = getattr(psycopg2, n)
if isinstance(obj, type) and issubclass(obj, Exception):
excs.append(obj)
self.assert_(len(excs) > 8, str(excs))
excs.append(psycopg2.extensions.QueryCanceledError)
excs.append(psycopg2.extensions.TransactionRollbackError)
for exc in excs:
self.assert_(hasattr(errors, exc.__name__))
self.assert_(getattr(errors, exc.__name__) is exc)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

646
tests/test_extras_dictcursor.py Executable file
View File

@ -0,0 +1,646 @@
#!/usr/bin/env python
#
# extras_dictcursor - test if DictCursor extension class works
#
# Copyright (C) 2004-2019 Federico Di Gregorio <fog@debian.org>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import copy
import time
import pickle
import unittest
from datetime import timedelta
from functools import lru_cache
import psycopg2
import psycopg2.extras
from psycopg2.extras import NamedTupleConnection, NamedTupleCursor
from .testutils import ConnectingTestCase, skip_before_postgres, \
crdb_version, skip_if_crdb
class _DictCursorBase(ConnectingTestCase):
def setUp(self):
ConnectingTestCase.setUp(self)
curs = self.conn.cursor()
if crdb_version(self.conn) is not None:
curs.execute("SET experimental_enable_temp_tables = 'on'")
curs.execute("CREATE TEMPORARY TABLE ExtrasDictCursorTests (foo text)")
curs.execute("INSERT INTO ExtrasDictCursorTests VALUES ('bar')")
self.conn.commit()
def _testIterRowNumber(self, curs):
# Only checking for dataset < itersize:
# see CursorTests.test_iter_named_cursor_rownumber
curs.itersize = 20
curs.execute("""select * from generate_series(1,10)""")
for i, r in enumerate(curs):
self.assertEqual(i + 1, curs.rownumber)
def _testNamedCursorNotGreedy(self, curs):
curs.itersize = 2
curs.execute("""select clock_timestamp() as ts from generate_series(1,3)""")
recs = []
for t in curs:
time.sleep(0.01)
recs.append(t)
# check that the dataset was not fetched in a single gulp
self.assert_(recs[1]['ts'] - recs[0]['ts'] < timedelta(seconds=0.005))
self.assert_(recs[2]['ts'] - recs[1]['ts'] > timedelta(seconds=0.0099))
class ExtrasDictCursorTests(_DictCursorBase):
"""Test if DictCursor extension class works."""
@skip_if_crdb("named cursor")
def testDictConnCursorArgs(self):
self.conn.close()
self.conn = self.connect(connection_factory=psycopg2.extras.DictConnection)
cur = self.conn.cursor()
self.assert_(isinstance(cur, psycopg2.extras.DictCursor))
self.assertEqual(cur.name, None)
# overridable
cur = self.conn.cursor('foo',
cursor_factory=psycopg2.extras.NamedTupleCursor)
self.assertEqual(cur.name, 'foo')
self.assert_(isinstance(cur, psycopg2.extras.NamedTupleCursor))
def testDictCursorWithPlainCursorFetchOne(self):
self._testWithPlainCursor(lambda curs: curs.fetchone())
def testDictCursorWithPlainCursorFetchMany(self):
self._testWithPlainCursor(lambda curs: curs.fetchmany(100)[0])
def testDictCursorWithPlainCursorFetchManyNoarg(self):
self._testWithPlainCursor(lambda curs: curs.fetchmany()[0])
def testDictCursorWithPlainCursorFetchAll(self):
self._testWithPlainCursor(lambda curs: curs.fetchall()[0])
def testDictCursorWithPlainCursorIter(self):
def getter(curs):
for row in curs:
return row
self._testWithPlainCursor(getter)
def testUpdateRow(self):
row = self._testWithPlainCursor(lambda curs: curs.fetchone())
row['foo'] = 'qux'
self.failUnless(row['foo'] == 'qux')
self.failUnless(row[0] == 'qux')
@skip_before_postgres(8, 0)
def testDictCursorWithPlainCursorIterRowNumber(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
self._testIterRowNumber(curs)
def _testWithPlainCursor(self, getter):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs)
self.failUnless(row['foo'] == 'bar')
self.failUnless(row[0] == 'bar')
return row
def testDictCursorWithNamedCursorFetchOne(self):
self._testWithNamedCursor(lambda curs: curs.fetchone())
def testDictCursorWithNamedCursorFetchMany(self):
self._testWithNamedCursor(lambda curs: curs.fetchmany(100)[0])
def testDictCursorWithNamedCursorFetchManyNoarg(self):
self._testWithNamedCursor(lambda curs: curs.fetchmany()[0])
def testDictCursorWithNamedCursorFetchAll(self):
self._testWithNamedCursor(lambda curs: curs.fetchall()[0])
def testDictCursorWithNamedCursorIter(self):
def getter(curs):
for row in curs:
return row
self._testWithNamedCursor(getter)
@skip_if_crdb("named cursor")
@skip_before_postgres(8, 2)
def testDictCursorWithNamedCursorNotGreedy(self):
curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.DictCursor)
self._testNamedCursorNotGreedy(curs)
@skip_if_crdb("named cursor")
@skip_before_postgres(8, 0)
def testDictCursorWithNamedCursorIterRowNumber(self):
curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.DictCursor)
self._testIterRowNumber(curs)
@skip_if_crdb("named cursor")
def _testWithNamedCursor(self, getter):
curs = self.conn.cursor('aname', cursor_factory=psycopg2.extras.DictCursor)
curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs)
self.failUnless(row['foo'] == 'bar')
self.failUnless(row[0] == 'bar')
def testPickleDictRow(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
curs.execute("select 10 as a, 20 as b")
r = curs.fetchone()
d = pickle.dumps(r)
r1 = pickle.loads(d)
self.assertEqual(r, r1)
self.assertEqual(r[0], r1[0])
self.assertEqual(r[1], r1[1])
self.assertEqual(r['a'], r1['a'])
self.assertEqual(r['b'], r1['b'])
self.assertEqual(r._index, r1._index)
def test_copy(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
curs.execute("select 10 as foo, 'hi' as bar")
rv = curs.fetchone()
self.assertEqual(len(rv), 2)
rv2 = copy.copy(rv)
self.assertEqual(len(rv2), 2)
self.assertEqual(len(rv), 2)
rv3 = copy.deepcopy(rv)
self.assertEqual(len(rv3), 2)
self.assertEqual(len(rv), 2)
def test_iter_methods(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
curs.execute("select 10 as a, 20 as b")
r = curs.fetchone()
self.assert_(not isinstance(r.keys(), list))
self.assertEqual(len(list(r.keys())), 2)
self.assert_(not isinstance(r.values(), list))
self.assertEqual(len(list(r.values())), 2)
self.assert_(not isinstance(r.items(), list))
self.assertEqual(len(list(r.items())), 2)
def test_order(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux")
r = curs.fetchone()
self.assertEqual(list(r), [5, 4, 33, 2])
self.assertEqual(list(r.keys()), ['foo', 'bar', 'baz', 'qux'])
self.assertEqual(list(r.values()), [5, 4, 33, 2])
self.assertEqual(list(r.items()),
[('foo', 5), ('bar', 4), ('baz', 33), ('qux', 2)])
r1 = pickle.loads(pickle.dumps(r))
self.assertEqual(list(r1), list(r))
self.assertEqual(list(r1.keys()), list(r.keys()))
self.assertEqual(list(r1.values()), list(r.values()))
self.assertEqual(list(r1.items()), list(r.items()))
class ExtrasDictCursorRealTests(_DictCursorBase):
def testRealMeansReal(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = curs.fetchone()
self.assert_(isinstance(row, dict))
def testDictCursorWithPlainCursorRealFetchOne(self):
self._testWithPlainCursorReal(lambda curs: curs.fetchone())
def testDictCursorWithPlainCursorRealFetchMany(self):
self._testWithPlainCursorReal(lambda curs: curs.fetchmany(100)[0])
def testDictCursorWithPlainCursorRealFetchManyNoarg(self):
self._testWithPlainCursorReal(lambda curs: curs.fetchmany()[0])
def testDictCursorWithPlainCursorRealFetchAll(self):
self._testWithPlainCursorReal(lambda curs: curs.fetchall()[0])
def testDictCursorWithPlainCursorRealIter(self):
def getter(curs):
for row in curs:
return row
self._testWithPlainCursorReal(getter)
@skip_before_postgres(8, 0)
def testDictCursorWithPlainCursorRealIterRowNumber(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
self._testIterRowNumber(curs)
def _testWithPlainCursorReal(self, getter):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs)
self.failUnless(row['foo'] == 'bar')
def testPickleRealDictRow(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
curs.execute("select 10 as a, 20 as b")
r = curs.fetchone()
d = pickle.dumps(r)
r1 = pickle.loads(d)
self.assertEqual(r, r1)
self.assertEqual(r['a'], r1['a'])
self.assertEqual(r['b'], r1['b'])
def test_copy(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
curs.execute("select 10 as foo, 'hi' as bar")
rv = curs.fetchone()
self.assertEqual(len(rv), 2)
rv2 = copy.copy(rv)
self.assertEqual(len(rv2), 2)
self.assertEqual(len(rv), 2)
rv3 = copy.deepcopy(rv)
self.assertEqual(len(rv3), 2)
self.assertEqual(len(rv), 2)
def testDictCursorRealWithNamedCursorFetchOne(self):
self._testWithNamedCursorReal(lambda curs: curs.fetchone())
def testDictCursorRealWithNamedCursorFetchMany(self):
self._testWithNamedCursorReal(lambda curs: curs.fetchmany(100)[0])
def testDictCursorRealWithNamedCursorFetchManyNoarg(self):
self._testWithNamedCursorReal(lambda curs: curs.fetchmany()[0])
def testDictCursorRealWithNamedCursorFetchAll(self):
self._testWithNamedCursorReal(lambda curs: curs.fetchall()[0])
def testDictCursorRealWithNamedCursorIter(self):
def getter(curs):
for row in curs:
return row
self._testWithNamedCursorReal(getter)
@skip_if_crdb("named cursor")
@skip_before_postgres(8, 2)
def testDictCursorRealWithNamedCursorNotGreedy(self):
curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.RealDictCursor)
self._testNamedCursorNotGreedy(curs)
@skip_if_crdb("named cursor")
@skip_before_postgres(8, 0)
def testDictCursorRealWithNamedCursorIterRowNumber(self):
curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.RealDictCursor)
self._testIterRowNumber(curs)
@skip_if_crdb("named cursor")
def _testWithNamedCursorReal(self, getter):
curs = self.conn.cursor('aname',
cursor_factory=psycopg2.extras.RealDictCursor)
curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs)
self.failUnless(row['foo'] == 'bar')
def test_iter_methods(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
curs.execute("select 10 as a, 20 as b")
r = curs.fetchone()
self.assert_(not isinstance(r.keys(), list))
self.assertEqual(len(list(r.keys())), 2)
self.assert_(not isinstance(r.values(), list))
self.assertEqual(len(list(r.values())), 2)
self.assert_(not isinstance(r.items(), list))
self.assertEqual(len(list(r.items())), 2)
def test_order(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux")
r = curs.fetchone()
self.assertEqual(list(r), ['foo', 'bar', 'baz', 'qux'])
self.assertEqual(list(r.keys()), ['foo', 'bar', 'baz', 'qux'])
self.assertEqual(list(r.values()), [5, 4, 33, 2])
self.assertEqual(list(r.items()),
[('foo', 5), ('bar', 4), ('baz', 33), ('qux', 2)])
r1 = pickle.loads(pickle.dumps(r))
self.assertEqual(list(r1), list(r))
self.assertEqual(list(r1.keys()), list(r.keys()))
self.assertEqual(list(r1.values()), list(r.values()))
self.assertEqual(list(r1.items()), list(r.items()))
def test_pop(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
curs.execute("select 1 as a, 2 as b, 3 as c")
r = curs.fetchone()
self.assertEqual(r.pop('b'), 2)
self.assertEqual(list(r), ['a', 'c'])
self.assertEqual(list(r.keys()), ['a', 'c'])
self.assertEqual(list(r.values()), [1, 3])
self.assertEqual(list(r.items()), [('a', 1), ('c', 3)])
self.assertEqual(r.pop('b', None), None)
self.assertRaises(KeyError, r.pop, 'b')
def test_mod(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
curs.execute("select 1 as a, 2 as b, 3 as c")
r = curs.fetchone()
r['d'] = 4
self.assertEqual(list(r), ['a', 'b', 'c', 'd'])
self.assertEqual(list(r.keys()), ['a', 'b', 'c', 'd'])
self.assertEqual(list(r.values()), [1, 2, 3, 4])
self.assertEqual(list(
r.items()), [('a', 1), ('b', 2), ('c', 3), ('d', 4)])
assert r['a'] == 1
assert r['b'] == 2
assert r['c'] == 3
assert r['d'] == 4
class NamedTupleCursorTest(ConnectingTestCase):
def setUp(self):
ConnectingTestCase.setUp(self)
self.conn = self.connect(connection_factory=NamedTupleConnection)
curs = self.conn.cursor()
if crdb_version(self.conn) is not None:
curs.execute("SET experimental_enable_temp_tables = 'on'")
curs.execute("CREATE TEMPORARY TABLE nttest (i int, s text)")
curs.execute("INSERT INTO nttest VALUES (1, 'foo')")
curs.execute("INSERT INTO nttest VALUES (2, 'bar')")
curs.execute("INSERT INTO nttest VALUES (3, 'baz')")
self.conn.commit()
@skip_if_crdb("named cursor")
def test_cursor_args(self):
cur = self.conn.cursor('foo', cursor_factory=psycopg2.extras.DictCursor)
self.assertEqual(cur.name, 'foo')
self.assert_(isinstance(cur, psycopg2.extras.DictCursor))
def test_fetchone(self):
curs = self.conn.cursor()
curs.execute("select * from nttest order by 1")
t = curs.fetchone()
self.assertEqual(t[0], 1)
self.assertEqual(t.i, 1)
self.assertEqual(t[1], 'foo')
self.assertEqual(t.s, 'foo')
self.assertEqual(curs.rownumber, 1)
self.assertEqual(curs.rowcount, 3)
def test_fetchmany_noarg(self):
curs = self.conn.cursor()
curs.arraysize = 2
curs.execute("select * from nttest order by 1")
res = curs.fetchmany()
self.assertEqual(2, len(res))
self.assertEqual(res[0].i, 1)
self.assertEqual(res[0].s, 'foo')
self.assertEqual(res[1].i, 2)
self.assertEqual(res[1].s, 'bar')
self.assertEqual(curs.rownumber, 2)
self.assertEqual(curs.rowcount, 3)
def test_fetchmany(self):
curs = self.conn.cursor()
curs.execute("select * from nttest order by 1")
res = curs.fetchmany(2)
self.assertEqual(2, len(res))
self.assertEqual(res[0].i, 1)
self.assertEqual(res[0].s, 'foo')
self.assertEqual(res[1].i, 2)
self.assertEqual(res[1].s, 'bar')
self.assertEqual(curs.rownumber, 2)
self.assertEqual(curs.rowcount, 3)
def test_fetchall(self):
curs = self.conn.cursor()
curs.execute("select * from nttest order by 1")
res = curs.fetchall()
self.assertEqual(3, len(res))
self.assertEqual(res[0].i, 1)
self.assertEqual(res[0].s, 'foo')
self.assertEqual(res[1].i, 2)
self.assertEqual(res[1].s, 'bar')
self.assertEqual(res[2].i, 3)
self.assertEqual(res[2].s, 'baz')
self.assertEqual(curs.rownumber, 3)
self.assertEqual(curs.rowcount, 3)
def test_executemany(self):
curs = self.conn.cursor()
curs.executemany("delete from nttest where i = %s",
[(1,), (2,)])
curs.execute("select * from nttest order by 1")
res = curs.fetchall()
self.assertEqual(1, len(res))
self.assertEqual(res[0].i, 3)
self.assertEqual(res[0].s, 'baz')
def test_iter(self):
curs = self.conn.cursor()
curs.execute("select * from nttest order by 1")
i = iter(curs)
self.assertEqual(curs.rownumber, 0)
t = next(i)
self.assertEqual(t.i, 1)
self.assertEqual(t.s, 'foo')
self.assertEqual(curs.rownumber, 1)
self.assertEqual(curs.rowcount, 3)
t = next(i)
self.assertEqual(t.i, 2)
self.assertEqual(t.s, 'bar')
self.assertEqual(curs.rownumber, 2)
self.assertEqual(curs.rowcount, 3)
t = next(i)
self.assertEqual(t.i, 3)
self.assertEqual(t.s, 'baz')
self.assertRaises(StopIteration, next, i)
self.assertEqual(curs.rownumber, 3)
self.assertEqual(curs.rowcount, 3)
def test_record_updated(self):
curs = self.conn.cursor()
curs.execute("select 1 as foo;")
r = curs.fetchone()
self.assertEqual(r.foo, 1)
curs.execute("select 2 as bar;")
r = curs.fetchone()
self.assertEqual(r.bar, 2)
self.assertRaises(AttributeError, getattr, r, 'foo')
def test_no_result_no_surprise(self):
curs = self.conn.cursor()
curs.execute("update nttest set s = s")
self.assertRaises(psycopg2.ProgrammingError, curs.fetchone)
curs.execute("update nttest set s = s")
self.assertRaises(psycopg2.ProgrammingError, curs.fetchall)
def test_bad_col_names(self):
curs = self.conn.cursor()
curs.execute('select 1 as "foo.bar_baz", 2 as "?column?", 3 as "3"')
rv = curs.fetchone()
self.assertEqual(rv.foo_bar_baz, 1)
self.assertEqual(rv.f_column_, 2)
self.assertEqual(rv.f3, 3)
@skip_before_postgres(8)
def test_nonascii_name(self):
curs = self.conn.cursor()
curs.execute('select 1 as \xe5h\xe9')
rv = curs.fetchone()
self.assertEqual(getattr(rv, '\xe5h\xe9'), 1)
def test_minimal_generation(self):
# Instrument the class to verify it gets called the minimum number of times.
f_orig = NamedTupleCursor._make_nt
calls = [0]
def f_patched(self_):
calls[0] += 1
return f_orig(self_)
NamedTupleCursor._make_nt = f_patched
try:
curs = self.conn.cursor()
curs.execute("select * from nttest order by 1")
curs.fetchone()
curs.fetchone()
curs.fetchone()
self.assertEqual(1, calls[0])
curs.execute("select * from nttest order by 1")
curs.fetchone()
curs.fetchall()
self.assertEqual(2, calls[0])
curs.execute("select * from nttest order by 1")
curs.fetchone()
curs.fetchmany(1)
self.assertEqual(3, calls[0])
finally:
NamedTupleCursor._make_nt = f_orig
@skip_if_crdb("named cursor")
@skip_before_postgres(8, 0)
def test_named(self):
curs = self.conn.cursor('tmp')
curs.execute("""select i from generate_series(0,9) i""")
recs = []
recs.extend(curs.fetchmany(5))
recs.append(curs.fetchone())
recs.extend(curs.fetchall())
self.assertEqual(list(range(10)), [t.i for t in recs])
@skip_if_crdb("named cursor")
def test_named_fetchone(self):
curs = self.conn.cursor('tmp')
curs.execute("""select 42 as i""")
t = curs.fetchone()
self.assertEqual(t.i, 42)
@skip_if_crdb("named cursor")
def test_named_fetchmany(self):
curs = self.conn.cursor('tmp')
curs.execute("""select 42 as i""")
recs = curs.fetchmany(10)
self.assertEqual(recs[0].i, 42)
@skip_if_crdb("named cursor")
def test_named_fetchall(self):
curs = self.conn.cursor('tmp')
curs.execute("""select 42 as i""")
recs = curs.fetchall()
self.assertEqual(recs[0].i, 42)
@skip_if_crdb("named cursor")
@skip_before_postgres(8, 2)
def test_not_greedy(self):
curs = self.conn.cursor('tmp')
curs.itersize = 2
curs.execute("""select clock_timestamp() as ts from generate_series(1,3)""")
recs = []
for t in curs:
time.sleep(0.01)
recs.append(t)
# check that the dataset was not fetched in a single gulp
self.assert_(recs[1].ts - recs[0].ts < timedelta(seconds=0.005))
self.assert_(recs[2].ts - recs[1].ts > timedelta(seconds=0.0099))
@skip_if_crdb("named cursor")
@skip_before_postgres(8, 0)
def test_named_rownumber(self):
curs = self.conn.cursor('tmp')
# Only checking for dataset < itersize:
# see CursorTests.test_iter_named_cursor_rownumber
curs.itersize = 4
curs.execute("""select * from generate_series(1,3)""")
for i, t in enumerate(curs):
self.assertEqual(i + 1, curs.rownumber)
def test_cache(self):
NamedTupleCursor._cached_make_nt.cache_clear()
curs = self.conn.cursor()
curs.execute("select 10 as a, 20 as b")
r1 = curs.fetchone()
curs.execute("select 10 as a, 20 as c")
r2 = curs.fetchone()
# Get a new cursor to check that the cache works across multiple ones
curs = self.conn.cursor()
curs.execute("select 10 as a, 30 as b")
r3 = curs.fetchone()
self.assert_(type(r1) is type(r3))
self.assert_(type(r1) is not type(r2))
cache_info = NamedTupleCursor._cached_make_nt.cache_info()
self.assertEqual(cache_info.hits, 1)
self.assertEqual(cache_info.misses, 2)
self.assertEqual(cache_info.currsize, 2)
def test_max_cache(self):
old_func = NamedTupleCursor._cached_make_nt
NamedTupleCursor._cached_make_nt = \
lru_cache(8)(NamedTupleCursor._cached_make_nt.__wrapped__)
try:
recs = []
curs = self.conn.cursor()
for i in range(10):
curs.execute(f"select 1 as f{i}")
recs.append(curs.fetchone())
# Still in cache
curs.execute("select 1 as f9")
rec = curs.fetchone()
self.assert_(any(type(r) is type(rec) for r in recs))
# Gone from cache
curs.execute("select 1 as f0")
rec = curs.fetchone()
self.assert_(all(type(r) is not type(rec) for r in recs))
finally:
NamedTupleCursor._cached_make_nt = old_func
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

269
tests/test_fast_executemany.py Executable file
View File

@ -0,0 +1,269 @@
#!/usr/bin/env python
#
# test_fast_executemany.py - tests for fast executemany implementations
#
# Copyright (C) 2017-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
from datetime import date
from . import testutils
import unittest
import psycopg2
import psycopg2.extras
import psycopg2.extensions as ext
from psycopg2 import sql
class TestPaginate(unittest.TestCase):
def test_paginate(self):
def pag(seq):
return psycopg2.extras._paginate(seq, 100)
self.assertEqual(list(pag([])), [])
self.assertEqual(list(pag([1])), [[1]])
self.assertEqual(list(pag(range(99))), [list(range(99))])
self.assertEqual(list(pag(range(100))), [list(range(100))])
self.assertEqual(list(pag(range(101))), [list(range(100)), [100]])
self.assertEqual(
list(pag(range(200))), [list(range(100)), list(range(100, 200))])
self.assertEqual(
list(pag(range(1000))),
[list(range(i * 100, (i + 1) * 100)) for i in range(10)])
class FastExecuteTestMixin:
def setUp(self):
super().setUp()
cur = self.conn.cursor()
cur.execute("""create table testfast (
id serial primary key, date date, val int, data text)""")
class TestExecuteBatch(FastExecuteTestMixin, testutils.ConnectingTestCase):
def test_empty(self):
cur = self.conn.cursor()
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)",
[])
cur.execute("select * from testfast order by id")
self.assertEqual(cur.fetchall(), [])
def test_one(self):
cur = self.conn.cursor()
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)",
iter([(1, 10)]))
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(1, 10)])
def test_tuples(self):
cur = self.conn.cursor()
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, date, val) values (%s, %s, %s)",
((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
cur.execute("select id, date, val from testfast order by id")
self.assertEqual(cur.fetchall(),
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
def test_many(self):
cur = self.conn.cursor()
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)",
((i, i * 10) for i in range(1000)))
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
def test_composed(self):
cur = self.conn.cursor()
psycopg2.extras.execute_batch(cur,
sql.SQL("insert into {0} (id, val) values (%s, %s)")
.format(sql.Identifier('testfast')),
((i, i * 10) for i in range(1000)))
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
def test_pages(self):
cur = self.conn.cursor()
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)",
((i, i * 10) for i in range(25)),
page_size=10)
# last command was 5 statements
self.assertEqual(sum(c == ';' for c in cur.query.decode('ascii')), 4)
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
@testutils.skip_before_postgres(8, 0)
def test_unicode(self):
cur = self.conn.cursor()
ext.register_type(ext.UNICODE, cur)
snowman = "\u2603"
# unicode in statement
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
[(1, 'x')])
cur.execute("select id, data from testfast where id = 1")
self.assertEqual(cur.fetchone(), (1, 'x'))
# unicode in data
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, data) values (%s, %s)",
[(2, snowman)])
cur.execute("select id, data from testfast where id = 2")
self.assertEqual(cur.fetchone(), (2, snowman))
# unicode in both
psycopg2.extras.execute_batch(cur,
"insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
[(3, snowman)])
cur.execute("select id, data from testfast where id = 3")
self.assertEqual(cur.fetchone(), (3, snowman))
@testutils.skip_before_postgres(8, 2)
class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase):
def test_empty(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s",
[])
cur.execute("select * from testfast order by id")
self.assertEqual(cur.fetchall(), [])
def test_one(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s",
iter([(1, 10)]))
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(1, 10)])
def test_tuples(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, date, val) values %s",
((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
cur.execute("select id, date, val from testfast order by id")
self.assertEqual(cur.fetchall(),
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
def test_dicts(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, date, val) values %s",
(dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar")
for i in range(10)),
template='(%(id)s, %(date)s, %(val)s)')
cur.execute("select id, date, val from testfast order by id")
self.assertEqual(cur.fetchall(),
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
def test_many(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s",
((i, i * 10) for i in range(1000)))
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
def test_composed(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
sql.SQL("insert into {0} (id, val) values %s")
.format(sql.Identifier('testfast')),
((i, i * 10) for i in range(1000)))
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
def test_pages(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s",
((i, i * 10) for i in range(25)),
page_size=10)
# last statement was 5 tuples (one parens is for the fields list)
self.assertEqual(sum(c == '(' for c in cur.query.decode('ascii')), 6)
cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
def test_unicode(self):
cur = self.conn.cursor()
ext.register_type(ext.UNICODE, cur)
snowman = "\u2603"
# unicode in statement
psycopg2.extras.execute_values(cur,
"insert into testfast (id, data) values %%s -- %s" % snowman,
[(1, 'x')])
cur.execute("select id, data from testfast where id = 1")
self.assertEqual(cur.fetchone(), (1, 'x'))
# unicode in data
psycopg2.extras.execute_values(cur,
"insert into testfast (id, data) values %s",
[(2, snowman)])
cur.execute("select id, data from testfast where id = 2")
self.assertEqual(cur.fetchone(), (2, snowman))
# unicode in both
psycopg2.extras.execute_values(cur,
"insert into testfast (id, data) values %%s -- %s" % snowman,
[(3, snowman)])
cur.execute("select id, data from testfast where id = 3")
self.assertEqual(cur.fetchone(), (3, snowman))
def test_returning(self):
cur = self.conn.cursor()
result = psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s returning id",
((i, i * 10) for i in range(25)),
page_size=10, fetch=True)
# result contains all returned pages
self.assertEqual([r[0] for r in result], list(range(25)))
def test_invalid_sql(self):
cur = self.conn.cursor()
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
"insert", [])
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
"insert %s and %s", [])
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
"insert %f", [])
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
"insert %f %s", [])
def test_percent_escape(self):
cur = self.conn.cursor()
psycopg2.extras.execute_values(cur,
"insert into testfast (id, data) values %s -- a%%b",
[(1, 'hi')])
self.assert_(b'a%%b' not in cur.query)
self.assert_(b'a%b' in cur.query)
cur.execute("select id, data from testfast")
self.assertEqual(cur.fetchall(), [(1, 'hi')])
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

246
tests/test_green.py Executable file
View File

@ -0,0 +1,246 @@
#!/usr/bin/env python
# test_green.py - unit test for async wait callback
#
# Copyright (C) 2010-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import select
import unittest
import warnings
import psycopg2
import psycopg2.extensions
import psycopg2.extras
from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE
from .testutils import ConnectingTestCase, skip_before_postgres, slow
from .testutils import skip_if_crdb
class ConnectionStub:
"""A `connection` wrapper allowing analysis of the `poll()` calls."""
def __init__(self, conn):
self.conn = conn
self.polls = []
def fileno(self):
return self.conn.fileno()
def poll(self):
rv = self.conn.poll()
self.polls.append(rv)
return rv
class GreenTestCase(ConnectingTestCase):
def setUp(self):
self._cb = psycopg2.extensions.get_wait_callback()
psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select)
ConnectingTestCase.setUp(self)
def tearDown(self):
ConnectingTestCase.tearDown(self)
psycopg2.extensions.set_wait_callback(self._cb)
def set_stub_wait_callback(self, conn, cb=None):
stub = ConnectionStub(conn)
psycopg2.extensions.set_wait_callback(
lambda conn: (cb or psycopg2.extras.wait_select)(stub))
return stub
@slow
@skip_if_crdb("flush on write flakey")
def test_flush_on_write(self):
# a very large query requires a flush loop to be sent to the backend
conn = self.conn
stub = self.set_stub_wait_callback(conn)
curs = conn.cursor()
for mb in 1, 5, 10, 20, 50:
size = mb * 1024 * 1024
del stub.polls[:]
curs.execute("select %s;", ('x' * size,))
self.assertEqual(size, len(curs.fetchone()[0]))
if stub.polls.count(psycopg2.extensions.POLL_WRITE) > 1:
return
# This is more a testing glitch than an error: it happens
# on high load on linux: probably because the kernel has more
# buffers ready. A warning may be useful during development,
# but an error is bad during regression testing.
warnings.warn("sending a large query didn't trigger block on write.")
def test_error_in_callback(self):
# behaviour changed after issue #113: if there is an error in the
# callback for the moment we don't have a way to reset the connection
# without blocking (ticket #113) so just close it.
conn = self.conn
curs = conn.cursor()
curs.execute("select 1") # have a BEGIN
curs.fetchone()
# now try to do something that will fail in the callback
psycopg2.extensions.set_wait_callback(lambda conn: 1 // 0)
self.assertRaises(ZeroDivisionError, curs.execute, "select 2")
self.assert_(conn.closed)
def test_dont_freak_out(self):
# if there is an error in a green query, don't freak out and close
# the connection
conn = self.conn
curs = conn.cursor()
self.assertRaises(psycopg2.ProgrammingError,
curs.execute, "select the unselectable")
# check that the connection is left in an usable state
self.assert_(not conn.closed)
conn.rollback()
curs.execute("select 1")
self.assertEqual(curs.fetchone()[0], 1)
@skip_before_postgres(8, 2)
def test_copy_no_hang(self):
cur = self.conn.cursor()
self.assertRaises(psycopg2.ProgrammingError,
cur.execute, "copy (select 1) to stdout")
@slow
@skip_if_crdb("notice")
@skip_before_postgres(9, 0)
def test_non_block_after_notice(self):
def wait(conn):
while 1:
state = conn.poll()
if state == POLL_OK:
break
elif state == POLL_READ:
select.select([conn.fileno()], [], [], 0.1)
elif state == POLL_WRITE:
select.select([], [conn.fileno()], [], 0.1)
else:
raise conn.OperationalError(f"bad state from poll: {state}")
stub = self.set_stub_wait_callback(self.conn, wait)
cur = self.conn.cursor()
cur.execute("""
select 1;
do $$
begin
raise notice 'hello';
end
$$ language plpgsql;
select pg_sleep(1);
""")
polls = stub.polls.count(POLL_READ)
self.assert_(polls > 8, polls)
class CallbackErrorTestCase(ConnectingTestCase):
def setUp(self):
self._cb = psycopg2.extensions.get_wait_callback()
psycopg2.extensions.set_wait_callback(self.crappy_callback)
ConnectingTestCase.setUp(self)
self.to_error = None
def tearDown(self):
ConnectingTestCase.tearDown(self)
psycopg2.extensions.set_wait_callback(self._cb)
def crappy_callback(self, conn):
"""green callback failing after `self.to_error` time it is called"""
while True:
if self.to_error is not None:
self.to_error -= 1
if self.to_error <= 0:
raise ZeroDivisionError("I accidentally the connection")
try:
state = conn.poll()
if state == POLL_OK:
break
elif state == POLL_READ:
select.select([conn.fileno()], [], [])
elif state == POLL_WRITE:
select.select([], [conn.fileno()], [])
else:
raise conn.OperationalError(f"bad state from poll: {state}")
except KeyboardInterrupt:
conn.cancel()
# the loop will be broken by a server error
continue
def test_errors_on_connection(self):
# Test error propagation in the different stages of the connection
for i in range(100):
self.to_error = i
try:
self.connect()
except ZeroDivisionError:
pass
else:
# We managed to connect
return
self.fail("you should have had a success or an error by now")
def test_errors_on_query(self):
for i in range(100):
self.to_error = None
cnn = self.connect()
cur = cnn.cursor()
self.to_error = i
try:
cur.execute("select 1")
cur.fetchone()
except ZeroDivisionError:
pass
else:
# The query completed
return
self.fail("you should have had a success or an error by now")
@skip_if_crdb("named cursor")
def test_errors_named_cursor(self):
for i in range(100):
self.to_error = None
cnn = self.connect()
cur = cnn.cursor('foo')
self.to_error = i
try:
cur.execute("select 1")
cur.fetchone()
except ZeroDivisionError:
pass
else:
# The query completed
return
self.fail("you should have had a success or an error by now")
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

120
tests/test_ipaddress.py Executable file
View File

@ -0,0 +1,120 @@
#!/usr/bin/env python
#
# test_ipaddress.py - tests for ipaddress support
#
# Copyright (C) 2016-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
from . import testutils
import unittest
import psycopg2
import psycopg2.extras
try:
import ipaddress as ip
except ImportError:
# Python 2
ip = None
@unittest.skipIf(ip is None, "'ipaddress' module not available")
class NetworkingTestCase(testutils.ConnectingTestCase):
def test_inet_cast(self):
cur = self.conn.cursor()
psycopg2.extras.register_ipaddress(cur)
cur.execute("select null::inet")
self.assert_(cur.fetchone()[0] is None)
cur.execute("select '127.0.0.1/24'::inet")
obj = cur.fetchone()[0]
self.assert_(isinstance(obj, ip.IPv4Interface), repr(obj))
self.assertEquals(obj, ip.ip_interface('127.0.0.1/24'))
cur.execute("select '::ffff:102:300/128'::inet")
obj = cur.fetchone()[0]
self.assert_(isinstance(obj, ip.IPv6Interface), repr(obj))
self.assertEquals(obj, ip.ip_interface('::ffff:102:300/128'))
@testutils.skip_before_postgres(8, 2)
def test_inet_array_cast(self):
cur = self.conn.cursor()
psycopg2.extras.register_ipaddress(cur)
cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::inet[]")
l = cur.fetchone()[0]
self.assert_(l[0] is None)
self.assertEquals(l[1], ip.ip_interface('127.0.0.1'))
self.assertEquals(l[2], ip.ip_interface('::ffff:102:300/128'))
self.assert_(isinstance(l[1], ip.IPv4Interface), l)
self.assert_(isinstance(l[2], ip.IPv6Interface), l)
def test_inet_adapt(self):
cur = self.conn.cursor()
psycopg2.extras.register_ipaddress(cur)
cur.execute("select %s", [ip.ip_interface('127.0.0.1/24')])
self.assertEquals(cur.fetchone()[0], '127.0.0.1/24')
cur.execute("select %s", [ip.ip_interface('::ffff:102:300/128')])
self.assertEquals(cur.fetchone()[0], '::ffff:102:300/128')
@testutils.skip_if_crdb("cidr")
def test_cidr_cast(self):
cur = self.conn.cursor()
psycopg2.extras.register_ipaddress(cur)
cur.execute("select null::cidr")
self.assert_(cur.fetchone()[0] is None)
cur.execute("select '127.0.0.0/24'::cidr")
obj = cur.fetchone()[0]
self.assert_(isinstance(obj, ip.IPv4Network), repr(obj))
self.assertEquals(obj, ip.ip_network('127.0.0.0/24'))
cur.execute("select '::ffff:102:300/128'::cidr")
obj = cur.fetchone()[0]
self.assert_(isinstance(obj, ip.IPv6Network), repr(obj))
self.assertEquals(obj, ip.ip_network('::ffff:102:300/128'))
@testutils.skip_if_crdb("cidr")
@testutils.skip_before_postgres(8, 2)
def test_cidr_array_cast(self):
cur = self.conn.cursor()
psycopg2.extras.register_ipaddress(cur)
cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::cidr[]")
l = cur.fetchone()[0]
self.assert_(l[0] is None)
self.assertEquals(l[1], ip.ip_network('127.0.0.1'))
self.assertEquals(l[2], ip.ip_network('::ffff:102:300/128'))
self.assert_(isinstance(l[1], ip.IPv4Network), l)
self.assert_(isinstance(l[2], ip.IPv6Network), l)
def test_cidr_adapt(self):
cur = self.conn.cursor()
psycopg2.extras.register_ipaddress(cur)
cur.execute("select %s", [ip.ip_network('127.0.0.0/24')])
self.assertEquals(cur.fetchone()[0], '127.0.0.0/24')
cur.execute("select %s", [ip.ip_network('::ffff:102:300/128')])
self.assertEquals(cur.fetchone()[0], '::ffff:102:300/128')
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

530
tests/test_lobject.py Executable file
View File

@ -0,0 +1,530 @@
#!/usr/bin/env python
# test_lobject.py - unit test for large objects support
#
# Copyright (C) 2008-2019 James Henstridge <james@jamesh.id.au>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import os
import shutil
import tempfile
from functools import wraps
import psycopg2
import psycopg2.extensions
import unittest
from .testutils import (decorate_all_tests, skip_if_tpc_disabled,
skip_before_postgres, ConnectingTestCase, skip_if_green, skip_if_crdb, slow)
def skip_if_no_lo(f):
f = skip_before_postgres(8, 1, "large objects only supported from PG 8.1")(f)
f = skip_if_green("libpq doesn't support LO in async mode")(f)
f = skip_if_crdb("large objects")(f)
return f
class LargeObjectTestCase(ConnectingTestCase):
def setUp(self):
ConnectingTestCase.setUp(self)
self.lo_oid = None
self.tmpdir = None
def tearDown(self):
if self.tmpdir:
shutil.rmtree(self.tmpdir, ignore_errors=True)
if self.conn.closed:
return
if self.lo_oid is not None:
self.conn.rollback()
try:
lo = self.conn.lobject(self.lo_oid, "n")
except psycopg2.OperationalError:
pass
else:
lo.unlink()
ConnectingTestCase.tearDown(self)
@skip_if_no_lo
class LargeObjectTests(LargeObjectTestCase):
def test_create(self):
lo = self.conn.lobject()
self.assertNotEqual(lo, None)
self.assertEqual(lo.mode[0], "w")
def test_connection_needed(self):
self.assertRaises(TypeError,
psycopg2.extensions.lobject, [])
def test_open_non_existent(self):
# By creating then removing a large object, we get an Oid that
# should be unused.
lo = self.conn.lobject()
lo.unlink()
self.assertRaises(psycopg2.OperationalError, self.conn.lobject, lo.oid)
def test_open_existing(self):
lo = self.conn.lobject()
lo2 = self.conn.lobject(lo.oid)
self.assertNotEqual(lo2, None)
self.assertEqual(lo2.oid, lo.oid)
self.assertEqual(lo2.mode[0], "r")
def test_open_for_write(self):
lo = self.conn.lobject()
lo2 = self.conn.lobject(lo.oid, "w")
self.assertEqual(lo2.mode[0], "w")
lo2.write(b"some data")
def test_open_mode_n(self):
# Openning an object in mode "n" gives us a closed lobject.
lo = self.conn.lobject()
lo.close()
lo2 = self.conn.lobject(lo.oid, "n")
self.assertEqual(lo2.oid, lo.oid)
self.assertEqual(lo2.closed, True)
def test_mode_defaults(self):
lo = self.conn.lobject()
lo2 = self.conn.lobject(mode=None)
lo3 = self.conn.lobject(mode="")
self.assertEqual(lo.mode, lo2.mode)
self.assertEqual(lo.mode, lo3.mode)
def test_close_connection_gone(self):
lo = self.conn.lobject()
self.conn.close()
lo.close()
def test_create_with_oid(self):
# Create and delete a large object to get an unused Oid.
lo = self.conn.lobject()
oid = lo.oid
lo.unlink()
lo = self.conn.lobject(0, "w", oid)
self.assertEqual(lo.oid, oid)
def test_create_with_existing_oid(self):
lo = self.conn.lobject()
lo.close()
self.assertRaises(psycopg2.OperationalError,
self.conn.lobject, 0, "w", lo.oid)
self.assert_(not self.conn.closed)
def test_import(self):
self.tmpdir = tempfile.mkdtemp()
filename = os.path.join(self.tmpdir, "data.txt")
fp = open(filename, "wb")
fp.write(b"some data")
fp.close()
lo = self.conn.lobject(0, "r", 0, filename)
self.assertEqual(lo.read(), "some data")
def test_close(self):
lo = self.conn.lobject()
self.assertEqual(lo.closed, False)
lo.close()
self.assertEqual(lo.closed, True)
def test_write(self):
lo = self.conn.lobject()
self.assertEqual(lo.write(b"some data"), len("some data"))
def test_write_large(self):
lo = self.conn.lobject()
data = "data" * 1000000
self.assertEqual(lo.write(data), len(data))
def test_read(self):
lo = self.conn.lobject()
lo.write(b"some data")
lo.close()
lo = self.conn.lobject(lo.oid)
x = lo.read(4)
self.assertEqual(type(x), type(''))
self.assertEqual(x, "some")
self.assertEqual(lo.read(), " data")
def test_read_binary(self):
lo = self.conn.lobject()
lo.write(b"some data")
lo.close()
lo = self.conn.lobject(lo.oid, "rb")
x = lo.read(4)
self.assertEqual(type(x), type(b''))
self.assertEqual(x, b"some")
self.assertEqual(lo.read(), b" data")
def test_read_text(self):
lo = self.conn.lobject()
snowman = "\u2603"
lo.write("some data " + snowman)
lo.close()
lo = self.conn.lobject(lo.oid, "rt")
x = lo.read(4)
self.assertEqual(type(x), type(''))
self.assertEqual(x, "some")
self.assertEqual(lo.read(), " data " + snowman)
@slow
def test_read_large(self):
lo = self.conn.lobject()
data = "data" * 1000000
lo.write("some" + data)
lo.close()
lo = self.conn.lobject(lo.oid)
self.assertEqual(lo.read(4), "some")
data1 = lo.read()
# avoid dumping megacraps in the console in case of error
self.assert_(data == data1,
f"{data[:100]!r}... != {data1[:100]!r}...")
def test_seek_tell(self):
lo = self.conn.lobject()
length = lo.write(b"some data")
self.assertEqual(lo.tell(), length)
lo.close()
lo = self.conn.lobject(lo.oid)
self.assertEqual(lo.seek(5, 0), 5)
self.assertEqual(lo.tell(), 5)
self.assertEqual(lo.read(), "data")
# SEEK_CUR: relative current location
lo.seek(5)
self.assertEqual(lo.seek(2, 1), 7)
self.assertEqual(lo.tell(), 7)
self.assertEqual(lo.read(), "ta")
# SEEK_END: relative to end of file
self.assertEqual(lo.seek(-2, 2), length - 2)
self.assertEqual(lo.read(), "ta")
def test_unlink(self):
lo = self.conn.lobject()
lo.unlink()
# the object doesn't exist now, so we can't reopen it.
self.assertRaises(psycopg2.OperationalError, self.conn.lobject, lo.oid)
# And the object has been closed.
self.assertEquals(lo.closed, True)
def test_export(self):
lo = self.conn.lobject()
lo.write(b"some data")
self.tmpdir = tempfile.mkdtemp()
filename = os.path.join(self.tmpdir, "data.txt")
lo.export(filename)
self.assertTrue(os.path.exists(filename))
f = open(filename, "rb")
try:
self.assertEqual(f.read(), b"some data")
finally:
f.close()
def test_close_twice(self):
lo = self.conn.lobject()
lo.close()
lo.close()
def test_write_after_close(self):
lo = self.conn.lobject()
lo.close()
self.assertRaises(psycopg2.InterfaceError, lo.write, b"some data")
def test_read_after_close(self):
lo = self.conn.lobject()
lo.close()
self.assertRaises(psycopg2.InterfaceError, lo.read, 5)
def test_seek_after_close(self):
lo = self.conn.lobject()
lo.close()
self.assertRaises(psycopg2.InterfaceError, lo.seek, 0)
def test_tell_after_close(self):
lo = self.conn.lobject()
lo.close()
self.assertRaises(psycopg2.InterfaceError, lo.tell)
def test_unlink_after_close(self):
lo = self.conn.lobject()
lo.close()
# Unlink works on closed files.
lo.unlink()
def test_export_after_close(self):
lo = self.conn.lobject()
lo.write(b"some data")
lo.close()
self.tmpdir = tempfile.mkdtemp()
filename = os.path.join(self.tmpdir, "data.txt")
lo.export(filename)
self.assertTrue(os.path.exists(filename))
f = open(filename, "rb")
try:
self.assertEqual(f.read(), b"some data")
finally:
f.close()
def test_close_after_commit(self):
lo = self.conn.lobject()
self.lo_oid = lo.oid
self.conn.commit()
# Closing outside of the transaction is okay.
lo.close()
def test_write_after_commit(self):
lo = self.conn.lobject()
self.lo_oid = lo.oid
self.conn.commit()
self.assertRaises(psycopg2.ProgrammingError, lo.write, b"some data")
def test_read_after_commit(self):
lo = self.conn.lobject()
self.lo_oid = lo.oid
self.conn.commit()
self.assertRaises(psycopg2.ProgrammingError, lo.read, 5)
def test_seek_after_commit(self):
lo = self.conn.lobject()
self.lo_oid = lo.oid
self.conn.commit()
self.assertRaises(psycopg2.ProgrammingError, lo.seek, 0)
def test_tell_after_commit(self):
lo = self.conn.lobject()
self.lo_oid = lo.oid
self.conn.commit()
self.assertRaises(psycopg2.ProgrammingError, lo.tell)
def test_unlink_after_commit(self):
lo = self.conn.lobject()
self.lo_oid = lo.oid
self.conn.commit()
# Unlink of stale lobject is okay
lo.unlink()
def test_export_after_commit(self):
lo = self.conn.lobject()
lo.write(b"some data")
self.conn.commit()
self.tmpdir = tempfile.mkdtemp()
filename = os.path.join(self.tmpdir, "data.txt")
lo.export(filename)
self.assertTrue(os.path.exists(filename))
f = open(filename, "rb")
try:
self.assertEqual(f.read(), b"some data")
finally:
f.close()
@skip_if_tpc_disabled
def test_read_after_tpc_commit(self):
self.conn.tpc_begin('test_lobject')
lo = self.conn.lobject()
self.lo_oid = lo.oid
self.conn.tpc_commit()
self.assertRaises(psycopg2.ProgrammingError, lo.read, 5)
@skip_if_tpc_disabled
def test_read_after_tpc_prepare(self):
self.conn.tpc_begin('test_lobject')
lo = self.conn.lobject()
self.lo_oid = lo.oid
self.conn.tpc_prepare()
try:
self.assertRaises(psycopg2.ProgrammingError, lo.read, 5)
finally:
self.conn.tpc_commit()
def test_large_oid(self):
# Test we don't overflow with an oid not fitting a signed int
try:
self.conn.lobject(0xFFFFFFFE)
except psycopg2.OperationalError:
pass
def test_factory(self):
class lobject_subclass(psycopg2.extensions.lobject):
pass
lo = self.conn.lobject(lobject_factory=lobject_subclass)
self.assert_(isinstance(lo, lobject_subclass))
@decorate_all_tests
def skip_if_no_truncate(f):
@wraps(f)
def skip_if_no_truncate_(self):
if self.conn.info.server_version < 80300:
return self.skipTest(
"the server doesn't support large object truncate")
if not hasattr(psycopg2.extensions.lobject, 'truncate'):
return self.skipTest(
"psycopg2 has been built against a libpq "
"without large object truncate support.")
return f(self)
return skip_if_no_truncate_
@skip_if_no_lo
@skip_if_no_truncate
class LargeObjectTruncateTests(LargeObjectTestCase):
def test_truncate(self):
lo = self.conn.lobject()
lo.write("some data")
lo.close()
lo = self.conn.lobject(lo.oid, "w")
lo.truncate(4)
# seek position unchanged
self.assertEqual(lo.tell(), 0)
# data truncated
self.assertEqual(lo.read(), "some")
lo.truncate(6)
lo.seek(0)
# large object extended with zeroes
self.assertEqual(lo.read(), "some\x00\x00")
lo.truncate()
lo.seek(0)
# large object empty
self.assertEqual(lo.read(), "")
def test_truncate_after_close(self):
lo = self.conn.lobject()
lo.close()
self.assertRaises(psycopg2.InterfaceError, lo.truncate)
def test_truncate_after_commit(self):
lo = self.conn.lobject()
self.lo_oid = lo.oid
self.conn.commit()
self.assertRaises(psycopg2.ProgrammingError, lo.truncate)
def _has_lo64(conn):
"""Return (bool, msg) about the lo64 support"""
if conn.info.server_version < 90300:
return (False, "server version %s doesn't support the lo64 API"
% conn.info.server_version)
if 'lo64' not in psycopg2.__version__:
return False, "this psycopg build doesn't support the lo64 API"
return True, "this server and build support the lo64 API"
@decorate_all_tests
def skip_if_no_lo64(f):
@wraps(f)
def skip_if_no_lo64_(self):
lo64, msg = _has_lo64(self.conn)
if not lo64:
return self.skipTest(msg)
else:
return f(self)
return skip_if_no_lo64_
@skip_if_no_lo
@skip_if_no_truncate
@skip_if_no_lo64
class LargeObject64Tests(LargeObjectTestCase):
def test_seek_tell_truncate_greater_than_2gb(self):
lo = self.conn.lobject()
length = (1 << 31) + (1 << 30) # 2gb + 1gb = 3gb
lo.truncate(length)
self.assertEqual(lo.seek(length, 0), length)
self.assertEqual(lo.tell(), length)
@decorate_all_tests
def skip_if_lo64(f):
@wraps(f)
def skip_if_lo64_(self):
lo64, msg = _has_lo64(self.conn)
if lo64:
return self.skipTest(msg)
else:
return f(self)
return skip_if_lo64_
@skip_if_no_lo
@skip_if_no_truncate
@skip_if_lo64
class LargeObjectNot64Tests(LargeObjectTestCase):
def test_seek_larger_than_2gb(self):
lo = self.conn.lobject()
offset = 1 << 32 # 4gb
self.assertRaises(
(OverflowError, psycopg2.InterfaceError, psycopg2.NotSupportedError),
lo.seek, offset, 0)
def test_truncate_larger_than_2gb(self):
lo = self.conn.lobject()
length = 1 << 32 # 4gb
self.assertRaises(
(OverflowError, psycopg2.InterfaceError, psycopg2.NotSupportedError),
lo.truncate, length)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

367
tests/test_module.py Executable file
View File

@ -0,0 +1,367 @@
#!/usr/bin/env python
# test_module.py - unit test for the module interface
#
# Copyright (C) 2011-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import gc
import os
import sys
import pickle
from subprocess import Popen
from weakref import ref
import unittest
from .testutils import (skip_before_postgres,
ConnectingTestCase, skip_copy_if_green, skip_if_crdb, slow, StringIO)
import psycopg2
class ConnectTestCase(unittest.TestCase):
def setUp(self):
self.args = None
def connect_stub(dsn, connection_factory=None, async_=False):
self.args = (dsn, connection_factory, async_)
self._connect_orig = psycopg2._connect
psycopg2._connect = connect_stub
def tearDown(self):
psycopg2._connect = self._connect_orig
def test_there_might_be_nothing(self):
psycopg2.connect()
self.assertEqual(self.args[0], '')
self.assertEqual(self.args[1], None)
self.assertEqual(self.args[2], False)
psycopg2.connect(
connection_factory=lambda dsn, async_=False: None)
self.assertEqual(self.args[0], '')
self.assertNotEqual(self.args[1], None)
self.assertEqual(self.args[2], False)
psycopg2.connect(async_=True)
self.assertEqual(self.args[0], '')
self.assertEqual(self.args[1], None)
self.assertEqual(self.args[2], True)
def test_no_keywords(self):
psycopg2.connect('')
self.assertEqual(self.args[0], '')
self.assertEqual(self.args[1], None)
self.assertEqual(self.args[2], False)
def test_dsn(self):
psycopg2.connect('dbname=blah host=y')
self.assertEqual(self.args[0], 'dbname=blah host=y')
self.assertEqual(self.args[1], None)
self.assertEqual(self.args[2], False)
def test_supported_keywords(self):
psycopg2.connect(database='foo')
self.assertEqual(self.args[0], 'dbname=foo')
psycopg2.connect(user='postgres')
self.assertEqual(self.args[0], 'user=postgres')
psycopg2.connect(password='secret')
self.assertEqual(self.args[0], 'password=secret')
psycopg2.connect(port=5432)
self.assertEqual(self.args[0], 'port=5432')
psycopg2.connect(sslmode='require')
self.assertEqual(self.args[0], 'sslmode=require')
psycopg2.connect(database='foo',
user='postgres', password='secret', port=5432)
self.assert_('dbname=foo' in self.args[0])
self.assert_('user=postgres' in self.args[0])
self.assert_('password=secret' in self.args[0])
self.assert_('port=5432' in self.args[0])
self.assertEqual(len(self.args[0].split()), 4)
def test_generic_keywords(self):
psycopg2.connect(options='stuff')
self.assertEqual(self.args[0], 'options=stuff')
def test_factory(self):
def f(dsn, async_=False):
pass
psycopg2.connect(database='foo', host='baz', connection_factory=f)
self.assertDsnEqual(self.args[0], 'dbname=foo host=baz')
self.assertEqual(self.args[1], f)
self.assertEqual(self.args[2], False)
psycopg2.connect("dbname=foo host=baz", connection_factory=f)
self.assertDsnEqual(self.args[0], 'dbname=foo host=baz')
self.assertEqual(self.args[1], f)
self.assertEqual(self.args[2], False)
def test_async(self):
psycopg2.connect(database='foo', host='baz', async_=1)
self.assertDsnEqual(self.args[0], 'dbname=foo host=baz')
self.assertEqual(self.args[1], None)
self.assert_(self.args[2])
psycopg2.connect("dbname=foo host=baz", async_=True)
self.assertDsnEqual(self.args[0], 'dbname=foo host=baz')
self.assertEqual(self.args[1], None)
self.assert_(self.args[2])
def test_int_port_param(self):
psycopg2.connect(database='sony', port=6543)
dsn = f" {self.args[0]} "
self.assert_(" dbname=sony " in dsn, dsn)
self.assert_(" port=6543 " in dsn, dsn)
def test_empty_param(self):
psycopg2.connect(database='sony', password='')
self.assertDsnEqual(self.args[0], "dbname=sony password=''")
def test_escape(self):
psycopg2.connect(database='hello world')
self.assertEqual(self.args[0], "dbname='hello world'")
psycopg2.connect(database=r'back\slash')
self.assertEqual(self.args[0], r"dbname=back\\slash")
psycopg2.connect(database="quo'te")
self.assertEqual(self.args[0], r"dbname=quo\'te")
psycopg2.connect(database="with\ttab")
self.assertEqual(self.args[0], "dbname='with\ttab'")
psycopg2.connect(database=r"\every thing'")
self.assertEqual(self.args[0], r"dbname='\\every thing\''")
def test_params_merging(self):
psycopg2.connect('dbname=foo', database='bar')
self.assertEqual(self.args[0], 'dbname=bar')
psycopg2.connect('dbname=foo', user='postgres')
self.assertDsnEqual(self.args[0], 'dbname=foo user=postgres')
class ExceptionsTestCase(ConnectingTestCase):
def test_attributes(self):
cur = self.conn.cursor()
try:
cur.execute("select * from nonexist")
except psycopg2.Error as exc:
e = exc
self.assertEqual(e.pgcode, '42P01')
self.assert_(e.pgerror)
self.assert_(e.cursor is cur)
def test_diagnostics_attributes(self):
cur = self.conn.cursor()
try:
cur.execute("select * from nonexist")
except psycopg2.Error as exc:
e = exc
diag = e.diag
self.assert_(isinstance(diag, psycopg2.extensions.Diagnostics))
for attr in [
'column_name', 'constraint_name', 'context', 'datatype_name',
'internal_position', 'internal_query', 'message_detail',
'message_hint', 'message_primary', 'schema_name', 'severity',
'severity_nonlocalized', 'source_file', 'source_function',
'source_line', 'sqlstate', 'statement_position', 'table_name', ]:
v = getattr(diag, attr)
if v is not None:
self.assert_(isinstance(v, str))
def test_diagnostics_values(self):
cur = self.conn.cursor()
try:
cur.execute("select * from nonexist")
except psycopg2.Error as exc:
e = exc
self.assertEqual(e.diag.sqlstate, '42P01')
self.assertEqual(e.diag.severity, 'ERROR')
def test_diagnostics_life(self):
def tmp():
cur = self.conn.cursor()
try:
cur.execute("select * from nonexist")
except psycopg2.Error as exc:
return cur, exc
cur, e = tmp()
diag = e.diag
w = ref(cur)
del e, cur
gc.collect()
assert(w() is not None)
self.assertEqual(diag.sqlstate, '42P01')
del diag
gc.collect()
gc.collect()
assert(w() is None)
@skip_if_crdb("copy")
@skip_copy_if_green
def test_diagnostics_copy(self):
f = StringIO()
cur = self.conn.cursor()
try:
cur.copy_to(f, 'nonexist')
except psycopg2.Error as exc:
diag = exc.diag
self.assertEqual(diag.sqlstate, '42P01')
def test_diagnostics_independent(self):
cur = self.conn.cursor()
try:
cur.execute("l'acqua e' poca e 'a papera nun galleggia")
except Exception as exc:
diag1 = exc.diag
self.conn.rollback()
try:
cur.execute("select level from water where ducks > 1")
except psycopg2.Error as exc:
diag2 = exc.diag
self.assertEqual(diag1.sqlstate, '42601')
self.assertEqual(diag2.sqlstate, '42P01')
@skip_if_crdb("deferrable")
def test_diagnostics_from_commit(self):
cur = self.conn.cursor()
cur.execute("""
create temp table test_deferred (
data int primary key,
ref int references test_deferred (data)
deferrable initially deferred)
""")
cur.execute("insert into test_deferred values (1,2)")
try:
self.conn.commit()
except psycopg2.Error as exc:
e = exc
self.assertEqual(e.diag.sqlstate, '23503')
@skip_if_crdb("diagnostic")
@skip_before_postgres(9, 3)
def test_9_3_diagnostics(self):
cur = self.conn.cursor()
cur.execute("""
create temp table test_exc (
data int constraint chk_eq1 check (data = 1)
)""")
try:
cur.execute("insert into test_exc values(2)")
except psycopg2.Error as exc:
e = exc
self.assertEqual(e.pgcode, '23514')
self.assertEqual(e.diag.schema_name[:7], "pg_temp")
self.assertEqual(e.diag.table_name, "test_exc")
self.assertEqual(e.diag.column_name, None)
self.assertEqual(e.diag.constraint_name, "chk_eq1")
self.assertEqual(e.diag.datatype_name, None)
@skip_if_crdb("diagnostic")
@skip_before_postgres(9, 6)
def test_9_6_diagnostics(self):
cur = self.conn.cursor()
try:
cur.execute("select 1 from nosuchtable")
except psycopg2.Error as exc:
e = exc
self.assertEqual(e.diag.severity_nonlocalized, 'ERROR')
def test_pickle(self):
cur = self.conn.cursor()
try:
cur.execute("select * from nonexist")
except psycopg2.Error as exc:
e = exc
e1 = pickle.loads(pickle.dumps(e))
self.assertEqual(e.pgerror, e1.pgerror)
self.assertEqual(e.pgcode, e1.pgcode)
self.assert_(e1.cursor is None)
@skip_if_crdb("connect any db")
def test_pickle_connection_error(self):
# segfaults on psycopg 2.5.1 - see ticket #170
try:
psycopg2.connect('dbname=nosuchdatabasemate')
except psycopg2.Error as exc:
e = exc
e1 = pickle.loads(pickle.dumps(e))
self.assertEqual(e.pgerror, e1.pgerror)
self.assertEqual(e.pgcode, e1.pgcode)
self.assert_(e1.cursor is None)
class TestExtensionModule(unittest.TestCase):
@slow
def test_import_internal(self):
# check that the internal package can be imported "naked"
# we may break this property if there is a compelling reason to do so,
# however having it allows for some import juggling such as the one
# required in ticket #201.
pkgdir = os.path.dirname(psycopg2.__file__)
pardir = os.path.dirname(pkgdir)
self.assert_(pardir in sys.path)
script = f"""
import sys
sys.path.remove({pardir!r})
sys.path.insert(0, {pkgdir!r})
import _psycopg
"""
proc = Popen([sys.executable, '-c', script])
proc.communicate()
self.assertEqual(0, proc.returncode)
class TestVersionDiscovery(unittest.TestCase):
def test_libpq_version(self):
self.assertTrue(type(psycopg2.__libpq_version__) is int)
try:
self.assertTrue(type(psycopg2.extensions.libpq_version()) is int)
except psycopg2.NotSupportedError:
self.assertTrue(psycopg2.__libpq_version__ < 90100)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

231
tests/test_notify.py Executable file
View File

@ -0,0 +1,231 @@
#!/usr/bin/env python
# test_notify.py - unit test for async notifications
#
# Copyright (C) 2010-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import unittest
from collections import deque
import psycopg2
from psycopg2 import extensions
from psycopg2.extensions import Notify
from .testutils import ConnectingTestCase, skip_if_crdb, slow
from .testconfig import dsn
import sys
import time
import select
from subprocess import Popen, PIPE
@skip_if_crdb("notify")
class NotifiesTests(ConnectingTestCase):
def autocommit(self, conn):
"""Set a connection in autocommit mode."""
conn.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT)
def listen(self, name):
"""Start listening for a name on self.conn."""
curs = self.conn.cursor()
curs.execute("LISTEN " + name)
curs.close()
def notify(self, name, sec=0, payload=None):
"""Send a notification to the database, eventually after some time."""
if payload is None:
payload = ''
else:
payload = f", {payload!r}"
script = ("""\
import time
time.sleep({sec})
import {module} as psycopg2
import {module}.extensions as ext
conn = psycopg2.connect({dsn!r})
conn.set_isolation_level(ext.ISOLATION_LEVEL_AUTOCOMMIT)
print(conn.info.backend_pid)
curs = conn.cursor()
curs.execute("NOTIFY " {name!r} {payload!r})
curs.close()
conn.close()
""".format(
module=psycopg2.__name__,
dsn=dsn, sec=sec, name=name, payload=payload))
return Popen([sys.executable, '-c', script], stdout=PIPE)
@slow
def test_notifies_received_on_poll(self):
self.autocommit(self.conn)
self.listen('foo')
proc = self.notify('foo', 1)
t0 = time.time()
select.select([self.conn], [], [], 5)
t1 = time.time()
self.assert_(0.99 < t1 - t0 < 4, t1 - t0)
pid = int(proc.communicate()[0])
self.assertEqual(0, len(self.conn.notifies))
self.assertEqual(extensions.POLL_OK, self.conn.poll())
self.assertEqual(1, len(self.conn.notifies))
self.assertEqual(pid, self.conn.notifies[0][0])
self.assertEqual('foo', self.conn.notifies[0][1])
@slow
def test_many_notifies(self):
self.autocommit(self.conn)
for name in ['foo', 'bar', 'baz']:
self.listen(name)
pids = {}
for name in ['foo', 'bar', 'baz', 'qux']:
pids[name] = int(self.notify(name).communicate()[0])
self.assertEqual(0, len(self.conn.notifies))
for i in range(10):
self.assertEqual(extensions.POLL_OK, self.conn.poll())
self.assertEqual(3, len(self.conn.notifies))
names = dict.fromkeys(['foo', 'bar', 'baz'])
for (pid, name) in self.conn.notifies:
self.assertEqual(pids[name], pid)
names.pop(name) # raise if name found twice
@slow
def test_notifies_received_on_execute(self):
self.autocommit(self.conn)
self.listen('foo')
pid = int(self.notify('foo').communicate()[0])
self.assertEqual(0, len(self.conn.notifies))
self.conn.cursor().execute('select 1;')
self.assertEqual(1, len(self.conn.notifies))
self.assertEqual(pid, self.conn.notifies[0][0])
self.assertEqual('foo', self.conn.notifies[0][1])
@slow
def test_notify_object(self):
self.autocommit(self.conn)
self.listen('foo')
self.notify('foo').communicate()
time.sleep(0.5)
self.conn.poll()
notify = self.conn.notifies[0]
self.assert_(isinstance(notify, psycopg2.extensions.Notify))
@slow
def test_notify_attributes(self):
self.autocommit(self.conn)
self.listen('foo')
pid = int(self.notify('foo').communicate()[0])
time.sleep(0.5)
self.conn.poll()
self.assertEqual(1, len(self.conn.notifies))
notify = self.conn.notifies[0]
self.assertEqual(pid, notify.pid)
self.assertEqual('foo', notify.channel)
self.assertEqual('', notify.payload)
@slow
def test_notify_payload(self):
if self.conn.info.server_version < 90000:
return self.skipTest("server version %s doesn't support notify payload"
% self.conn.info.server_version)
self.autocommit(self.conn)
self.listen('foo')
pid = int(self.notify('foo', payload="Hello, world!").communicate()[0])
time.sleep(0.5)
self.conn.poll()
self.assertEqual(1, len(self.conn.notifies))
notify = self.conn.notifies[0]
self.assertEqual(pid, notify.pid)
self.assertEqual('foo', notify.channel)
self.assertEqual('Hello, world!', notify.payload)
@slow
def test_notify_deque(self):
self.autocommit(self.conn)
self.conn.notifies = deque()
self.listen('foo')
self.notify('foo').communicate()
time.sleep(0.5)
self.conn.poll()
notify = self.conn.notifies.popleft()
self.assert_(isinstance(notify, psycopg2.extensions.Notify))
self.assertEqual(len(self.conn.notifies), 0)
@slow
def test_notify_noappend(self):
self.autocommit(self.conn)
self.conn.notifies = None
self.listen('foo')
self.notify('foo').communicate()
time.sleep(0.5)
self.conn.poll()
self.assertEqual(self.conn.notifies, None)
def test_notify_init(self):
n = psycopg2.extensions.Notify(10, 'foo')
self.assertEqual(10, n.pid)
self.assertEqual('foo', n.channel)
self.assertEqual('', n.payload)
(pid, channel) = n
self.assertEqual((pid, channel), (10, 'foo'))
n = psycopg2.extensions.Notify(42, 'bar', 'baz')
self.assertEqual(42, n.pid)
self.assertEqual('bar', n.channel)
self.assertEqual('baz', n.payload)
(pid, channel) = n
self.assertEqual((pid, channel), (42, 'bar'))
def test_compare(self):
data = [(10, 'foo'), (20, 'foo'), (10, 'foo', 'bar'), (10, 'foo', 'baz')]
for d1 in data:
for d2 in data:
n1 = psycopg2.extensions.Notify(*d1)
n2 = psycopg2.extensions.Notify(*d2)
self.assertEqual((n1 == n2), (d1 == d2))
self.assertEqual((n1 != n2), (d1 != d2))
def test_compare_tuple(self):
self.assertEqual((10, 'foo'), Notify(10, 'foo'))
self.assertEqual((10, 'foo'), Notify(10, 'foo', 'bar'))
self.assertNotEqual((10, 'foo'), Notify(20, 'foo'))
self.assertNotEqual((10, 'foo'), Notify(10, 'bar'))
def test_hash(self):
self.assertEqual(hash((10, 'foo')), hash(Notify(10, 'foo')))
self.assertNotEqual(hash(Notify(10, 'foo', 'bar')),
hash(Notify(10, 'foo')))
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

86
tests/test_psycopg2_dbapi20.py Executable file
View File

@ -0,0 +1,86 @@
#!/usr/bin/env python
# test_psycopg2_dbapi20.py - DB API conformance test for psycopg2
#
# Copyright (C) 2006-2019 Federico Di Gregorio <fog@debian.org>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
from . import dbapi20
from . import dbapi20_tpc
from .testutils import skip_if_tpc_disabled
import unittest
import psycopg2
from .testconfig import dsn
class Psycopg2Tests(dbapi20.DatabaseAPI20Test):
driver = psycopg2
connect_args = ()
connect_kw_args = {'dsn': dsn}
lower_func = 'lower' # For stored procedure test
def test_callproc(self):
# Until DBAPI 2.0 compliance, callproc should return None or it's just
# misleading. Therefore, we will skip the return value test for
# callproc and only perform the fetch test.
#
# For what it's worth, the DBAPI2.0 test_callproc doesn't actually
# test for DBAPI2.0 compliance! It doesn't check for modified OUT and
# IN/OUT parameters in the return values!
con = self._connect()
try:
cur = con.cursor()
if self.lower_func and hasattr(cur, 'callproc'):
cur.callproc(self.lower_func, ('FOO',))
r = cur.fetchall()
self.assertEqual(len(r), 1, 'callproc produced no result set')
self.assertEqual(len(r[0]), 1,
'callproc produced invalid result set')
self.assertEqual(r[0][0], 'foo',
'callproc produced invalid results')
finally:
con.close()
def test_setoutputsize(self):
# psycopg2's setoutputsize() is a no-op
pass
def test_nextset(self):
# psycopg2 does not implement nextset()
pass
@skip_if_tpc_disabled
class Psycopg2TPCTests(dbapi20_tpc.TwoPhaseCommitTests, unittest.TestCase):
driver = psycopg2
def connect(self):
return psycopg2.connect(dsn=dsn)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == '__main__':
unittest.main()

229
tests/test_quote.py Executable file
View File

@ -0,0 +1,229 @@
#!/usr/bin/env python
# test_quote.py - unit test for strings quoting
#
# Copyright (C) 2007-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
from . import testutils
import unittest
from .testutils import ConnectingTestCase, skip_if_crdb
import psycopg2
import psycopg2.extensions
from psycopg2.extensions import adapt, quote_ident
class QuotingTestCase(ConnectingTestCase):
r"""Checks the correct quoting of strings and binary objects.
Since ver. 8.1, PostgreSQL is moving towards SQL standard conforming
strings, where the backslash (\) is treated as literal character,
not as escape. To treat the backslash as a C-style escapes, PG supports
the E'' quotes.
This test case checks that the E'' quotes are used whenever they are
needed. The tests are expected to pass with all PostgreSQL server versions
(currently tested with 7.4 <= PG <= 8.3beta) and with any
'standard_conforming_strings' server parameter value.
The tests also check that no warning is raised ('escape_string_warning'
should be on).
https://www.postgresql.org/docs/current/static/sql-syntax-lexical.html#SQL-SYNTAX-STRINGS
https://www.postgresql.org/docs/current/static/runtime-config-compatible.html
"""
def test_string(self):
data = """some data with \t chars
to escape into, 'quotes' and \\ a backslash too.
"""
data += "".join(map(chr, range(1, 127)))
curs = self.conn.cursor()
curs.execute("SELECT %s;", (data,))
res = curs.fetchone()[0]
self.assertEqual(res, data)
self.assert_(not self.conn.notices)
def test_string_null_terminator(self):
curs = self.conn.cursor()
data = 'abcd\x01\x00cdefg'
try:
curs.execute("SELECT %s", (data,))
except ValueError as e:
self.assertEquals(str(e),
'A string literal cannot contain NUL (0x00) characters.')
else:
self.fail("ValueError not raised")
def test_binary(self):
data = b"""some data with \000\013 binary
stuff into, 'quotes' and \\ a backslash too.
"""
data += bytes(list(range(256)))
curs = self.conn.cursor()
curs.execute("SELECT %s::bytea;", (psycopg2.Binary(data),))
res = curs.fetchone()[0].tobytes()
if res[0] in (b'x', ord(b'x')) and self.conn.info.server_version >= 90000:
return self.skipTest(
"bytea broken with server >= 9.0, libpq < 9")
self.assertEqual(res, data)
self.assert_(not self.conn.notices)
def test_unicode(self):
curs = self.conn.cursor()
curs.execute("SHOW server_encoding")
server_encoding = curs.fetchone()[0]
if server_encoding != "UTF8":
return self.skipTest(
f"Unicode test skipped since server encoding is {server_encoding}")
data = """some data with \t chars
to escape into, 'quotes', \u20ac euro sign and \\ a backslash too.
"""
data += "".join(map(chr, [u for u in range(1, 65536)
if not 0xD800 <= u <= 0xDFFF])) # surrogate area
self.conn.set_client_encoding('UNICODE')
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, self.conn)
curs.execute("SELECT %s::text;", (data,))
res = curs.fetchone()[0]
self.assertEqual(res, data)
self.assert_(not self.conn.notices)
@skip_if_crdb("encoding")
def test_latin1(self):
self.conn.set_client_encoding('LATIN1')
curs = self.conn.cursor()
data = bytes(list(range(32, 127))
+ list(range(160, 256))).decode('latin1')
# as string
curs.execute("SELECT %s::text;", (data,))
res = curs.fetchone()[0]
self.assertEqual(res, data)
self.assert_(not self.conn.notices)
@skip_if_crdb("encoding")
def test_koi8(self):
self.conn.set_client_encoding('KOI8')
curs = self.conn.cursor()
data = bytes(list(range(32, 127))
+ list(range(128, 256))).decode('koi8_r')
# as string
curs.execute("SELECT %s::text;", (data,))
res = curs.fetchone()[0]
self.assertEqual(res, data)
self.assert_(not self.conn.notices)
def test_bytes(self):
snowman = "\u2603"
conn = self.connect()
conn.set_client_encoding('UNICODE')
psycopg2.extensions.register_type(psycopg2.extensions.BYTES, conn)
curs = conn.cursor()
curs.execute("select %s::text", (snowman,))
x = curs.fetchone()[0]
self.assert_(isinstance(x, bytes))
self.assertEqual(x, snowman.encode('utf8'))
class TestQuotedString(ConnectingTestCase):
def test_encoding_from_conn(self):
q = psycopg2.extensions.QuotedString('hi')
self.assertEqual(q.encoding, 'latin1')
self.conn.set_client_encoding('utf_8')
q.prepare(self.conn)
self.assertEqual(q.encoding, 'utf_8')
class TestQuotedIdentifier(ConnectingTestCase):
def test_identifier(self):
self.assertEqual(quote_ident('blah-blah', self.conn), '"blah-blah"')
self.assertEqual(quote_ident('quote"inside', self.conn), '"quote""inside"')
@testutils.skip_before_postgres(8, 0)
def test_unicode_ident(self):
snowman = "\u2603"
quoted = '"' + snowman + '"'
self.assertEqual(quote_ident(snowman, self.conn), quoted)
class TestStringAdapter(ConnectingTestCase):
def test_encoding_default(self):
a = adapt("hello")
self.assertEqual(a.encoding, 'latin1')
self.assertEqual(a.getquoted(), b"'hello'")
# NOTE: we can't really test an encoding different from utf8, because
# when encoding without connection the libpq will use parameters from
# a previous one, so what would happens depends jn the tests run order.
# egrave = u'\xe8'
# self.assertEqual(adapt(egrave).getquoted(), "'\xe8'")
def test_encoding_error(self):
snowman = "\u2603"
a = adapt(snowman)
self.assertRaises(UnicodeEncodeError, a.getquoted)
def test_set_encoding(self):
# Note: this works-ish mostly in case when the standard db connection
# we test with is utf8, otherwise the encoding chosen by PQescapeString
# may give bad results.
snowman = "\u2603"
a = adapt(snowman)
a.encoding = 'utf8'
self.assertEqual(a.encoding, 'utf8')
self.assertEqual(a.getquoted(), b"'\xe2\x98\x83'")
def test_connection_wins_anyway(self):
snowman = "\u2603"
a = adapt(snowman)
a.encoding = 'latin9'
self.conn.set_client_encoding('utf8')
a.prepare(self.conn)
self.assertEqual(a.encoding, 'utf_8')
self.assertQuotedEqual(a.getquoted(), b"'\xe2\x98\x83'")
def test_adapt_bytes(self):
snowman = "\u2603"
self.conn.set_client_encoding('utf8')
a = psycopg2.extensions.QuotedString(snowman.encode('utf8'))
a.prepare(self.conn)
self.assertQuotedEqual(a.getquoted(), b"'\xe2\x98\x83'")
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

276
tests/test_replication.py Executable file
View File

@ -0,0 +1,276 @@
#!/usr/bin/env python
# test_replication.py - unit test for replication protocol
#
# Copyright (C) 2015-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import time
from select import select
import psycopg2
from psycopg2 import sql
from psycopg2.extras import (
PhysicalReplicationConnection, LogicalReplicationConnection, StopReplication)
from . import testconfig
import unittest
from .testutils import ConnectingTestCase
from .testutils import skip_before_postgres, skip_if_green
skip_repl_if_green = skip_if_green("replication not supported in green mode")
class ReplicationTestCase(ConnectingTestCase):
def setUp(self):
super().setUp()
self.slot = testconfig.repl_slot
self._slots = []
def tearDown(self):
# first close all connections, as they might keep the slot(s) active
super().tearDown()
time.sleep(0.025) # sometimes the slot is still active, wait a little
if self._slots:
kill_conn = self.connect()
if kill_conn:
kill_cur = kill_conn.cursor()
for slot in self._slots:
kill_cur.execute("SELECT pg_drop_replication_slot(%s)", (slot,))
kill_conn.commit()
kill_conn.close()
def create_replication_slot(self, cur, slot_name=testconfig.repl_slot, **kwargs):
cur.create_replication_slot(slot_name, **kwargs)
self._slots.append(slot_name)
def drop_replication_slot(self, cur, slot_name=testconfig.repl_slot):
cur.drop_replication_slot(slot_name)
self._slots.remove(slot_name)
# generate some events for our replication stream
def make_replication_events(self):
conn = self.connect()
if conn is None:
return
cur = conn.cursor()
try:
cur.execute("DROP TABLE dummy1")
except psycopg2.ProgrammingError:
conn.rollback()
cur.execute(
"CREATE TABLE dummy1 AS SELECT * FROM generate_series(1, 5) AS id")
conn.commit()
class ReplicationTest(ReplicationTestCase):
@skip_before_postgres(9, 0)
def test_physical_replication_connection(self):
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
if conn is None:
return
cur = conn.cursor()
cur.execute("IDENTIFY_SYSTEM")
cur.fetchall()
@skip_before_postgres(9, 0)
def test_datestyle(self):
if testconfig.repl_dsn is None:
return self.skipTest("replication tests disabled by default")
conn = self.repl_connect(
dsn=testconfig.repl_dsn, options='-cdatestyle=german',
connection_factory=PhysicalReplicationConnection)
if conn is None:
return
cur = conn.cursor()
cur.execute("IDENTIFY_SYSTEM")
cur.fetchall()
@skip_before_postgres(9, 4)
def test_logical_replication_connection(self):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None:
return
cur = conn.cursor()
cur.execute("IDENTIFY_SYSTEM")
cur.fetchall()
@skip_before_postgres(9, 4) # slots require 9.4
def test_create_replication_slot(self):
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
if conn is None:
return
cur = conn.cursor()
self.create_replication_slot(cur)
self.assertRaises(
psycopg2.ProgrammingError, self.create_replication_slot, cur)
@skip_before_postgres(9, 4) # slots require 9.4
@skip_repl_if_green
def test_start_on_missing_replication_slot(self):
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
if conn is None:
return
cur = conn.cursor()
self.assertRaises(psycopg2.ProgrammingError,
cur.start_replication, self.slot)
self.create_replication_slot(cur)
cur.start_replication(self.slot)
@skip_before_postgres(9, 4) # slots require 9.4
@skip_repl_if_green
def test_start_replication_expert_sql(self):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None:
return
cur = conn.cursor()
self.create_replication_slot(cur, output_plugin='test_decoding')
cur.start_replication_expert(
sql.SQL("START_REPLICATION SLOT {slot} LOGICAL 0/00000000").format(
slot=sql.Identifier(self.slot)))
@skip_before_postgres(9, 4) # slots require 9.4
@skip_repl_if_green
def test_start_and_recover_from_error(self):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None:
return
cur = conn.cursor()
self.create_replication_slot(cur, output_plugin='test_decoding')
self.make_replication_events()
def consume(msg):
raise StopReplication()
with self.assertRaises(psycopg2.DataError):
# try with invalid options
cur.start_replication(
slot_name=self.slot, options={'invalid_param': 'value'})
cur.consume_stream(consume)
# try with correct command
cur.start_replication(slot_name=self.slot)
self.assertRaises(StopReplication, cur.consume_stream, consume)
@skip_before_postgres(9, 4) # slots require 9.4
@skip_repl_if_green
def test_keepalive(self):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None:
return
cur = conn.cursor()
self.create_replication_slot(cur, output_plugin='test_decoding')
self.make_replication_events()
cur.start_replication(self.slot)
def consume(msg):
raise StopReplication()
self.assertRaises(StopReplication,
cur.consume_stream, consume, keepalive_interval=2)
conn.close()
@skip_before_postgres(9, 4) # slots require 9.4
@skip_repl_if_green
def test_stop_replication(self):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None:
return
cur = conn.cursor()
self.create_replication_slot(cur, output_plugin='test_decoding')
self.make_replication_events()
cur.start_replication(self.slot)
def consume(msg):
raise StopReplication()
self.assertRaises(StopReplication, cur.consume_stream, consume)
class AsyncReplicationTest(ReplicationTestCase):
@skip_before_postgres(9, 4) # slots require 9.4
@skip_repl_if_green
def test_async_replication(self):
conn = self.repl_connect(
connection_factory=LogicalReplicationConnection, async_=1)
if conn is None:
return
cur = conn.cursor()
self.create_replication_slot(cur, output_plugin='test_decoding')
self.wait(cur)
cur.start_replication(self.slot)
self.wait(cur)
self.make_replication_events()
self.msg_count = 0
def consume(msg):
# just check the methods
f"{cur.io_timestamp}: {repr(msg)}"
f"{cur.feedback_timestamp}: {repr(msg)}"
f"{cur.wal_end}: {repr(msg)}"
self.msg_count += 1
if self.msg_count > 3:
cur.send_feedback(reply=True)
raise StopReplication()
cur.send_feedback(flush_lsn=msg.data_start)
# cannot be used in asynchronous mode
self.assertRaises(psycopg2.ProgrammingError, cur.consume_stream, consume)
def process_stream():
while True:
msg = cur.read_message()
if msg:
consume(msg)
else:
select([cur], [], [])
self.assertRaises(StopReplication, process_stream)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

412
tests/test_sql.py Executable file
View File

@ -0,0 +1,412 @@
#!/usr/bin/env python
# test_sql.py - tests for the psycopg2.sql module
#
# Copyright (C) 2016-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import datetime as dt
import unittest
from .testutils import (
ConnectingTestCase, skip_before_postgres, skip_copy_if_green, StringIO,
skip_if_crdb)
import psycopg2
from psycopg2 import sql
class SqlFormatTests(ConnectingTestCase):
def test_pos(self):
s = sql.SQL("select {} from {}").format(
sql.Identifier('field'), sql.Identifier('table'))
s1 = s.as_string(self.conn)
self.assert_(isinstance(s1, str))
self.assertEqual(s1, 'select "field" from "table"')
def test_pos_spec(self):
s = sql.SQL("select {0} from {1}").format(
sql.Identifier('field'), sql.Identifier('table'))
s1 = s.as_string(self.conn)
self.assert_(isinstance(s1, str))
self.assertEqual(s1, 'select "field" from "table"')
s = sql.SQL("select {1} from {0}").format(
sql.Identifier('table'), sql.Identifier('field'))
s1 = s.as_string(self.conn)
self.assert_(isinstance(s1, str))
self.assertEqual(s1, 'select "field" from "table"')
def test_dict(self):
s = sql.SQL("select {f} from {t}").format(
f=sql.Identifier('field'), t=sql.Identifier('table'))
s1 = s.as_string(self.conn)
self.assert_(isinstance(s1, str))
self.assertEqual(s1, 'select "field" from "table"')
def test_compose_literal(self):
s = sql.SQL("select {0};").format(sql.Literal(dt.date(2016, 12, 31)))
s1 = s.as_string(self.conn)
self.assertEqual(s1, "select '2016-12-31'::date;")
def test_compose_empty(self):
s = sql.SQL("select foo;").format()
s1 = s.as_string(self.conn)
self.assertEqual(s1, "select foo;")
def test_percent_escape(self):
s = sql.SQL("42 % {0}").format(sql.Literal(7))
s1 = s.as_string(self.conn)
self.assertEqual(s1, "42 % 7")
def test_braces_escape(self):
s = sql.SQL("{{{0}}}").format(sql.Literal(7))
self.assertEqual(s.as_string(self.conn), "{7}")
s = sql.SQL("{{1,{0}}}").format(sql.Literal(7))
self.assertEqual(s.as_string(self.conn), "{1,7}")
def test_compose_badnargs(self):
self.assertRaises(IndexError, sql.SQL("select {0};").format)
def test_compose_badnargs_auto(self):
self.assertRaises(IndexError, sql.SQL("select {};").format)
self.assertRaises(ValueError, sql.SQL("select {} {1};").format, 10, 20)
self.assertRaises(ValueError, sql.SQL("select {0} {};").format, 10, 20)
def test_compose_bad_args_type(self):
self.assertRaises(IndexError, sql.SQL("select {0};").format, a=10)
self.assertRaises(KeyError, sql.SQL("select {x};").format, 10)
def test_must_be_composable(self):
self.assertRaises(TypeError, sql.SQL("select {0};").format, 'foo')
self.assertRaises(TypeError, sql.SQL("select {0};").format, 10)
def test_no_modifiers(self):
self.assertRaises(ValueError, sql.SQL("select {a!r};").format, a=10)
self.assertRaises(ValueError, sql.SQL("select {a:<};").format, a=10)
def test_must_be_adaptable(self):
class Foo:
pass
self.assertRaises(psycopg2.ProgrammingError,
sql.SQL("select {0};").format(sql.Literal(Foo())).as_string, self.conn)
def test_execute(self):
cur = self.conn.cursor()
cur.execute("""
create table test_compose (
id serial primary key,
foo text, bar text, "ba'z" text)
""")
cur.execute(
sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format(
sql.Identifier('test_compose'),
sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])),
(sql.Placeholder() * 3).join(', ')),
(10, 'a', 'b', 'c'))
cur.execute("select * from test_compose")
self.assertEqual(cur.fetchall(), [(10, 'a', 'b', 'c')])
def test_executemany(self):
cur = self.conn.cursor()
cur.execute("""
create table test_compose (
id serial primary key,
foo text, bar text, "ba'z" text)
""")
cur.executemany(
sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format(
sql.Identifier('test_compose'),
sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])),
(sql.Placeholder() * 3).join(', ')),
[(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')])
cur.execute("select * from test_compose")
self.assertEqual(cur.fetchall(),
[(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')])
@skip_if_crdb("copy")
@skip_copy_if_green
@skip_before_postgres(8, 2)
def test_copy(self):
cur = self.conn.cursor()
cur.execute("""
create table test_compose (
id serial primary key,
foo text, bar text, "ba'z" text)
""")
s = StringIO("10\ta\tb\tc\n20\td\te\tf\n")
cur.copy_expert(
sql.SQL("copy {t} (id, foo, bar, {f}) from stdin").format(
t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")), s)
s1 = StringIO()
cur.copy_expert(
sql.SQL("copy (select {f} from {t} order by id) to stdout").format(
t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")), s1)
s1.seek(0)
self.assertEqual(s1.read(), 'c\nf\n')
class IdentifierTests(ConnectingTestCase):
def test_class(self):
self.assert_(issubclass(sql.Identifier, sql.Composable))
def test_init(self):
self.assert_(isinstance(sql.Identifier('foo'), sql.Identifier))
self.assert_(isinstance(sql.Identifier('foo'), sql.Identifier))
self.assert_(isinstance(sql.Identifier('foo', 'bar', 'baz'), sql.Identifier))
self.assertRaises(TypeError, sql.Identifier)
self.assertRaises(TypeError, sql.Identifier, 10)
self.assertRaises(TypeError, sql.Identifier, dt.date(2016, 12, 31))
def test_strings(self):
self.assertEqual(sql.Identifier('foo').strings, ('foo',))
self.assertEqual(sql.Identifier('foo', 'bar').strings, ('foo', 'bar'))
# Legacy method
self.assertEqual(sql.Identifier('foo').string, 'foo')
self.assertRaises(AttributeError,
getattr, sql.Identifier('foo', 'bar'), 'string')
def test_repr(self):
obj = sql.Identifier("fo'o")
self.assertEqual(repr(obj), 'Identifier("fo\'o")')
self.assertEqual(repr(obj), str(obj))
obj = sql.Identifier("fo'o", 'ba"r')
self.assertEqual(repr(obj), 'Identifier("fo\'o", \'ba"r\')')
self.assertEqual(repr(obj), str(obj))
def test_eq(self):
self.assert_(sql.Identifier('foo') == sql.Identifier('foo'))
self.assert_(sql.Identifier('foo', 'bar') == sql.Identifier('foo', 'bar'))
self.assert_(sql.Identifier('foo') != sql.Identifier('bar'))
self.assert_(sql.Identifier('foo') != 'foo')
self.assert_(sql.Identifier('foo') != sql.SQL('foo'))
def test_as_str(self):
self.assertEqual(
sql.Identifier('foo').as_string(self.conn), '"foo"')
self.assertEqual(
sql.Identifier('foo', 'bar').as_string(self.conn), '"foo"."bar"')
self.assertEqual(
sql.Identifier("fo'o", 'ba"r').as_string(self.conn), '"fo\'o"."ba""r"')
def test_join(self):
self.assert_(not hasattr(sql.Identifier('foo'), 'join'))
class LiteralTests(ConnectingTestCase):
def test_class(self):
self.assert_(issubclass(sql.Literal, sql.Composable))
def test_init(self):
self.assert_(isinstance(sql.Literal('foo'), sql.Literal))
self.assert_(isinstance(sql.Literal('foo'), sql.Literal))
self.assert_(isinstance(sql.Literal(b'foo'), sql.Literal))
self.assert_(isinstance(sql.Literal(42), sql.Literal))
self.assert_(isinstance(
sql.Literal(dt.date(2016, 12, 31)), sql.Literal))
def test_wrapped(self):
self.assertEqual(sql.Literal('foo').wrapped, 'foo')
def test_repr(self):
self.assertEqual(repr(sql.Literal("foo")), "Literal('foo')")
self.assertEqual(str(sql.Literal("foo")), "Literal('foo')")
self.assertQuotedEqual(sql.Literal("foo").as_string(self.conn), "'foo'")
self.assertEqual(sql.Literal(42).as_string(self.conn), "42")
self.assertEqual(
sql.Literal(dt.date(2017, 1, 1)).as_string(self.conn),
"'2017-01-01'::date")
def test_eq(self):
self.assert_(sql.Literal('foo') == sql.Literal('foo'))
self.assert_(sql.Literal('foo') != sql.Literal('bar'))
self.assert_(sql.Literal('foo') != 'foo')
self.assert_(sql.Literal('foo') != sql.SQL('foo'))
def test_must_be_adaptable(self):
class Foo:
pass
self.assertRaises(psycopg2.ProgrammingError,
sql.Literal(Foo()).as_string, self.conn)
class SQLTests(ConnectingTestCase):
def test_class(self):
self.assert_(issubclass(sql.SQL, sql.Composable))
def test_init(self):
self.assert_(isinstance(sql.SQL('foo'), sql.SQL))
self.assert_(isinstance(sql.SQL('foo'), sql.SQL))
self.assertRaises(TypeError, sql.SQL, 10)
self.assertRaises(TypeError, sql.SQL, dt.date(2016, 12, 31))
def test_string(self):
self.assertEqual(sql.SQL('foo').string, 'foo')
def test_repr(self):
self.assertEqual(repr(sql.SQL("foo")), "SQL('foo')")
self.assertEqual(str(sql.SQL("foo")), "SQL('foo')")
self.assertEqual(sql.SQL("foo").as_string(self.conn), "foo")
def test_eq(self):
self.assert_(sql.SQL('foo') == sql.SQL('foo'))
self.assert_(sql.SQL('foo') != sql.SQL('bar'))
self.assert_(sql.SQL('foo') != 'foo')
self.assert_(sql.SQL('foo') != sql.Literal('foo'))
def test_sum(self):
obj = sql.SQL("foo") + sql.SQL("bar")
self.assert_(isinstance(obj, sql.Composed))
self.assertEqual(obj.as_string(self.conn), "foobar")
def test_sum_inplace(self):
obj = sql.SQL("foo")
obj += sql.SQL("bar")
self.assert_(isinstance(obj, sql.Composed))
self.assertEqual(obj.as_string(self.conn), "foobar")
def test_multiply(self):
obj = sql.SQL("foo") * 3
self.assert_(isinstance(obj, sql.Composed))
self.assertEqual(obj.as_string(self.conn), "foofoofoo")
def test_join(self):
obj = sql.SQL(", ").join(
[sql.Identifier('foo'), sql.SQL('bar'), sql.Literal(42)])
self.assert_(isinstance(obj, sql.Composed))
self.assertEqual(obj.as_string(self.conn), '"foo", bar, 42')
obj = sql.SQL(", ").join(
sql.Composed([sql.Identifier('foo'), sql.SQL('bar'), sql.Literal(42)]))
self.assert_(isinstance(obj, sql.Composed))
self.assertEqual(obj.as_string(self.conn), '"foo", bar, 42')
obj = sql.SQL(", ").join([])
self.assertEqual(obj, sql.Composed([]))
class ComposedTest(ConnectingTestCase):
def test_class(self):
self.assert_(issubclass(sql.Composed, sql.Composable))
def test_repr(self):
obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
self.assertEqual(repr(obj),
"""Composed([Literal('foo'), Identifier("b'ar")])""")
self.assertEqual(str(obj), repr(obj))
def test_seq(self):
l = [sql.SQL('foo'), sql.Literal('bar'), sql.Identifier('baz')]
self.assertEqual(sql.Composed(l).seq, l)
def test_eq(self):
l = [sql.Literal("foo"), sql.Identifier("b'ar")]
l2 = [sql.Literal("foo"), sql.Literal("b'ar")]
self.assert_(sql.Composed(l) == sql.Composed(list(l)))
self.assert_(sql.Composed(l) != l)
self.assert_(sql.Composed(l) != sql.Composed(l2))
def test_join(self):
obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
obj = obj.join(", ")
self.assert_(isinstance(obj, sql.Composed))
self.assertQuotedEqual(obj.as_string(self.conn), "'foo', \"b'ar\"")
def test_sum(self):
obj = sql.Composed([sql.SQL("foo ")])
obj = obj + sql.Literal("bar")
self.assert_(isinstance(obj, sql.Composed))
self.assertQuotedEqual(obj.as_string(self.conn), "foo 'bar'")
def test_sum_inplace(self):
obj = sql.Composed([sql.SQL("foo ")])
obj += sql.Literal("bar")
self.assert_(isinstance(obj, sql.Composed))
self.assertQuotedEqual(obj.as_string(self.conn), "foo 'bar'")
obj = sql.Composed([sql.SQL("foo ")])
obj += sql.Composed([sql.Literal("bar")])
self.assert_(isinstance(obj, sql.Composed))
self.assertQuotedEqual(obj.as_string(self.conn), "foo 'bar'")
def test_iter(self):
obj = sql.Composed([sql.SQL("foo"), sql.SQL('bar')])
it = iter(obj)
i = next(it)
self.assertEqual(i, sql.SQL('foo'))
i = next(it)
self.assertEqual(i, sql.SQL('bar'))
self.assertRaises(StopIteration, next, it)
class PlaceholderTest(ConnectingTestCase):
def test_class(self):
self.assert_(issubclass(sql.Placeholder, sql.Composable))
def test_name(self):
self.assertEqual(sql.Placeholder().name, None)
self.assertEqual(sql.Placeholder('foo').name, 'foo')
def test_repr(self):
self.assert_(str(sql.Placeholder()), 'Placeholder()')
self.assert_(repr(sql.Placeholder()), 'Placeholder()')
self.assert_(sql.Placeholder().as_string(self.conn), '%s')
def test_repr_name(self):
self.assert_(str(sql.Placeholder('foo')), "Placeholder('foo')")
self.assert_(repr(sql.Placeholder('foo')), "Placeholder('foo')")
self.assert_(sql.Placeholder('foo').as_string(self.conn), '%(foo)s')
def test_bad_name(self):
self.assertRaises(ValueError, sql.Placeholder, ')')
def test_eq(self):
self.assert_(sql.Placeholder('foo') == sql.Placeholder('foo'))
self.assert_(sql.Placeholder('foo') != sql.Placeholder('bar'))
self.assert_(sql.Placeholder('foo') != 'foo')
self.assert_(sql.Placeholder() == sql.Placeholder())
self.assert_(sql.Placeholder('foo') != sql.Placeholder())
self.assert_(sql.Placeholder('foo') != sql.Literal('foo'))
class ValuesTest(ConnectingTestCase):
def test_null(self):
self.assert_(isinstance(sql.NULL, sql.SQL))
self.assertEqual(sql.NULL.as_string(self.conn), "NULL")
def test_default(self):
self.assert_(isinstance(sql.DEFAULT, sql.SQL))
self.assertEqual(sql.DEFAULT.as_string(self.conn), "DEFAULT")
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

258
tests/test_transaction.py Executable file
View File

@ -0,0 +1,258 @@
#!/usr/bin/env python
# test_transaction - unit test on transaction behaviour
#
# Copyright (C) 2007-2019 Federico Di Gregorio <fog@debian.org>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import threading
import unittest
from .testutils import ConnectingTestCase, skip_before_postgres, slow
from .testutils import skip_if_crdb
import psycopg2
from psycopg2.extensions import (
ISOLATION_LEVEL_SERIALIZABLE, STATUS_BEGIN, STATUS_READY)
class TransactionTests(ConnectingTestCase):
def setUp(self):
ConnectingTestCase.setUp(self)
skip_if_crdb("isolation level", self.conn)
self.conn.set_isolation_level(ISOLATION_LEVEL_SERIALIZABLE)
curs = self.conn.cursor()
curs.execute('''
CREATE TEMPORARY TABLE table1 (
id int PRIMARY KEY
)''')
# The constraint is set to deferrable for the commit_failed test
curs.execute('''
CREATE TEMPORARY TABLE table2 (
id int PRIMARY KEY,
table1_id int,
CONSTRAINT table2__table1_id__fk
FOREIGN KEY (table1_id) REFERENCES table1(id) DEFERRABLE)''')
curs.execute('INSERT INTO table1 VALUES (1)')
curs.execute('INSERT INTO table2 VALUES (1, 1)')
self.conn.commit()
def test_rollback(self):
# Test that rollback undoes changes
curs = self.conn.cursor()
curs.execute('INSERT INTO table2 VALUES (2, 1)')
# Rollback takes us from BEGIN state to READY state
self.assertEqual(self.conn.status, STATUS_BEGIN)
self.conn.rollback()
self.assertEqual(self.conn.status, STATUS_READY)
curs.execute('SELECT id, table1_id FROM table2 WHERE id = 2')
self.assertEqual(curs.fetchall(), [])
def test_commit(self):
# Test that commit stores changes
curs = self.conn.cursor()
curs.execute('INSERT INTO table2 VALUES (2, 1)')
# Rollback takes us from BEGIN state to READY state
self.assertEqual(self.conn.status, STATUS_BEGIN)
self.conn.commit()
self.assertEqual(self.conn.status, STATUS_READY)
# Now rollback and show that the new record is still there:
self.conn.rollback()
curs.execute('SELECT id, table1_id FROM table2 WHERE id = 2')
self.assertEqual(curs.fetchall(), [(2, 1)])
def test_failed_commit(self):
# Test that we can recover from a failed commit.
# We use a deferred constraint to cause a failure on commit.
curs = self.conn.cursor()
curs.execute('SET CONSTRAINTS table2__table1_id__fk DEFERRED')
curs.execute('INSERT INTO table2 VALUES (2, 42)')
# The commit should fail, and move the cursor back to READY state
self.assertEqual(self.conn.status, STATUS_BEGIN)
self.assertRaises(psycopg2.IntegrityError, self.conn.commit)
self.assertEqual(self.conn.status, STATUS_READY)
# The connection should be ready to use for the next transaction:
curs.execute('SELECT 1')
self.assertEqual(curs.fetchone()[0], 1)
class DeadlockSerializationTests(ConnectingTestCase):
"""Test deadlock and serialization failure errors."""
def connect(self):
conn = ConnectingTestCase.connect(self)
conn.set_isolation_level(ISOLATION_LEVEL_SERIALIZABLE)
return conn
def setUp(self):
ConnectingTestCase.setUp(self)
skip_if_crdb("isolation level", self.conn)
curs = self.conn.cursor()
# Drop table if it already exists
try:
curs.execute("DROP TABLE table1")
self.conn.commit()
except psycopg2.DatabaseError:
self.conn.rollback()
try:
curs.execute("DROP TABLE table2")
self.conn.commit()
except psycopg2.DatabaseError:
self.conn.rollback()
# Create sample data
curs.execute("""
CREATE TABLE table1 (
id int PRIMARY KEY,
name text)
""")
curs.execute("INSERT INTO table1 VALUES (1, 'hello')")
curs.execute("CREATE TABLE table2 (id int PRIMARY KEY)")
self.conn.commit()
def tearDown(self):
curs = self.conn.cursor()
curs.execute("DROP TABLE table1")
curs.execute("DROP TABLE table2")
self.conn.commit()
ConnectingTestCase.tearDown(self)
@slow
def test_deadlock(self):
self.thread1_error = self.thread2_error = None
step1 = threading.Event()
step2 = threading.Event()
def task1():
try:
conn = self.connect()
curs = conn.cursor()
curs.execute("LOCK table1 IN ACCESS EXCLUSIVE MODE")
step1.set()
step2.wait()
curs.execute("LOCK table2 IN ACCESS EXCLUSIVE MODE")
except psycopg2.DatabaseError as exc:
self.thread1_error = exc
step1.set()
conn.close()
def task2():
try:
conn = self.connect()
curs = conn.cursor()
step1.wait()
curs.execute("LOCK table2 IN ACCESS EXCLUSIVE MODE")
step2.set()
curs.execute("LOCK table1 IN ACCESS EXCLUSIVE MODE")
except psycopg2.DatabaseError as exc:
self.thread2_error = exc
step2.set()
conn.close()
# Run the threads in parallel. The "step1" and "step2" events
# ensure that the two transactions overlap.
thread1 = threading.Thread(target=task1)
thread2 = threading.Thread(target=task2)
thread1.start()
thread2.start()
thread1.join()
thread2.join()
# Exactly one of the threads should have failed with
# TransactionRollbackError:
self.assertFalse(self.thread1_error and self.thread2_error)
error = self.thread1_error or self.thread2_error
self.assertTrue(isinstance(
error, psycopg2.extensions.TransactionRollbackError))
@slow
def test_serialisation_failure(self):
self.thread1_error = self.thread2_error = None
step1 = threading.Event()
step2 = threading.Event()
def task1():
try:
conn = self.connect()
curs = conn.cursor()
curs.execute("SELECT name FROM table1 WHERE id = 1")
curs.fetchall()
step1.set()
step2.wait()
curs.execute("UPDATE table1 SET name='task1' WHERE id = 1")
conn.commit()
except psycopg2.DatabaseError as exc:
self.thread1_error = exc
step1.set()
conn.close()
def task2():
try:
conn = self.connect()
curs = conn.cursor()
step1.wait()
curs.execute("UPDATE table1 SET name='task2' WHERE id = 1")
conn.commit()
except psycopg2.DatabaseError as exc:
self.thread2_error = exc
step2.set()
conn.close()
# Run the threads in parallel. The "step1" and "step2" events
# ensure that the two transactions overlap.
thread1 = threading.Thread(target=task1)
thread2 = threading.Thread(target=task2)
thread1.start()
thread2.start()
thread1.join()
thread2.join()
# Exactly one of the threads should have failed with
# TransactionRollbackError:
self.assertFalse(self.thread1_error and self.thread2_error)
error = self.thread1_error or self.thread2_error
self.assertTrue(isinstance(
error, psycopg2.extensions.TransactionRollbackError))
class QueryCancellationTests(ConnectingTestCase):
"""Tests for query cancellation."""
def setUp(self):
ConnectingTestCase.setUp(self)
self.conn.set_isolation_level(ISOLATION_LEVEL_SERIALIZABLE)
@skip_before_postgres(8, 2)
def test_statement_timeout(self):
curs = self.conn.cursor()
# Set a low statement timeout, then sleep for a longer period.
curs.execute('SET statement_timeout TO 10')
self.assertRaises(psycopg2.extensions.QueryCanceledError,
curs.execute, 'SELECT pg_sleep(50)')
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

492
tests/test_types_basic.py Executable file
View File

@ -0,0 +1,492 @@
#!/usr/bin/env python
#
# types_basic.py - tests for basic types conversions
#
# Copyright (C) 2004-2019 Federico Di Gregorio <fog@debian.org>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import string
import ctypes
import decimal
import datetime
import platform
from . import testutils
import unittest
from .testutils import ConnectingTestCase, restore_types
from .testutils import skip_if_crdb
import psycopg2
from psycopg2.extensions import AsIs, adapt, register_adapter
class TypesBasicTests(ConnectingTestCase):
"""Test that all type conversions are working."""
def execute(self, *args):
curs = self.conn.cursor()
curs.execute(*args)
return curs.fetchone()[0]
def testQuoting(self):
s = "Quote'this\\! ''ok?''"
self.failUnless(self.execute("SELECT %s AS foo", (s,)) == s,
"wrong quoting: " + s)
def testUnicode(self):
s = "Quote'this\\! ''ok?''"
self.failUnless(self.execute("SELECT %s AS foo", (s,)) == s,
"wrong unicode quoting: " + s)
def testNumber(self):
s = self.execute("SELECT %s AS foo", (1971,))
self.failUnless(s == 1971, "wrong integer quoting: " + str(s))
def testBoolean(self):
x = self.execute("SELECT %s as foo", (False,))
self.assert_(x is False)
x = self.execute("SELECT %s as foo", (True,))
self.assert_(x is True)
def testDecimal(self):
s = self.execute("SELECT %s AS foo", (decimal.Decimal("19.10"),))
self.failUnless(s - decimal.Decimal("19.10") == 0,
"wrong decimal quoting: " + str(s))
s = self.execute("SELECT %s AS foo", (decimal.Decimal("NaN"),))
self.failUnless(str(s) == "NaN", "wrong decimal quoting: " + str(s))
self.failUnless(type(s) == decimal.Decimal,
"wrong decimal conversion: " + repr(s))
s = self.execute("SELECT %s AS foo", (decimal.Decimal("infinity"),))
self.failUnless(str(s) == "NaN", "wrong decimal quoting: " + str(s))
self.failUnless(type(s) == decimal.Decimal,
"wrong decimal conversion: " + repr(s))
s = self.execute("SELECT %s AS foo", (decimal.Decimal("-infinity"),))
self.failUnless(str(s) == "NaN", "wrong decimal quoting: " + str(s))
self.failUnless(type(s) == decimal.Decimal,
"wrong decimal conversion: " + repr(s))
def testFloatNan(self):
try:
float("nan")
except ValueError:
return self.skipTest("nan not available on this platform")
s = self.execute("SELECT %s AS foo", (float("nan"),))
self.failUnless(str(s) == "nan", "wrong float quoting: " + str(s))
self.failUnless(type(s) == float, "wrong float conversion: " + repr(s))
def testFloatInf(self):
try:
self.execute("select 'inf'::float")
except psycopg2.DataError:
return self.skipTest("inf::float not available on the server")
except ValueError:
return self.skipTest("inf not available on this platform")
s = self.execute("SELECT %s AS foo", (float("inf"),))
self.failUnless(str(s) == "inf", "wrong float quoting: " + str(s))
self.failUnless(type(s) == float, "wrong float conversion: " + repr(s))
s = self.execute("SELECT %s AS foo", (float("-inf"),))
self.failUnless(str(s) == "-inf", "wrong float quoting: " + str(s))
def testBinary(self):
s = bytes(range(256))
b = psycopg2.Binary(s)
buf = self.execute("SELECT %s::bytea AS foo", (b,))
self.assertEqual(s, buf.tobytes())
def testBinaryNone(self):
b = psycopg2.Binary(None)
buf = self.execute("SELECT %s::bytea AS foo", (b,))
self.assertEqual(buf, None)
def testBinaryEmptyString(self):
# test to make sure an empty Binary is converted to an empty string
b = psycopg2.Binary(bytes([]))
self.assertEqual(str(b), "''::bytea")
def testBinaryRoundTrip(self):
# test to make sure buffers returned by psycopg2 are
# understood by execute:
s = bytes(range(256))
buf = self.execute("SELECT %s::bytea AS foo", (psycopg2.Binary(s),))
buf2 = self.execute("SELECT %s::bytea AS foo", (buf,))
self.assertEqual(s, buf2.tobytes())
@skip_if_crdb("nested array")
def testArray(self):
s = self.execute("SELECT %s AS foo", ([[1, 2], [3, 4]],))
self.failUnlessEqual(s, [[1, 2], [3, 4]])
s = self.execute("SELECT %s AS foo", (['one', 'two', 'three'],))
self.failUnlessEqual(s, ['one', 'two', 'three'])
@skip_if_crdb("nested array")
def testEmptyArrayRegression(self):
# ticket #42
curs = self.conn.cursor()
curs.execute(
"create table array_test "
"(id integer, col timestamp without time zone[])")
curs.execute("insert into array_test values (%s, %s)",
(1, [datetime.date(2011, 2, 14)]))
curs.execute("select col from array_test where id = 1")
self.assertEqual(curs.fetchone()[0], [datetime.datetime(2011, 2, 14, 0, 0)])
curs.execute("insert into array_test values (%s, %s)", (2, []))
curs.execute("select col from array_test where id = 2")
self.assertEqual(curs.fetchone()[0], [])
@skip_if_crdb("nested array")
@testutils.skip_before_postgres(8, 4)
def testNestedEmptyArray(self):
# issue #788
curs = self.conn.cursor()
curs.execute("select 10 = any(%s::int[])", ([[]], ))
self.assertFalse(curs.fetchone()[0])
def testEmptyArrayNoCast(self):
s = self.execute("SELECT '{}' AS foo")
self.assertEqual(s, '{}')
s = self.execute("SELECT %s AS foo", ([],))
self.assertEqual(s, '{}')
def testEmptyArray(self):
s = self.execute("SELECT '{}'::text[] AS foo")
self.failUnlessEqual(s, [])
s = self.execute("SELECT 1 != ALL(%s)", ([],))
self.failUnlessEqual(s, True)
# but don't break the strings :)
s = self.execute("SELECT '{}'::text AS foo")
self.failUnlessEqual(s, "{}")
def testArrayEscape(self):
ss = ['', '\\', '"', '\\\\', '\\"']
for s in ss:
r = self.execute("SELECT %s AS foo", (s,))
self.failUnlessEqual(s, r)
r = self.execute("SELECT %s AS foo", ([s],))
self.failUnlessEqual([s], r)
r = self.execute("SELECT %s AS foo", (ss,))
self.failUnlessEqual(ss, r)
def testArrayMalformed(self):
curs = self.conn.cursor()
ss = ['', '{', '{}}', '{' * 20 + '}' * 20]
for s in ss:
self.assertRaises(psycopg2.DataError,
psycopg2.extensions.STRINGARRAY, s.encode('utf8'), curs)
def testTextArray(self):
curs = self.conn.cursor()
curs.execute("select '{a,b,c}'::text[]")
x = curs.fetchone()[0]
self.assert_(isinstance(x[0], str))
self.assertEqual(x, ['a', 'b', 'c'])
def testUnicodeArray(self):
psycopg2.extensions.register_type(
psycopg2.extensions.UNICODEARRAY, self.conn)
curs = self.conn.cursor()
curs.execute("select '{a,b,c}'::text[]")
x = curs.fetchone()[0]
self.assert_(isinstance(x[0], str))
self.assertEqual(x, ['a', 'b', 'c'])
def testBytesArray(self):
psycopg2.extensions.register_type(
psycopg2.extensions.BYTESARRAY, self.conn)
curs = self.conn.cursor()
curs.execute("select '{a,b,c}'::text[]")
x = curs.fetchone()[0]
self.assert_(isinstance(x[0], bytes))
self.assertEqual(x, [b'a', b'b', b'c'])
@skip_if_crdb("nested array")
@testutils.skip_before_postgres(8, 2)
def testArrayOfNulls(self):
curs = self.conn.cursor()
curs.execute("""
create table na (
texta text[],
inta int[],
boola boolean[],
textaa text[][],
intaa int[][],
boolaa boolean[][]
)""")
curs.execute("insert into na (texta) values (%s)", ([None],))
curs.execute("insert into na (texta) values (%s)", (['a', None],))
curs.execute("insert into na (texta) values (%s)", ([None, None],))
curs.execute("insert into na (inta) values (%s)", ([None],))
curs.execute("insert into na (inta) values (%s)", ([42, None],))
curs.execute("insert into na (inta) values (%s)", ([None, None],))
curs.execute("insert into na (boola) values (%s)", ([None],))
curs.execute("insert into na (boola) values (%s)", ([True, None],))
curs.execute("insert into na (boola) values (%s)", ([None, None],))
curs.execute("insert into na (textaa) values (%s)", ([[None]],))
curs.execute("insert into na (textaa) values (%s)", ([['a', None]],))
curs.execute("insert into na (textaa) values (%s)", ([[None, None]],))
curs.execute("insert into na (intaa) values (%s)", ([[None]],))
curs.execute("insert into na (intaa) values (%s)", ([[42, None]],))
curs.execute("insert into na (intaa) values (%s)", ([[None, None]],))
curs.execute("insert into na (boolaa) values (%s)", ([[None]],))
curs.execute("insert into na (boolaa) values (%s)", ([[True, None]],))
curs.execute("insert into na (boolaa) values (%s)", ([[None, None]],))
@skip_if_crdb("nested array")
@testutils.skip_before_postgres(8, 2)
def testNestedArrays(self):
curs = self.conn.cursor()
for a in [
[[1]],
[[None]],
[[None, None, None]],
[[None, None], [1, None]],
[[None, None], [None, None]],
[[[None, None], [None, None]]],
]:
curs.execute("select %s::int[]", (a,))
self.assertEqual(curs.fetchone()[0], a)
def testTypeRoundtripBytes(self):
o1 = bytes(range(256))
o2 = self.execute("select %s;", (o1,))
self.assertEqual(memoryview, type(o2))
# Test with an empty buffer
o1 = bytes([])
o2 = self.execute("select %s;", (o1,))
self.assertEqual(memoryview, type(o2))
def testTypeRoundtripBytesArray(self):
o1 = bytes(range(256))
o1 = [o1]
o2 = self.execute("select %s;", (o1,))
self.assertEqual(memoryview, type(o2[0]))
def testAdaptBytearray(self):
o1 = bytearray(range(256))
o2 = self.execute("select %s;", (o1,))
self.assertEqual(memoryview, type(o2))
self.assertEqual(len(o1), len(o2))
for c1, c2 in zip(o1, o2):
self.assertEqual(c1, ord(c2))
# Test with an empty buffer
o1 = bytearray([])
o2 = self.execute("select %s;", (o1,))
self.assertEqual(len(o2), 0)
self.assertEqual(memoryview, type(o2))
def testAdaptMemoryview(self):
o1 = memoryview(bytearray(range(256)))
o2 = self.execute("select %s;", (o1,))
self.assertEqual(memoryview, type(o2))
# Test with an empty buffer
o1 = memoryview(bytearray([]))
o2 = self.execute("select %s;", (o1,))
self.assertEqual(memoryview, type(o2))
def testByteaHexCheckFalsePositive(self):
# the check \x -> x to detect bad bytea decode
# may be fooled if the first char is really an 'x'
o1 = psycopg2.Binary(b'x')
o2 = self.execute("SELECT %s::bytea AS foo", (o1,))
self.assertEqual(b'x', o2[0])
def testNegNumber(self):
d1 = self.execute("select -%s;", (decimal.Decimal('-1.0'),))
self.assertEqual(1, d1)
f1 = self.execute("select -%s;", (-1.0,))
self.assertEqual(1, f1)
i1 = self.execute("select -%s;", (-1,))
self.assertEqual(1, i1)
def testGenericArray(self):
a = self.execute("select '{1, 2, 3}'::int4[]")
self.assertEqual(a, [1, 2, 3])
a = self.execute("select array['a', 'b', '''']::text[]")
self.assertEqual(a, ['a', 'b', "'"])
@testutils.skip_before_postgres(8, 2)
def testGenericArrayNull(self):
def caster(s, cur):
if s is None:
return "nada"
return int(s) * 2
base = psycopg2.extensions.new_type((23,), "INT4", caster)
array = psycopg2.extensions.new_array_type((1007,), "INT4ARRAY", base)
psycopg2.extensions.register_type(array, self.conn)
a = self.execute("select '{1, 2, 3}'::int4[]")
self.assertEqual(a, [2, 4, 6])
a = self.execute("select '{1, 2, NULL}'::int4[]")
self.assertEqual(a, [2, 4, 'nada'])
@skip_if_crdb("cidr")
@testutils.skip_before_postgres(8, 2)
def testNetworkArray(self):
# we don't know these types, but we know their arrays
a = self.execute("select '{192.168.0.1/24}'::inet[]")
self.assertEqual(a, ['192.168.0.1/24'])
a = self.execute("select '{192.168.0.0/24}'::cidr[]")
self.assertEqual(a, ['192.168.0.0/24'])
a = self.execute("select '{10:20:30:40:50:60}'::macaddr[]")
self.assertEqual(a, ['10:20:30:40:50:60'])
def testIntEnum(self):
from enum import IntEnum
class Color(IntEnum):
RED = 1
GREEN = 2
BLUE = 4
a = self.execute("select %s", (Color.GREEN,))
self.assertEqual(a, Color.GREEN)
class AdaptSubclassTest(unittest.TestCase):
def test_adapt_subtype(self):
class Sub(str):
pass
s1 = "hel'lo"
s2 = Sub(s1)
self.assertEqual(adapt(s1).getquoted(), adapt(s2).getquoted())
@restore_types
def test_adapt_most_specific(self):
class A:
pass
class B(A):
pass
class C(B):
pass
register_adapter(A, lambda a: AsIs("a"))
register_adapter(B, lambda b: AsIs("b"))
self.assertEqual(b'b', adapt(C()).getquoted())
@restore_types
def test_adapt_subtype_3(self):
class A:
pass
class B(A):
pass
register_adapter(A, lambda a: AsIs("a"))
self.assertEqual(b"a", adapt(B()).getquoted())
def test_conform_subclass_precedence(self):
class foo(tuple):
def __conform__(self, proto):
return self
def getquoted(self):
return 'bar'
self.assertEqual(adapt(foo((1, 2, 3))).getquoted(), 'bar')
@unittest.skipIf(
platform.system() == 'Windows',
"Not testing because we are useless with ctypes on Windows")
class ByteaParserTest(unittest.TestCase):
"""Unit test for our bytea format parser."""
def setUp(self):
self._cast = self._import_cast()
def _import_cast(self):
"""Use ctypes to access the C function."""
lib = ctypes.pydll.LoadLibrary(psycopg2._psycopg.__file__)
cast = lib.typecast_BINARY_cast
cast.argtypes = [ctypes.c_char_p, ctypes.c_size_t, ctypes.py_object]
cast.restype = ctypes.py_object
return cast
def cast(self, buffer):
"""Cast a buffer from the output format"""
l = buffer and len(buffer) or 0
rv = self._cast(buffer, l, None)
if rv is None:
return None
return rv.tobytes()
def test_null(self):
rv = self.cast(None)
self.assertEqual(rv, None)
def test_blank(self):
rv = self.cast(b'')
self.assertEqual(rv, b'')
def test_blank_hex(self):
# Reported as problematic in ticket #48
rv = self.cast(b'\\x')
self.assertEqual(rv, b'')
def test_full_hex(self, upper=False):
buf = ''.join(("%02x" % i) for i in range(256))
if upper:
buf = buf.upper()
buf = '\\x' + buf
rv = self.cast(buf.encode('utf8'))
self.assertEqual(rv, bytes(range(256)))
def test_full_hex_upper(self):
return self.test_full_hex(upper=True)
def test_full_escaped_octal(self):
buf = ''.join(("\\%03o" % i) for i in range(256))
rv = self.cast(buf.encode('utf8'))
self.assertEqual(rv, bytes(range(256)))
def test_escaped_mixed(self):
buf = ''.join(("\\%03o" % i) for i in range(32))
buf += string.ascii_letters
buf += ''.join('\\' + c for c in string.ascii_letters)
buf += '\\\\'
rv = self.cast(buf.encode('utf8'))
tgt = bytes(range(32)) + \
(string.ascii_letters * 2 + '\\').encode('ascii')
self.assertEqual(rv, tgt)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

1597
tests/test_types_extras.py Executable file

File diff suppressed because it is too large Load Diff

319
tests/test_with.py Executable file
View File

@ -0,0 +1,319 @@
#!/usr/bin/env python
# test_ctxman.py - unit test for connection and cursor used as context manager
#
# Copyright (C) 2012-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import psycopg2
import psycopg2.extensions as ext
import unittest
from .testutils import ConnectingTestCase, skip_before_postgres, skip_if_crdb
class WithTestCase(ConnectingTestCase):
def setUp(self):
ConnectingTestCase.setUp(self)
curs = self.conn.cursor()
try:
curs.execute("delete from test_with")
self.conn.commit()
except psycopg2.ProgrammingError:
# assume table doesn't exist
self.conn.rollback()
curs.execute("create table test_with (id integer primary key)")
self.conn.commit()
class WithConnectionTestCase(WithTestCase):
def test_with_ok(self):
with self.conn as conn:
self.assert_(self.conn is conn)
self.assertEqual(conn.status, ext.STATUS_READY)
curs = conn.cursor()
curs.execute("insert into test_with values (1)")
self.assertEqual(conn.status, ext.STATUS_BEGIN)
self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed)
curs = self.conn.cursor()
curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [(1,)])
def test_with_connect_idiom(self):
with self.connect() as conn:
self.assertEqual(conn.status, ext.STATUS_READY)
curs = conn.cursor()
curs.execute("insert into test_with values (2)")
self.assertEqual(conn.status, ext.STATUS_BEGIN)
self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed)
curs = self.conn.cursor()
curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [(2,)])
def test_with_error_db(self):
def f():
with self.conn as conn:
curs = conn.cursor()
curs.execute("insert into test_with values ('a')")
self.assertRaises(psycopg2.DataError, f)
self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed)
curs = self.conn.cursor()
curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [])
def test_with_error_python(self):
def f():
with self.conn as conn:
curs = conn.cursor()
curs.execute("insert into test_with values (3)")
1 / 0
self.assertRaises(ZeroDivisionError, f)
self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed)
curs = self.conn.cursor()
curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [])
def test_with_closed(self):
def f():
with self.conn:
pass
self.conn.close()
self.assertRaises(psycopg2.InterfaceError, f)
def test_subclass_commit(self):
commits = []
class MyConn(ext.connection):
def commit(self):
commits.append(None)
super().commit()
with self.connect(connection_factory=MyConn) as conn:
curs = conn.cursor()
curs.execute("insert into test_with values (10)")
self.assertEqual(conn.status, ext.STATUS_READY)
self.assert_(commits)
curs = self.conn.cursor()
curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [(10,)])
def test_subclass_rollback(self):
rollbacks = []
class MyConn(ext.connection):
def rollback(self):
rollbacks.append(None)
super().rollback()
try:
with self.connect(connection_factory=MyConn) as conn:
curs = conn.cursor()
curs.execute("insert into test_with values (11)")
1 / 0
except ZeroDivisionError:
pass
else:
self.assert_("exception not raised")
self.assertEqual(conn.status, ext.STATUS_READY)
self.assert_(rollbacks)
curs = conn.cursor()
curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [])
def test_cant_reenter(self):
raised_ok = False
with self.conn:
try:
with self.conn:
pass
except psycopg2.ProgrammingError:
raised_ok = True
self.assert_(raised_ok)
# Still good
with self.conn:
pass
def test_with_autocommit(self):
self.conn.autocommit = True
self.assertEqual(
self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE
)
with self.conn:
curs = self.conn.cursor()
curs.execute("insert into test_with values (1)")
self.assertEqual(
self.conn.info.transaction_status,
ext.TRANSACTION_STATUS_INTRANS,
)
self.assertEqual(
self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE
)
curs.execute("select count(*) from test_with")
self.assertEqual(curs.fetchone()[0], 1)
self.assertEqual(
self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE
)
def test_with_autocommit_pyerror(self):
self.conn.autocommit = True
raised_ok = False
try:
with self.conn:
curs = self.conn.cursor()
curs.execute("insert into test_with values (2)")
self.assertEqual(
self.conn.info.transaction_status,
ext.TRANSACTION_STATUS_INTRANS,
)
1 / 0
except ZeroDivisionError:
raised_ok = True
self.assert_(raised_ok)
self.assertEqual(
self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE
)
curs.execute("select count(*) from test_with")
self.assertEqual(curs.fetchone()[0], 0)
self.assertEqual(
self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE
)
def test_with_autocommit_pgerror(self):
self.conn.autocommit = True
raised_ok = False
try:
with self.conn:
curs = self.conn.cursor()
curs.execute("insert into test_with values (2)")
self.assertEqual(
self.conn.info.transaction_status,
ext.TRANSACTION_STATUS_INTRANS,
)
curs.execute("insert into test_with values ('x')")
except psycopg2.errors.InvalidTextRepresentation:
raised_ok = True
self.assert_(raised_ok)
self.assertEqual(
self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE
)
curs.execute("select count(*) from test_with")
self.assertEqual(curs.fetchone()[0], 0)
self.assertEqual(
self.conn.info.transaction_status, ext.TRANSACTION_STATUS_IDLE
)
class WithCursorTestCase(WithTestCase):
def test_with_ok(self):
with self.conn as conn:
with conn.cursor() as curs:
curs.execute("insert into test_with values (4)")
self.assert_(not curs.closed)
self.assertEqual(self.conn.status, ext.STATUS_BEGIN)
self.assert_(curs.closed)
self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed)
curs = self.conn.cursor()
curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [(4,)])
def test_with_error(self):
try:
with self.conn as conn:
with conn.cursor() as curs:
curs.execute("insert into test_with values (5)")
1 / 0
except ZeroDivisionError:
pass
self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed)
self.assert_(curs.closed)
curs = self.conn.cursor()
curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [])
def test_subclass(self):
closes = []
class MyCurs(ext.cursor):
def close(self):
closes.append(None)
super().close()
with self.conn.cursor(cursor_factory=MyCurs) as curs:
self.assert_(isinstance(curs, MyCurs))
self.assert_(curs.closed)
self.assert_(closes)
@skip_if_crdb("named cursor")
def test_exception_swallow(self):
# bug #262: __exit__ calls cur.close() that hides the exception
# with another error.
try:
with self.conn as conn:
with conn.cursor('named') as cur:
cur.execute("select 1/0")
cur.fetchone()
except psycopg2.DataError as e:
self.assertEqual(e.pgcode, '22012')
else:
self.fail("where is my exception?")
@skip_if_crdb("named cursor")
@skip_before_postgres(8, 2)
def test_named_with_noop(self):
with self.conn.cursor('named'):
pass
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

42
tests/testconfig.py Normal file
View File

@ -0,0 +1,42 @@
# Configure the test suite from the env variables.
import os
dbname = os.environ.get('PSYCOPG2_TESTDB', 'psycopg2_test')
dbhost = os.environ.get('PSYCOPG2_TESTDB_HOST', os.environ.get('PGHOST'))
dbport = os.environ.get('PSYCOPG2_TESTDB_PORT', os.environ.get('PGPORT'))
dbuser = os.environ.get('PSYCOPG2_TESTDB_USER', os.environ.get('PGUSER'))
dbpass = os.environ.get('PSYCOPG2_TESTDB_PASSWORD', os.environ.get('PGPASSWORD'))
# Check if we want to test psycopg's green path.
green = os.environ.get('PSYCOPG2_TEST_GREEN', None)
if green:
if green == '1':
from psycopg2.extras import wait_select as wait_callback
elif green == 'eventlet':
from eventlet.support.psycopg2_patcher import eventlet_wait_callback \
as wait_callback
else:
raise ValueError("please set 'PSYCOPG2_TEST_GREEN' to a valid value")
import psycopg2.extensions
psycopg2.extensions.set_wait_callback(wait_callback)
# Construct a DSN to connect to the test database:
dsn = f'dbname={dbname}'
if dbhost is not None:
dsn += f' host={dbhost}'
if dbport is not None:
dsn += f' port={dbport}'
if dbuser is not None:
dsn += f' user={dbuser}'
if dbpass is not None:
dsn += f' password={dbpass}'
# Don't run replication tests if REPL_DSN is not set, default to normal DSN if
# set to empty string.
repl_dsn = os.environ.get('PSYCOPG2_TEST_REPL_DSN', None)
if repl_dsn == '':
repl_dsn = dsn
repl_slot = os.environ.get('PSYCOPG2_TEST_REPL_SLOT', 'psycopg2_test_slot')

544
tests/testutils.py Normal file
View File

@ -0,0 +1,544 @@
# testutils.py - utility module for psycopg2 testing.
#
# Copyright (C) 2010-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com>
# Copyright (C) 2020-2021 The Psycopg Team
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import re
import os
import sys
import types
import ctypes
import select
import operator
import platform
import unittest
from functools import wraps
from ctypes.util import find_library
from io import StringIO # noqa
from io import TextIOBase # noqa
from importlib import reload # noqa
import psycopg2
import psycopg2.errors
import psycopg2.extensions
from .testconfig import green, dsn, repl_dsn
# Silence warnings caused by the stubbornness of the Python unittest
# maintainers
# https://bugs.python.org/issue9424
if (not hasattr(unittest.TestCase, 'assert_')
or unittest.TestCase.assert_ is not unittest.TestCase.assertTrue):
# mavaff...
unittest.TestCase.assert_ = unittest.TestCase.assertTrue
unittest.TestCase.failUnless = unittest.TestCase.assertTrue
unittest.TestCase.assertEquals = unittest.TestCase.assertEqual
unittest.TestCase.failUnlessEqual = unittest.TestCase.assertEqual
def assertDsnEqual(self, dsn1, dsn2, msg=None):
"""Check that two conninfo string have the same content"""
self.assertEqual(set(dsn1.split()), set(dsn2.split()), msg)
unittest.TestCase.assertDsnEqual = assertDsnEqual
class ConnectingTestCase(unittest.TestCase):
"""A test case providing connections for tests.
A connection for the test is always available as `self.conn`. Others can be
created with `self.connect()`. All are closed on tearDown.
Subclasses needing to customize setUp and tearDown should remember to call
the base class implementations.
"""
def setUp(self):
self._conns = []
def tearDown(self):
# close the connections used in the test
for conn in self._conns:
if not conn.closed:
conn.close()
def assertQuotedEqual(self, first, second, msg=None):
"""Compare two quoted strings disregarding eventual E'' quotes"""
def f(s):
if isinstance(s, str):
return re.sub(r"\bE'", "'", s)
elif isinstance(first, bytes):
return re.sub(br"\bE'", b"'", s)
else:
return s
return self.assertEqual(f(first), f(second), msg)
def connect(self, **kwargs):
try:
self._conns
except AttributeError as e:
raise AttributeError(
f"{e} (did you forget to call ConnectingTestCase.setUp()?)")
if 'dsn' in kwargs:
conninfo = kwargs.pop('dsn')
else:
conninfo = dsn
conn = psycopg2.connect(conninfo, **kwargs)
self._conns.append(conn)
return conn
def repl_connect(self, **kwargs):
"""Return a connection set up for replication
The connection is on "PSYCOPG2_TEST_REPL_DSN" unless overridden by
a *dsn* kwarg.
Should raise a skip test if not available, but guard for None on
old Python versions.
"""
if repl_dsn is None:
return self.skipTest("replication tests disabled by default")
if 'dsn' not in kwargs:
kwargs['dsn'] = repl_dsn
try:
conn = self.connect(**kwargs)
if conn.async_ == 1:
self.wait(conn)
except psycopg2.OperationalError as e:
# If pgcode is not set it is a genuine connection error
# Otherwise we tried to run some bad operation in the connection
# (e.g. bug #482) and we'd rather know that.
if e.pgcode is None:
return self.skipTest(f"replication db not configured: {e}")
else:
raise
return conn
def _get_conn(self):
if not hasattr(self, '_the_conn'):
self._the_conn = self.connect()
return self._the_conn
def _set_conn(self, conn):
self._the_conn = conn
conn = property(_get_conn, _set_conn)
# for use with async connections only
def wait(self, cur_or_conn):
pollable = cur_or_conn
if not hasattr(pollable, 'poll'):
pollable = cur_or_conn.connection
while True:
state = pollable.poll()
if state == psycopg2.extensions.POLL_OK:
break
elif state == psycopg2.extensions.POLL_READ:
select.select([pollable], [], [], 1)
elif state == psycopg2.extensions.POLL_WRITE:
select.select([], [pollable], [], 1)
else:
raise Exception("Unexpected result from poll: %r", state)
_libpq = None
@property
def libpq(self):
"""Return a ctypes wrapper for the libpq library"""
if ConnectingTestCase._libpq is not None:
return ConnectingTestCase._libpq
libname = find_library('pq')
if libname is None and platform.system() == 'Windows':
raise self.skipTest("can't import libpq on windows")
try:
rv = ConnectingTestCase._libpq = ctypes.pydll.LoadLibrary(libname)
except OSError as e:
raise self.skipTest("couldn't open libpq for testing: %s" % e)
return rv
def decorate_all_tests(obj, *decorators):
"""
Apply all the *decorators* to all the tests defined in the TestCase *obj*.
The decorator can also be applied to a decorator: if *obj* is a function,
return a new decorator which can be applied either to a method or to a
class, in which case it will decorate all the tests.
"""
if isinstance(obj, types.FunctionType):
def decorator(func_or_cls):
if isinstance(func_or_cls, types.FunctionType):
return obj(func_or_cls)
else:
decorate_all_tests(func_or_cls, obj)
return func_or_cls
return decorator
for n in dir(obj):
if n.startswith('test'):
for d in decorators:
setattr(obj, n, d(getattr(obj, n)))
@decorate_all_tests
def skip_if_no_uuid(f):
"""Decorator to skip a test if uuid is not supported by PG."""
@wraps(f)
def skip_if_no_uuid_(self):
try:
cur = self.conn.cursor()
cur.execute("select typname from pg_type where typname = 'uuid'")
has = cur.fetchone()
finally:
self.conn.rollback()
if has:
return f(self)
else:
return self.skipTest("uuid type not available on the server")
return skip_if_no_uuid_
@decorate_all_tests
def skip_if_tpc_disabled(f):
"""Skip a test if the server has tpc support disabled."""
@wraps(f)
def skip_if_tpc_disabled_(self):
cnn = self.connect()
skip_if_crdb("2-phase commit", cnn)
cur = cnn.cursor()
try:
cur.execute("SHOW max_prepared_transactions;")
except psycopg2.ProgrammingError:
return self.skipTest(
"server too old: two phase transactions not supported.")
else:
mtp = int(cur.fetchone()[0])
cnn.close()
if not mtp:
return self.skipTest(
"server not configured for two phase transactions. "
"set max_prepared_transactions to > 0 to run the test")
return f(self)
return skip_if_tpc_disabled_
def skip_before_postgres(*ver):
"""Skip a test on PostgreSQL before a certain version."""
reason = None
if isinstance(ver[-1], str):
ver, reason = ver[:-1], ver[-1]
ver = ver + (0,) * (3 - len(ver))
@decorate_all_tests
def skip_before_postgres_(f):
@wraps(f)
def skip_before_postgres__(self):
if self.conn.info.server_version < int("%d%02d%02d" % ver):
return self.skipTest(
reason or "skipped because PostgreSQL %s"
% self.conn.info.server_version)
else:
return f(self)
return skip_before_postgres__
return skip_before_postgres_
def skip_after_postgres(*ver):
"""Skip a test on PostgreSQL after (including) a certain version."""
ver = ver + (0,) * (3 - len(ver))
@decorate_all_tests
def skip_after_postgres_(f):
@wraps(f)
def skip_after_postgres__(self):
if self.conn.info.server_version >= int("%d%02d%02d" % ver):
return self.skipTest("skipped because PostgreSQL %s"
% self.conn.info.server_version)
else:
return f(self)
return skip_after_postgres__
return skip_after_postgres_
def libpq_version():
v = psycopg2.__libpq_version__
if v >= 90100:
v = min(v, psycopg2.extensions.libpq_version())
return v
def skip_before_libpq(*ver):
"""Skip a test if libpq we're linked to is older than a certain version."""
ver = ver + (0,) * (3 - len(ver))
def skip_before_libpq_(cls):
v = libpq_version()
decorator = unittest.skipIf(
v < int("%d%02d%02d" % ver),
f"skipped because libpq {v}",
)
return decorator(cls)
return skip_before_libpq_
def skip_after_libpq(*ver):
"""Skip a test if libpq we're linked to is newer than a certain version."""
ver = ver + (0,) * (3 - len(ver))
def skip_after_libpq_(cls):
v = libpq_version()
decorator = unittest.skipIf(
v >= int("%d%02d%02d" % ver),
f"skipped because libpq {v}",
)
return decorator(cls)
return skip_after_libpq_
def skip_before_python(*ver):
"""Skip a test on Python before a certain version."""
def skip_before_python_(cls):
decorator = unittest.skipIf(
sys.version_info[:len(ver)] < ver,
f"skipped because Python {'.'.join(map(str, sys.version_info[:len(ver)]))}",
)
return decorator(cls)
return skip_before_python_
def skip_from_python(*ver):
"""Skip a test on Python after (including) a certain version."""
def skip_from_python_(cls):
decorator = unittest.skipIf(
sys.version_info[:len(ver)] >= ver,
f"skipped because Python {'.'.join(map(str, sys.version_info[:len(ver)]))}",
)
return decorator(cls)
return skip_from_python_
@decorate_all_tests
def skip_if_no_superuser(f):
"""Skip a test if the database user running the test is not a superuser"""
@wraps(f)
def skip_if_no_superuser_(self):
try:
return f(self)
except psycopg2.errors.InsufficientPrivilege:
self.skipTest("skipped because not superuser")
return skip_if_no_superuser_
def skip_if_green(reason):
def skip_if_green_(cls):
decorator = unittest.skipIf(green, reason)
return decorator(cls)
return skip_if_green_
skip_copy_if_green = skip_if_green("copy in async mode currently not supported")
def skip_if_no_getrefcount(cls):
decorator = unittest.skipUnless(
hasattr(sys, 'getrefcount'),
'no sys.getrefcount()',
)
return decorator(cls)
def skip_if_windows(cls):
"""Skip a test if run on windows"""
decorator = unittest.skipIf(
platform.system() == 'Windows',
"Not supported on Windows",
)
return decorator(cls)
def crdb_version(conn, __crdb_version=[]):
"""
Return the CockroachDB version if that's the db being tested, else None.
Return the number as an integer similar to PQserverVersion: return
v20.1.3 as 200103.
Assume all the connections are on the same db: return a cached result on
following calls.
"""
if __crdb_version:
return __crdb_version[0]
sver = conn.info.parameter_status("crdb_version")
if sver is None:
__crdb_version.append(None)
else:
m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)
if not m:
raise ValueError(
f"can't parse CockroachDB version from {sver}")
ver = int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3))
__crdb_version.append(ver)
return __crdb_version[0]
def skip_if_crdb(reason, conn=None, version=None):
"""Skip a test or test class if we are testing against CockroachDB.
Can be used as a decorator for tests function or classes:
@skip_if_crdb("my reason")
class SomeUnitTest(UnitTest):
# ...
Or as a normal function if the *conn* argument is passed.
If *version* is specified it should be a string such as ">= 20.1", "< 20",
"== 20.1.3": the test will be skipped only if the version matches.
"""
if not isinstance(reason, str):
raise TypeError(f"reason should be a string, got {reason!r} instead")
if conn is not None:
ver = crdb_version(conn)
if ver is not None and _crdb_match_version(ver, version):
if reason in crdb_reasons:
reason = (
"%s (https://github.com/cockroachdb/cockroach/issues/%s)"
% (reason, crdb_reasons[reason]))
raise unittest.SkipTest(
f"not supported on CockroachDB {ver}: {reason}")
@decorate_all_tests
def skip_if_crdb_(f):
@wraps(f)
def skip_if_crdb__(self, *args, **kwargs):
skip_if_crdb(reason, conn=self.connect(), version=version)
return f(self, *args, **kwargs)
return skip_if_crdb__
return skip_if_crdb_
# mapping from reason description to ticket number
crdb_reasons = {
"2-phase commit": 22329,
"backend pid": 35897,
"cancel": 41335,
"cast adds tz": 51692,
"cidr": 18846,
"composite": 27792,
"copy": 41608,
"deferrable": 48307,
"encoding": 35882,
"hstore": 41284,
"infinity date": 41564,
"interval style": 35807,
"large objects": 243,
"named cursor": 41412,
"nested array": 32552,
"notify": 41522,
"password_encryption": 42519,
"range": 41282,
"stored procedure": 1751,
}
def _crdb_match_version(version, pattern):
if pattern is None:
return True
m = re.match(r'^(>|>=|<|<=|==|!=)\s*(\d+)(?:\.(\d+))?(?:\.(\d+))?$', pattern)
if m is None:
raise ValueError(
"bad crdb version pattern %r: should be 'OP MAJOR[.MINOR[.BUGFIX]]'"
% pattern)
ops = {'>': 'gt', '>=': 'ge', '<': 'lt', '<=': 'le', '==': 'eq', '!=': 'ne'}
op = getattr(operator, ops[m.group(1)])
ref = int(m.group(2)) * 10000 + int(m.group(3) or 0) * 100 + int(m.group(4) or 0)
return op(version, ref)
class raises_typeerror:
def __enter__(self):
pass
def __exit__(self, type, exc, tb):
assert type is TypeError
return True
def slow(f):
"""Decorator to mark slow tests we may want to skip
Note: in order to find slow tests you can run:
make check 2>&1 | ts -i "%.s" | sort -n
"""
@wraps(f)
def slow_(self):
if os.environ.get('PSYCOPG2_TEST_FAST', '0') != '0':
return self.skipTest("slow test")
return f(self)
return slow_
def restore_types(f):
"""Decorator to restore the adaptation system after running a test"""
@wraps(f)
def restore_types_(self):
types = psycopg2.extensions.string_types.copy()
adapters = psycopg2.extensions.adapters.copy()
try:
return f(self)
finally:
psycopg2.extensions.string_types.clear()
psycopg2.extensions.string_types.update(types)
psycopg2.extensions.adapters.clear()
psycopg2.extensions.adapters.update(adapters)
return restore_types_