root/pg8000/trunk/pg8000_test.py

Revision 852, 18.4 kB (checked in by mfenniak, 2 years ago)

add date and time data types, and timezonetz read. Add bool write -- somehow this was missed.

  • Property svn:executable set to
Line 
1 #!/usr/bin/env python
2
3 import datetime
4 import decimal
5 import threading
6 import unittest
7 import pg8000
8 import struct
9
10 db_joy_connect = {
11         "host": "joy",
12         "user": "pg8000-test",
13         "database": "pg8000-test",
14         "password": "pg8000-test",
15         "socket_timeout": 5,
16         "ssl": False,
17         }
18 db_local_connect = {
19         "unix_sock": "/tmp/.s.PGSQL.5432",
20         "user": "mfenniak"
21         }
22 db_connect = db_local_connect
23 db = pg8000.Connection(**db_connect)
24 dbapi = pg8000.DBAPI
25 db2 = dbapi.connect(**db_connect)
26
27 # Tests relating to the basic operation of the database driver, driven by the
28 # pg8000 custom interface.
29 class QueryTests(unittest.TestCase):
30     def setUp(self):
31         try:
32             db.execute("DROP TABLE t1")
33         except pg8000.DatabaseError, e:
34             # the only acceptable error is:
35             self.assert_(e.args[1] == '42P01', # table does not exist
36                     "incorrect error for drop table")
37         db.execute("CREATE TEMPORARY TABLE t1 (f1 int primary key, f2 int not null, f3 varchar(50) null)")
38
39     def TestParallelQueries(self):
40         db.execute("INSERT INTO t1 (f1, f2, f3) VALUES ($1, $2, $3)", 1, 1, None)
41         db.execute("INSERT INTO t1 (f1, f2, f3) VALUES ($1, $2, $3)", 2, 10, None)
42         db.execute("INSERT INTO t1 (f1, f2, f3) VALUES ($1, $2, $3)", 3, 100, None)
43         db.execute("INSERT INTO t1 (f1, f2, f3) VALUES ($1, $2, $3)", 4, 1000, None)
44         db.execute("INSERT INTO t1 (f1, f2, f3) VALUES ($1, $2, $3)", 5, 10000, None)
45         c1 = pg8000.Cursor(db)
46         c2 = pg8000.Cursor(db)
47         c1.execute("SELECT f1, f2, f3 FROM t1")
48         for row in c1.iterate_tuple():
49             f1, f2, f3 = row
50             c2.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > $1", f1)
51             for row in c2.iterate_tuple():
52                 f1, f2, f3 = row
53
54     def TestNoDataErrorRecovery(self):
55         for i in range(1, 4):
56             try:
57                 db.execute("DROP TABLE t1")
58             except pg8000.DatabaseError, e:
59                 # the only acceptable error is:
60                 self.assert_(e.args[1] == '42P01', # table does not exist
61                         "incorrect error for drop table")
62
63     def TestMultithreadedStatement(self):
64         # Note: Multithreading with a prepared statement is not highly
65         # recommended due to low performance.
66         s1 = pg8000.PreparedStatement(db, "INSERT INTO t1 (f1, f2, f3) VALUES ($1, $2, $3)", int, int, str)
67         def test(left, right):
68             for i in range(left, right):
69                 s1.execute(i, id(threading.currentThread()), None)
70         t1 = threading.Thread(target=test, args=(1, 25))
71         t2 = threading.Thread(target=test, args=(25, 50))
72         t3 = threading.Thread(target=test, args=(50, 75))
73         t1.start(); t2.start(); t3.start()
74         t1.join(); t2.join(); t3.join()
75
76     def TestMultithreadedCursor(self):
77         # Note: Multithreading with a cursor is not highly recommended due to
78         # low performance.
79         cur = pg8000.Cursor(db)
80         def test(left, right):
81             for i in range(left, right):
82                 cur.execute("INSERT INTO t1 (f1, f2, f3) VALUES ($1, $2, $3)", i, id(threading.currentThread()), None)
83         t1 = threading.Thread(target=test, args=(1, 25))
84         t2 = threading.Thread(target=test, args=(25, 50))
85         t3 = threading.Thread(target=test, args=(50, 75))
86         t1.start(); t2.start(); t3.start()
87         t1.join(); t2.join(); t3.join()
88
89 class ParamstyleTests(unittest.TestCase):
90     def TestQmark(self):
91         new_query, new_args = pg8000.DBAPI.convert_paramstyle("qmark", "SELECT ?, ?, \"field_?\" FROM t WHERE a='say ''what?''' AND b=? AND c=E'?\\'test\\'?'", (1, 2, 3))
92         assert new_query == "SELECT $1, $2, \"field_?\" FROM t WHERE a='say ''what?''' AND b=$3 AND c=E'?\\'test\\'?'"
93         assert new_args == (1, 2, 3)
94
95     def TestQmark2(self):
96         new_query, new_args = pg8000.DBAPI.convert_paramstyle("qmark", "SELECT ?, ?, * FROM t WHERE a=? AND b='are you ''sure?'", (1, 2, 3))
97         assert new_query == "SELECT $1, $2, * FROM t WHERE a=$3 AND b='are you ''sure?'"
98         assert new_args == (1, 2, 3)
99
100     def TestNumeric(self):
101         new_query, new_args = pg8000.DBAPI.convert_paramstyle("numeric", "SELECT :2, :1, * FROM t WHERE a=:3", (1, 2, 3))
102         assert new_query == "SELECT $2, $1, * FROM t WHERE a=$3"
103         assert new_args == (1, 2, 3)
104
105     def TestNamed(self):
106         new_query, new_args = pg8000.DBAPI.convert_paramstyle("named", "SELECT :f2, :f1 FROM t WHERE a=:f2", {"f2": 1, "f1": 2})
107         assert new_query == "SELECT $1, $2 FROM t WHERE a=$1"
108         assert new_args == (1, 2)
109
110     def TestFormat(self):
111         new_query, new_args = pg8000.DBAPI.convert_paramstyle("format", "SELECT %s, %s, \"f1_%%\", E'txt_%%' FROM t WHERE a=%s AND b='75%%'", (1, 2, 3))
112         assert new_query == "SELECT $1, $2, \"f1_%\", E'txt_%' FROM t WHERE a=$3 AND b='75%'"
113         assert new_args == (1, 2, 3)
114
115     def TestPyformat(self):
116         new_query, new_args = pg8000.DBAPI.convert_paramstyle("pyformat", "SELECT %(f2)s, %(f1)s, \"f1_%%\", E'txt_%%' FROM t WHERE a=%(f2)s AND b='75%%'", {"f2": 1, "f1": 2, "f3": 3})
117         assert new_query == "SELECT $1, $2, \"f1_%\", E'txt_%' FROM t WHERE a=$1 AND b='75%'"
118         assert new_args == (1, 2)
119
120
121 class DBAPITests(unittest.TestCase):
122     def setUp(self):
123         c = db2.cursor()
124         try:
125             c.execute("DROP TABLE t1")
126         except pg8000.DatabaseError, e:
127             # the only acceptable error is:
128             self.assert_(e.args[1] == '42P01', # table does not exist
129                     "incorrect error for drop table")
130         c.execute("CREATE TEMPORARY TABLE t1 (f1 int primary key, f2 int not null, f3 varchar(50) null)")
131         c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, None))
132         c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 10, None))
133         c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 100, None))
134         c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (4, 1000, None))
135         c.execute("INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (5, 10000, None))
136
137     def TestParallelQueries(self):
138         c1 = db2.cursor()
139         c2 = db2.cursor()
140         c1.execute("SELECT f1, f2, f3 FROM t1")
141         while 1:
142             row = c1.fetchone()
143             if row == None:
144                 break
145             f1, f2, f3 = row
146             c2.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (f1,))
147             while 1:
148                 row = c2.fetchone()
149                 if row == None:
150                     break
151                 f1, f2, f3 = row
152
153     def TestQmark(self):
154         orig_paramstyle = dbapi.paramstyle
155         try:
156             dbapi.paramstyle = "qmark"
157             c1 = db2.cursor()
158             c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > ?", (3,))
159             while 1:
160                 row = c1.fetchone()
161                 if row == None:
162                     break
163                 f1, f2, f3 = row
164         finally:
165             dbapi.paramstyle = orig_paramstyle
166
167     def TestNumeric(self):
168         orig_paramstyle = dbapi.paramstyle
169         try:
170             dbapi.paramstyle = "numeric"
171             c1 = db2.cursor()
172             c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > :1", (3,))
173             while 1:
174                 row = c1.fetchone()
175                 if row == None:
176                     break
177                 f1, f2, f3 = row
178         finally:
179             dbapi.paramstyle = orig_paramstyle
180
181     def TestNamed(self):
182         orig_paramstyle = dbapi.paramstyle
183         try:
184             dbapi.paramstyle = "named"
185             c1 = db2.cursor()
186             c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > :f1", {"f1": 3})
187             while 1:
188                 row = c1.fetchone()
189                 if row == None:
190                     break
191                 f1, f2, f3 = row
192         finally:
193             dbapi.paramstyle = orig_paramstyle
194
195     def TestFormat(self):
196         orig_paramstyle = dbapi.paramstyle
197         try:
198             dbapi.paramstyle = "format"
199             c1 = db2.cursor()
200             c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (3,))
201             while 1:
202                 row = c1.fetchone()
203                 if row == None:
204                     break
205                 f1, f2, f3 = row
206         finally:
207             dbapi.paramstyle = orig_paramstyle
208    
209     def TestPyformat(self):
210         orig_paramstyle = dbapi.paramstyle
211         try:
212             dbapi.paramstyle = "pyformat"
213             c1 = db2.cursor()
214             c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %(f1)s", {"f1": 3})
215             while 1:
216                 row = c1.fetchone()
217                 if row == None:
218                     break
219                 f1, f2, f3 = row
220         finally:
221             dbapi.paramstyle = orig_paramstyle
222
223     def TestArraysize(self):
224         c1 = db2.cursor()
225         c1.arraysize = 3
226         c1.execute("SELECT * FROM t1")
227         retval = c1.fetchmany()
228         self.assert_(len(retval) == c1.arraysize,
229                 "fetchmany returned wrong number of rows")
230
231     def TestDate(self):
232         val = dbapi.Date(2001, 2, 3)
233         self.assert_(val == datetime.date(2001, 2, 3),
234                 "Date constructor value match failed")
235
236     def TestTime(self):
237         val = dbapi.Time(4, 5, 6)
238         self.assert_(val == datetime.time(4, 5, 6),
239                 "Time constructor value match failed")
240
241     def TestTimestamp(self):
242         val = dbapi.Timestamp(2001, 2, 3, 4, 5, 6)
243         self.assert_(val == datetime.datetime(2001, 2, 3, 4, 5, 6),
244                 "Timestamp constructor value match failed")
245
246     def TestDateFromTicks(self):
247         val = dbapi.DateFromTicks(1173804319)
248         self.assert_(val == datetime.date(2007, 3, 13),
249                 "DateFromTicks constructor value match failed")
250
251     def TestTimeFromTicks(self):
252         val = dbapi.TimeFromTicks(1173804319)
253         self.assert_(val == datetime.time(10, 45, 19),
254                 "TimeFromTicks constructor value match failed")
255
256     def TestTimestampFromTicks(self):
257         val = dbapi.TimestampFromTicks(1173804319)
258         self.assert_(val == datetime.datetime(2007, 3, 13, 10, 45, 19),
259                 "TimestampFromTicks constructor value match failed")
260
261     def TestBinary(self):
262         v = dbapi.Binary("\x00\x01\x02\x03\x02\x01\x00")
263         self.assert_(v == "\x00\x01\x02\x03\x02\x01\x00",
264                 "Binary value match failed")
265         self.assert_(isinstance(v, pg8000.Bytea),
266                 "Binary type match failed")
267
268
269 # Tests relating to type conversion.
270 class TypeTests(unittest.TestCase):
271     def TestTimeRoundtrip(self):
272         db.execute("SELECT $1 as f1", datetime.time(4, 5, 6))
273         retval = tuple(db.iterate_dict())
274         self.assert_(retval == ({"f1": datetime.time(4, 5, 6)},),
275                 "retrieved value match failed")
276
277     def TestDateRoundtrip(self):
278         db.execute("SELECT $1 as f1", datetime.date(2001, 2, 3))
279         retval = tuple(db.iterate_dict())
280         self.assert_(retval == ({"f1": datetime.date(2001, 2, 3)},),
281                 "retrieved value match failed")
282
283     def TestBoolRoundtrip(self):
284         db.execute("SELECT $1 as f1", True)
285         retval = tuple(db.iterate_dict())
286         self.assert_(retval == ({"f1": True},),
287                 "retrieved value match failed")
288
289     def TestNullRoundtrip(self):
290         # We can't just "SELECT $1" and set None as the parameter, since it has
291         # no type.  That would result in a PG error, "could not determine data
292         # type of parameter $1".  So we create a temporary table, insert null
293         # values, and read them back.
294         db.execute("CREATE TEMPORARY TABLE TestNullWrite (f1 int4, f2 timestamp, f3 varchar)")
295         db.execute("INSERT INTO TestNullWrite VALUES ($1, $2, $3)",
296                 None, None, None)
297         db.execute("SELECT * FROM TestNullWrite")
298         retval = tuple(db.iterate_dict())
299         self.assert_(retval == ({"f1": None, "f2": None, "f3": None},),
300                 "retrieved value match failed")
301
302     def TestNullSelectFailure(self):
303         # See comment in TestNullRoundtrip.  This test is here to ensure that
304         # this behaviour is documented and doesn't mysteriously change.
305         self.assertRaises(pg8000.ProgrammingError, db.execute,
306                 "SELECT $1 as f1", None)
307
308     def TestDecimalRoundtrip(self):
309         db.execute("SELECT $1 as f1", decimal.Decimal('1.1'))
310         retval = tuple(db.iterate_dict())
311         self.assert_(retval == ({"f1": decimal.Decimal('1.1')},),
312                 "retrieved value match failed")
313
314     def TestFloatRoundtrip(self):
315         # This test ensures that the binary float value doesn't change in a
316         # roundtrip to the server.  That could happen if the value was
317         # converted to text and got rounded by a decimal place somewhere.
318         val = 1.756e-12
319         bin_orig = struct.pack("!d", val)
320         db.execute("SELECT $1 as f1", val)
321         retval = tuple(db.iterate_dict())
322         bin_new = struct.pack("!d", retval[0]['f1'])
323         self.assert_(bin_new == bin_orig,
324                 "retrieved value match failed")
325
326     def TestStrRoundtrip(self):
327         db.execute("SELECT $1 as f1", "hello world")
328         retval = tuple(db.iterate_dict())
329         self.assert_(retval == ({"f1": u"hello world"},),
330                 "retrieved value match failed")
331
332     def TestUnicodeRoundtrip(self):
333         db.execute("SELECT $1 as f1", u"hello \u0173 world")
334         retval = tuple(db.iterate_dict())
335         self.assert_(retval == ({"f1": u"hello \u0173 world"},),
336                 "retrieved value match failed")
337
338     def TestLongRoundtrip(self):
339         db.execute("SELECT $1 as f1", 50000000000000L)
340         retval = tuple(db.iterate_dict())
341         self.assert_(retval == ({"f1": 50000000000000L},),
342                 "retrieved value match failed")
343
344     def TestIntRoundtrip(self):
345         db.execute("SELECT $1 as f1", 100)
346         retval = tuple(db.iterate_dict())
347         self.assert_(retval == ({"f1": 100},),
348                 "retrieved value match failed")
349
350     def TestByteaRoundtrip(self):
351         db.execute("SELECT $1 as f1", pg8000.Bytea("\x00\x01\x02\x03\x02\x01\x00"))
352         retval = tuple(db.iterate_dict())
353         self.assert_(retval == ({"f1": "\x00\x01\x02\x03\x02\x01\x00"},),
354                 "retrieved value match failed")
355
356     def TestTimestampRoundtrip(self):
357         v = datetime.datetime(2001, 2, 3, 4, 5, 6, 170000)
358         db.execute("SELECT $1 as f1", v)
359         retval = tuple(db.iterate_dict())
360         self.assert_(retval == ({"f1": v},),
361                 "retrieved value match failed")
362
363     def TestTimestampTzOut(self):
364         db.execute("SELECT '2001-02-03 04:05:06.17 Canada/Mountain'::timestamp with time zone")
365         retval = tuple(db.iterate_dict())
366         dt = retval[0]['timestamptz']
367         self.assert_(dt.tzinfo.hrs == -7,
368                 "timezone hrs != -7")
369         self.assert_(
370                 datetime.datetime(dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second, dt.microsecond) ==
371                 datetime.datetime(2001, 2, 3, 4, 5, 6, 170000),
372                 "retrieved value match failed")
373
374     def TestNameOut(self):
375         # select a field that is of "name" type:
376         db.execute("SELECT usename FROM pg_user")
377         retval = tuple(db.iterate_dict())
378         # It is sufficient that no errors were encountered.
379
380     def TestOidOut(self):
381         db.execute("SELECT oid FROM pg_type")
382         retval = tuple(db.iterate_dict())
383         # It is sufficient that no errors were encountered.
384
385     def TestBooleanOut(self):
386         db.execute("SELECT 't'::bool")
387         retval = tuple(db.iterate_dict())
388         self.assert_(retval == ({"bool": True},),
389                 "retrieved value match failed")
390
391     def TestNumericOut(self):
392         db.execute("SELECT 5000::numeric")
393         retval = tuple(db.iterate_dict())
394         self.assert_(retval == ({"numeric": decimal.Decimal("5000")},),
395                 "retrieved value match failed")
396
397     def TestInt2Out(self):
398         db.execute("SELECT 5000::smallint")
399         retval = tuple(db.iterate_dict())
400         self.assert_(retval == ({"int2": 5000},),
401                 "retrieved value match failed")
402
403     def TestInt4Out(self):
404         db.execute("SELECT 5000::integer")
405         retval = tuple(db.iterate_dict())
406         self.assert_(retval == ({"int4": 5000},),
407                 "retrieved value match failed")
408
409     def TestInt8Out(self):
410         db.execute("SELECT 50000000000000::bigint")
411         retval = tuple(db.iterate_dict())
412         self.assert_(retval == ({"int8": 50000000000000},),
413                 "retrieved value match failed")
414
415     def TestFloat4Out(self):
416         db.execute("SELECT 1.1::real")
417         retval = tuple(db.iterate_dict())
418         self.assert_(retval == ({"float4": 1.1000000238418579},),
419                 "retrieved value match failed")
420
421     def TestFloat8Out(self):
422         db.execute("SELECT 1.1::double precision")
423         retval = tuple(db.iterate_dict())
424         self.assert_(retval == ({"float8": 1.1000000000000001},),
425                 "retrieved value match failed")
426
427     def TestVarcharOut(self):
428         db.execute("SELECT 'hello'::varchar(20)")
429         retval = tuple(db.iterate_dict())
430         self.assert_(retval == ({"varchar": u"hello"},),
431                 "retrieved value match failed")
432
433     def TestCharOut(self):
434         db.execute("SELECT 'hello'::char(20)")
435         retval = tuple(db.iterate_dict())
436         self.assert_(retval == ({"bpchar": u"hello               "},),
437                 "retrieved value match failed")
438
439     def TestTextOut(self):
440         db.execute("SELECT 'hello'::text")
441         retval = tuple(db.iterate_dict())
442         self.assert_(retval == ({"text": u"hello"},),
443                 "retrieved value match failed")
444
445     def TestIntervalOut(self):
446         db.execute("SELECT '1 month'::interval")
447         retval = tuple(db.iterate_dict())
448         self.assert_(retval == ({"interval": "1 mon"},),
449                 "retrieved value match failed")
450
451     def TestTimestampOut(self):
452         db.execute("SELECT '2001-02-03 04:05:06.17'::timestamp")
453         retval = tuple(db.iterate_dict())
454         self.assert_(retval == ({"timestamp": datetime.datetime(2001, 2, 3, 4, 5, 6, 170000)},),
455                 "retrieved value match failed")
456
457
458 def suite():
459     paramstyle_tests = unittest.makeSuite(ParamstyleTests, "Test")
460     dbapi_tests = unittest.makeSuite(DBAPITests, "Test")
461     query_tests = unittest.makeSuite(QueryTests, "Test")
462     type_tests = unittest.makeSuite(TypeTests, "Test")
463     return unittest.TestSuite((paramstyle_tests, dbapi_tests, query_tests,
464         type_tests))
465
466 if __name__ == "__main__":
467     runner = unittest.TextTestRunner()
468     runner.run(suite())
469
Note: See TracBrowser for help on using the browser.