root/pg8000/pg8000-v1.01/pg8000.py

Revision 827, 49.3 kB (checked in by mfenniak, 2 years ago)

Add float and decimal output types

<
Line 
1 # vim: sw=4:expandtab:foldmethod=marker
2 #
3 # Copyright (c) 2007, Mathieu Fenniak
4 # All rights reserved.
5 #
6 # Redistribution and use in source and binary forms, with or without
7 # modification, are permitted provided that the following conditions are
8 # met:
9 #
10 # * Redistributions of source code must retain the above copyright notice,
11 # this list of conditions and the following disclaimer.
12 # * Redistributions in binary form must reproduce the above copyright notice,
13 # this list of conditions and the following disclaimer in the documentation
14 # and/or other materials provided with the distribution.
15 # * The name of the author may not be used to endorse or promote products
16 # derived from this software without specific prior written permission.
17 #
18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21 # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
22 # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23 # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24 # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25 # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26 # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27 # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28 # POSSIBILITY OF SUCH DAMAGE.
29
30 __author__ = "Mathieu Fenniak"
31
32 import socket
33 import struct
34 import datetime
35 import md5
36 import decimal
37 import threading
38
39 class Warning(StandardError):
40     pass
41
42 class Error(StandardError):
43     pass
44
45 class InterfaceError(Error):
46     pass
47
48 class DatabaseError(Error):
49     pass
50
51 class DataError(DatabaseError):
52     pass
53
54 class OperationalError(DatabaseError):
55     pass
56
57 class IntegrityError(DatabaseError):
58     pass
59
60 class InternalError(DatabaseError):
61     pass
62
63 class ProgrammingError(DatabaseError):
64     pass
65
66 class NotSupportedError(DatabaseError):
67     pass
68
69
70 class DataIterator(object):
71     def __init__(self, obj, func):
72         self.obj = obj
73         self.func = func
74
75     def __iter__(self):
76         return self
77
78     def next(self):
79         retval = self.func(self.obj)
80         if retval == None:
81             raise StopIteration()
82         return retval
83
84 class DBAPI(object):
85     Warning = Warning
86     Error = Error
87     InterfaceError = InterfaceError
88     DatabaseError = DatabaseError
89     DataError = DataError
90     OperationalError = OperationalError
91     IntegrityError = IntegrityError
92     ProgrammingError = ProgrammingError
93     NotSupportedError = NotSupportedError
94    
95     apilevel = "2.0"
96     threadsafety = 3
97     paramstyle = 'none-of-the-above'
98
99     class CursorWrapper(object):
100         def __init__(self, conn):
101             self.cursor = Cursor(conn)
102             self.arraysize = 1
103
104         def execute(self, operation, args=()):
105             self.cursor.execute(operation, *args)
106
107         def executemany(self, operation, parameter_sets):
108             for parameters in parameter_sets:
109                 self.execute(operation, parameters)
110
111         def fetchone(self):
112             return self.cursor.read_tuple()
113
114         def fetchmany(self, size=None):
115             if size == None:
116                 size = self.arraysize
117             rows = []
118             for i in range(size):
119                 rows.append(self.fetchone())
120             return rows
121
122         def fetchall(self):
123             return tuple(cursor.iterate_tuple())
124
125         def close(self):
126             self.cursor = None
127
128         def setinputsizes(self, sizes):
129             pass
130
131         def setoutputsize(self, size, column=None):
132             pass
133
134     class ConnectionWrapper(object):
135         def __init__(self, **kwargs):
136             self.conn = Connection(**kwargs)
137             self.conn.begin()
138
139         def cursor(self):
140             return CursorWrapper(self.conn)
141
142         def commit(self):
143             # There's a threading bug here.  If a query is sent after the
144             # commit, but before the begin, it will be executed immediately
145             # without a surrounding transaction.  Like all threading bugs -- it
146             # sounds unlikely, until it happens every time in one
147             # application...  however, to fix this, we need to lock the
148             # database connection entirely, so that no cursors can execute
149             # statements on other threads.  Support for that type of lock will
150             # be done later.
151             self.conn.commit()
152             self.conn.begin()
153
154         def rollback(self):
155             # see bug description in commit.
156             self.conn.rollback()
157             self.conn.begin()
158
159         def close(self):
160             self.conn = None
161
162     def connect(user, host=None, unix_sock=None, port=5432, database=None, password=None, socket_timeout=60):
163         return ConnectionWrapper(user=user, host=host, unix_sock=unix_sock,
164                 port=port, database=database, password=password,
165                 socket_timeout=socket_timeout)
166
167
168 ##
169 # This class represents a prepared statement.  A prepared statement is
170 # pre-parsed on the server, which reduces the need to parse the query every
171 # time it is run.  The statement can have parameters in the form of $1, $2, $3,
172 # etc.  When parameters are used, the types of the parameters need to be
173 # specified when creating the prepared statement.
174 # <p>
175 # As of v1.01, instances of this class are thread-safe.  This means that a
176 # single PreparedStatement can be accessed by multiple threads without the
177 # internal consistency of the statement being altered.  However, the
178 # responsibility is on the client application to ensure that one thread reading
179 # from a statement isn't affected by another thread starting a new query with
180 # the same statement.
181 # <p>
182 # Stability: Added in v1.00, stability guaranteed for v1.xx.
183 #
184 # @param connection     An instance of {@link Connection Connection}.
185 #
186 # @param statement      The SQL statement to be represented, often containing
187 # parameters in the form of $1, $2, $3, etc.
188 #
189 # @param types          Python type objects for each parameter in the SQL
190 # statement.  For example, int, float, str.
191 class PreparedStatement(object):
192
193     ##
194     # Determines the number of rows to read from the database server at once.
195     # Reading more rows increases performance at the cost of memory.  The
196     # default value is 100 rows.  The affect of this parameter is transparent.
197     # That is, the library reads more rows when the cache is empty
198     # automatically.
199     # <p>
200     # Stability: Added in v1.00, stability guaranteed for v1.xx.  It is
201     # possible that implementation changes in the future could cause this
202     # parameter to be ignored.O
203     row_cache_size = 100
204
205     def __init__(self, connection, statement, *types):
206         self.c = connection.c
207         self._portal_name = "pg8000_portal_%s_%s" % (id(self.c), id(self))
208         self._statement_name = "pg8000_statement_%s_%s" % (id(self.c), id(self))
209         self._row_desc = None
210         self._cached_rows = []
211         self._command_complete = True
212         self._parse_row_desc = self.c.parse(self._statement_name, statement, types)
213         self._lock = threading.RLock()
214
215     def __del__(self):
216         # This __del__ should work with garbage collection / non-instant
217         # cleanup.  It only really needs to be called right away if the same
218         # object id (and therefore the same statement name) might be reused
219         # soon, and clearly that wouldn't happen in a GC situation.
220         self.c.close_statement(self._statement_name)
221
222     ##
223     # Run the SQL prepared statement with the given parameters.
224     # <p>
225     # Stability: Added in v1.00, stability guaranteed for v1.xx.
226     def execute(self, *args):
227         self._lock.acquire()
228         try:
229             if not self._command_complete:
230                 # cleanup last execute
231                 self._cached_rows = []
232                 self.c.close_portal(self._portal_name)
233             self._command_complete = False
234             self._row_desc = self.c.bind(self._portal_name, self._statement_name, args, self._parse_row_desc)
235             if self._row_desc:
236                 # We execute our cursor right away to fill up our cache.  This
237                 # prevents the cursor from being destroyed, apparently, by a rogue
238                 # Sync between Bind and Execute.  Since it is quite likely that
239                 # data will be read from us right away anyways, this seems a safe
240                 # move for now.
241                 self._fill_cache()
242         finally:
243             self._lock.release()
244
245     def _fill_cache(self):
246         self._lock.acquire()
247         try:
248             if self._cached_rows:
249                 raise InternalError("attempt to fill cache that isn't empty")
250             end_of_data, rows = self.c.fetch_rows(self._portal_name, self.row_cache_size, self._row_desc)
251             self._cached_rows = rows
252             if end_of_data:
253                 self._command_complete = True
254         finally:
255             self._lock.release()
256
257     def _fetch(self):
258         self._lock.acquire()
259         try:
260             if not self._cached_rows:
261                 if self._command_complete:
262                     return None
263                 self._fill_cache()
264                 if self._command_complete and not self._cached_rows:
265                     # fill cache tells us the command is complete, but yet we have
266                     # no rows after filling our cache.  This is a special case when
267                     # a query returns no rows.
268                     return None
269             row = self._cached_rows[0]
270             del self._cached_rows[0]
271             return tuple(row)
272         finally:
273             self._lock.release()
274
275     ##
276     # Read a row from the database server, and return it in a dictionary
277     # indexed by column name/alias.  This method will raise an error if two
278     # columns have the same name.  Returns None after the last row.
279     # <p>
280     # Stability: Added in v1.00, stability guaranteed for v1.xx.
281     def read_dict(self):
282         row = self._fetch()
283         if row == None:
284             return row
285         retval = {}
286         for i in range(len(self._row_desc.fields)):
287             col_name = self._row_desc.fields[i]['name']
288             if retval.has_key(col_name):
289                 raise InterfaceError("cannot return dict of row when two columns have the same name (%r)" % (col_name,))
290             retval[col_name] = row[i]
291         return retval
292
293     ##
294     # Read a row from the database server, and return it as a tuple of values.
295     # Returns None after the last row.
296     # <p>
297     # Stability: Added in v1.00, stability guaranteed for v1.xx.
298     def read_tuple(self):
299         row = self._fetch()
300         if row == None:
301             return row
302         return row
303
304     ##
305     # Return an iterator for the output of this statement.  The iterator will
306     # return a tuple for each row, in the same manner as {@link
307     # #PreparedStatement.read_tuple read_tuple}.
308     # <p>
309     # Stability: Added in v1.00, stability guaranteed for v1.xx.
310     def iterate_tuple(self):
311         return DataIterator(self, PreparedStatement.read_tuple)
312
313     ##
314     # Return an iterator for the output of this statement.  The iterator will
315     # return a dict for each row, in the same manner as {@link
316     # #PreparedStatement.read_dict read_dict}.
317     # <p>
318     # Stability: Added in v1.00, stability guaranteed for v1.xx.
319     def iterate_dict(self):
320         return DataIterator(self, PreparedStatement.read_dict)
321
322 ##
323 # The Cursor class allows multiple queries to be performed concurrently with a
324 # single PostgreSQL connection.  The Cursor object is implemented internally by
325 # using a {@link PreparedStatement PreparedStatement} object, so if you plan to
326 # use a statement multiple times, you might as well create a PreparedStatement
327 # and save a small amount of reparsing time.
328 # <p>
329 # As of v1.01, instances of this class are thread-safe.  See {@link
330 # PreparedStatement PreparedStatement} for more information.
331 # <p>
332 # Stability: Added in v1.00, stability guaranteed for v1.xx.
333 #
334 # @param connection     An instance of {@link Connection Connection}.
335 class Cursor(object):
336     def __init__(self, connection):
337         self.connection = connection
338         self._stmt = None
339
340     ##
341     # Run an SQL statement using this cursor.  The SQL statement can have
342     # parameters in the form of $1, $2, $3, etc., which will be filled in by
343     # the additional arguments passed to this function.
344     # <p>
345     # Stability: Added in v1.00, stability guaranteed for v1.xx.
346     # @param query      The SQL statement to execute.
347     def execute(self, query, *args):
348         self._stmt = PreparedStatement(self.connection, query, *[type(x) for x in args])
349         self._stmt.execute(*args)
350
351     ##
352     # Read a row from the database server, and return it in a dictionary
353     # indexed by column name/alias.  This method will raise an error if two
354     # columns have the same name.  Returns None after the last row.
355     # <p>
356     # Stability: Added in v1.00, stability guaranteed for v1.xx.
357     def read_dict(self):
358         if self._stmt == None:
359             raise ProgrammingError("attempting to read from unexecuted cursor")
360         return self._stmt.read_dict()
361
362     ##
363     # Read a row from the database server, and return it as a tuple of values.
364     # Returns None after the last row.
365     # <p>
366     # Stability: Added in v1.00, stability guaranteed for v1.xx.
367     def read_tuple(self):
368         if self._stmt == None:
369             raise ProgrammingError("attempting to read from unexecuted cursor")
370         return self._stmt.read_tuple()
371
372     ##
373     # Return an iterator for the output of this statement.  The iterator will
374     # return a tuple for each row, in the same manner as {@link
375     # #PreparedStatement.read_tuple read_tuple}.
376     # <p>
377     # Stability: Added in v1.00, stability guaranteed for v1.xx.
378     def iterate_tuple(self):
379         if self._stmt == None:
380             raise ProgrammingError("attempting to read from unexecuted cursor")
381         return self._stmt.iterate_tuple()
382
383     ##
384     # Return an iterator for the output of this statement.  The iterator will
385     # return a dict for each row, in the same manner as {@link
386     # #PreparedStatement.read_dict read_dict}.
387     # <p>
388     # Stability: Added in v1.00, stability guaranteed for v1.xx.
389     def iterate_dict(self):
390         if self._stmt == None:
391             raise ProgrammingError("attempting to read from unexecuted cursor")
392         return self._stmt.iterate_dict()
393
394 ##
395 # This class represents a connection to a PostgreSQL database.
396 # <p>
397 # The database connection is derived from the {@link #Cursor Cursor} class,
398 # which provides a default cursor for running queries.  It also provides
399 # transaction control via the 'begin', 'commit', and 'rollback' methods.
400 # Without beginning a transaction explicitly, all statements will autocommit to
401 # the database.
402 # <p>
403 # As of v1.01, instances of this class are thread-safe.  See {@link
404 # PreparedStatement PreparedStatement} for more information.
405 # <p>
406 # Stability: Added in v1.00, stability guaranteed for v1.xx.
407 #
408 # @param user   The username to connect to the PostgreSQL server with.  This
409 # parameter is required.
410 #
411 # @keyparam host   The hostname of the PostgreSQL server to connect with.
412 # Providing this parameter is necessary for TCP/IP connections.  One of either
413 # host, or unix_sock, must be provided.
414 #
415 # @keyparam unix_sock   The path to the UNIX socket to access the database
416 # through, for example, '/tmp/.s.PGSQL.5432'.  One of either unix_sock or host
417 # must be provided.  The port parameter will have no affect if unix_sock is
418 # provided.
419 #
420 # @keyparam port   The TCP/IP port of the PostgreSQL server instance.  This
421 # parameter defaults to 5432, the registered and common port of PostgreSQL
422 # TCP/IP servers.
423 #
424 # @keyparam database   The name of the database instance to connect with.  This
425 # parameter is optional, if omitted the PostgreSQL server will assume the
426 # database name is the same as the username.
427 #
428 # @keyparam password   The user password to connect to the server with.  This
429 # parameter is optional.  If omitted, and the database server requests password
430 # based authentication, the connection will fail.  On the other hand, if this
431 # parameter is provided and the database does not request password
432 # authentication, then the password will not be used.
433 #
434 # @keyparam socket_timeout  Socket connect timeout measured in seconds.
435 # Defaults to 60 seconds.
436 class Connection(Cursor):
437     def __init__(self, user, host=None, unix_sock=None, port=5432, database=None, password=None, socket_timeout=60):
438         self._row_desc = None
439         try:
440             self.c = Protocol.Connection(unix_sock=unix_sock, host=host, port=port, socket_timeout=socket_timeout)
441             #self.c.connect()
442             self.c.authenticate(user, password=password, database=database)
443         except socket.error, e:
444             raise InterfaceError("communication error", e)
445         Cursor.__init__(self, self)
446         self._begin = PreparedStatement(self, "BEGIN TRANSACTION")
447         self._commit = PreparedStatement(self, "COMMIT TRANSACTION")
448         self._rollback = PreparedStatement(self, "ROLLBACK TRANSACTION")
449
450     ##
451     # Begins a new transaction.
452     # <p>
453     # Stability: Added in v1.00, stability guaranteed for v1.xx.
454     def begin(self):
455         self._begin.execute()
456
457     ##
458     # Commits the running transaction.
459     # <p>
460     # Stability: Added in v1.00, stability guaranteed for v1.xx.
461     def commit(self):
462         self._commit.execute()
463
464     ##
465     # Rolls back the running transaction.
466     # <p>
467     # Stability: Added in v1.00, stability guaranteed for v1.xx.
468     def rollback(self):
469         self._rollback.execute()
470
471
472 class Protocol(object):
473     class StartupMessage(object):
474         def __init__(self, user, database=None):
475             self.user = user
476             self.database = database
477
478         def serialize(self):
479             protocol = 196608
480             val = struct.pack("!i", protocol)
481             val += "user\x00" + self.user + "\x00"
482             if self.database:
483                 val += "database\x00" + self.database + "\x00"
484             val += "\x00"
485             val = struct.pack("!i", len(val) + 4) + val
486             return val
487
488     class Query(object):
489         def __init__(self, qs):
490             self.qs = qs
491
492         def serialize(self):
493             val = self.qs + "\x00"
494             val = struct.pack("!i", len(val) + 4) + val
495             val = "Q" + val
496             return val
497
498     class Parse(object):
499         def __init__(self, ps, qs, type_oids):
500             self.ps = ps
501             self.qs = qs
502             self.type_oids = type_oids
503
504         def serialize(self):
505             val = self.ps + "\x00" + self.qs + "\x00"
506             val = val + struct.pack("!h", len(self.type_oids))
507             for oid in self.type_oids:
508                 val = val + struct.pack("!i", oid)
509             val = struct.pack("!i", len(val) + 4) + val
510             val = "P" + val
511             return val
512
513     class Bind(object):
514         def __init__(self, portal, ps, in_fc, params, out_fc, client_encoding):
515             self.portal = portal
516             self.ps = ps
517             self.in_fc = in_fc
518             self.params = []
519             for i in range(len(params)):
520                 if len(self.in_fc) == 0:
521                     fc = 0
522                 elif len(self.in_fc) == 1:
523                     fc = self.in_fc[0]
524                 else:
525                     fc = self.in_fc[i]
526                 self.params.append(Types.pg_value(params[i], fc, client_encoding = client_encoding))
527             self.out_fc = out_fc
528
529         def serialize(self):
530             val = self.portal + "\x00" + self.ps + "\x00"
531             val = val + struct.pack("!h", len(self.in_fc))
532             for fc in self.in_fc:
533                 val = val + struct.pack("!h", fc)
534             val = val + struct.pack("!h", len(self.params))
535             for param in self.params:
536                 if param == None:
537                     # special case, NULL value
538                     val = val + struct.pack("!i", -1)
539                 else:
540                     val = val + struct.pack("!i", len(param)) + param
541             val = val + struct.pack("!h", len(self.out_fc))
542             for fc in self.out_fc:
543                 val = val + struct.pack("!h", fc)
544             val = struct.pack("!i", len(val) + 4) + val
545             val = "B" + val
546             return val
547
548     class Close(object):
549         def __init__(self, typ, name):
550             if len(typ) != 1:
551                 raise InternalError("Close typ must be 1 char")
552             self.typ = typ
553             self.name = name
554
555         def serialize(self):
556             val = self.typ + self.name + "\x00"
557             val = struct.pack("!i", len(val) + 4) + val
558             val = "C" + val
559             return val
560
561     class ClosePortal(Close):
562         def __init__(self, name):
563             Protocol.Close.__init__(self, "P", name)
564
565     class ClosePreparedStatement(Close):
566         def __init__(self, name):
567             Protocol.Close.__init__(self, "S", name)
568
569     class Describe(object):
570         def __init__(self, typ, name):
571             if len(typ) != 1:
572                 raise InternalError("Describe typ must be 1 char")
573             self.typ = typ
574             self.name = name
575
576         def serialize(self):
577             val = self.typ + self.name + "\x00"
578             val = struct.pack("!i", len(val) + 4) + val
579             val = "D" + val
580             return val
581
582     class DescribePortal(Describe):
583         def __init__(self, name):
584             Protocol.Describe.__init__(self, "P", name)
585
586     class DescribePreparedStatement(Describe):
587         def __init__(self, name):
588             Protocol.Describe.__init__(self, "S", name)
589
590     class Flush(object):
591         def serialize(self):
592             return 'H\x00\x00\x00\x04'
593
594     class Sync(object):
595         def serialize(self):
596             return 'S\x00\x00\x00\x04'
597
598     class PasswordMessage(object):
599         def __init__(self, pwd):
600             self.pwd = pwd
601
602         def serialize(self):
603             val = self.pwd + "\x00"
604             val = struct.pack("!i", len(val) + 4) + val
605             val = "p" + val
606             return val
607
608     class Execute(object):
609         def __init__(self, portal, row_count):
610             self.portal = portal
611             self.row_count = row_count
612
613         def serialize(self):
614             val = self.portal + "\x00" + struct.pack("!i", self.row_count)
615             val = struct.pack("!i", len(val) + 4) + val
616             val = "E" + val
617             return val
618
619     class AuthenticationRequest(object):
620         def __init__(self, data):
621             pass
622
623         def createFromData(data):
624             ident = struct.unpack("!i", data[:4])[0]
625             klass = Protocol.authentication_codes.get(ident, None)
626             if klass != None:
627                 return klass(data[4:])
628             else:
629                 raise NotSupportedError("authentication method %r not supported" % (ident,))
630         createFromData = staticmethod(createFromData)
631
632         def ok(self, conn, user, **kwargs):
633             raise InternalError("ok method should be overridden on AuthenticationRequest instance")
634
635     class AuthenticationOk(AuthenticationRequest):
636         def ok(self, conn, user, **kwargs):
637             return True
638
639     class AuthenticationMD5Password(AuthenticationRequest):
640         def __init__(self, data):
641             self.salt = "".join(struct.unpack("4c", data))
642
643         def ok(self, conn, user, password=None, **kwargs):
644             if password == None:
645                 raise InterfaceError("server requesting MD5 password authentication, but no password was provided")
646             pwd = "md5" + md5.new(md5.new(password + user).hexdigest() + self.salt).hexdigest()
647             conn._send(Protocol.PasswordMessage(pwd))
648             msg = conn._read_message()
649             if isinstance(msg, Protocol.AuthenticationRequest):
650                 return msg.ok(conn, user)
651             elif isinstance(msg, Protocol.ErrorResponse):
652                 if msg.code == "28000":
653                     raise InterfaceError("md5 password authentication failed")
654                 else:
655                     raise InternalError("server returned unexpected error %r" % msg)
656             else:
657                 raise InternalError("server returned unexpected response %r" % msg)
658
659     authentication_codes = {
660         0: AuthenticationOk,
661         5: AuthenticationMD5Password,
662     }
663
664     class ParameterStatus(object):
665         def __init__(self, key, value):
666             self.key = key
667             self.value = value
668
669         def createFromData(data):
670             key = data[:data.find("\x00")]
671             value = data[data.find("\x00")+1:-1]
672             return Protocol.ParameterStatus(key, value)
673         createFromData = staticmethod(createFromData)
674
675     class BackendKeyData(object):
676         def __init__(self, process_id, secret_key):
677             self.process_id = process_id
678             self.secret_key = secret_key
679
680         def createFromData(data):
681             process_id, secret_key = struct.unpack("!2i", data)
682             return Protocol.BackendKeyData(process_id, secret_key)
683         createFromData = staticmethod(createFromData)
684
685     class NoData(object):
686         def createFromData(data):
687             return Protocol.NoData()
688         createFromData = staticmethod(createFromData)
689
690     class ParseComplete(object):
691         def createFromData(data):
692             return Protocol.ParseComplete()
693         createFromData = staticmethod(createFromData)
694
695     class BindComplete(object):
696         def createFromData(data):
697             return Protocol.BindComplete()
698         createFromData = staticmethod(createFromData)
699
700     class CloseComplete(object):
701         def createFromData(data):
702             return Protocol.CloseComplete()
703         createFromData = staticmethod(createFromData)
704
705     class PortalSuspended(object):
706         def createFromData(data):
707             return Protocol.PortalSuspended()
708         createFromData = staticmethod(createFromData)
709
710     class ReadyForQuery(object):
711         def __init__(self, status):
712             self.status = status
713
714         def __repr__(self):
715             return "<ReadyForQuery %s>" % \
716                     {"I": "Idle", "T": "Idle in Transaction", "E": "Idle in Failed Transaction"}[self.status]
717
718         def createFromData(data):
719             return Protocol.ReadyForQuery(data)
720         createFromData = staticmethod(createFromData)
721
722     class NoticeResponse(object):
723         def __init__(self):
724             pass
725         def createFromData(data):
726             # we could read the notice here, but we don't care yet.
727             return Protocol.NoticeResponse()
728         createFromData = staticmethod(createFromData)
729
730     class ErrorResponse(object):
731         def __init__(self, severity, code, msg):
732             self.severity = severity
733             self.code = code
734             self.msg = msg
735
736         def __repr__(self):
737             return "<ErrorResponse %s %s %r>" % (self.severity, self.code, self.msg)
738
739         def createException(self):
740             return ProgrammingError(self.severity, self.code, self.msg)
741
742         def createFromData(data):
743             args = {}
744             for s in data.split("\x00"):
745                 if not s:
746                     continue
747                 elif s[0] == "S":
748                     args["severity"] = s[1:]
749                 elif s[0] == "C":
750                     args["code"] = s[1:]
751                 elif s[0] == "M":
752                     args["msg"] = s[1:]
753             return Protocol.ErrorResponse(**args)
754         createFromData = staticmethod(createFromData)
755
756     class ParameterDescription(object):
757         def __init__(self, type_oids):
758             self.type_oids = type_oids
759         def createFromData(data):
760             count = struct.unpack("!h", data[:2])[0]
761             type_oids = struct.unpack("!" + "i"*count, data[2:])
762             return Protocol.ParameterDescription(type_oids)
763         createFromData = staticmethod(createFromData)
764
765     class RowDescription(object):
766         def __init__(self, fields):
767             self.fields = fields
768
769         def createFromData(data):
770             count = struct.unpack("!h", data[:2])[0]
771             data = data[2:]
772             fields = []
773             for i in range(count):
774                 null = data.find("\x00")
775                 field = {"name": data[:null]}
776                 data = data[null+1:]
777                 field["table_oid"], field["column_attrnum"], field["type_oid"], field["type_size"], field["type_modifier"], field["format"] = struct.unpack("!ihihih", data[:18])
778                 data = data[18:]
779                 fields.append(field)
780             return Protocol.RowDescription(fields)
781         createFromData = staticmethod(createFromData)
782
783     class CommandComplete(object):
784         def __init__(self, tag):
785             self.tag = tag
786
787         def createFromData(data):
788             return Protocol.CommandComplete(data[:-1])
789         createFromData = staticmethod(createFromData)
790
791     class DataRow(object):
792         def __init__(self, fields):
793             self.fields = fields
794
795         def createFromData(data):
796             count = struct.unpack("!h", data[:2])[0]
797             data = data[2:]
798             fields = []
799             for i in range(count):
800                 val_len = struct.unpack("!i", data[:4])[0]
801                 data = data[4:]
802                 if val_len == -1:
803                     fields.append(None)
804                 else:
805                     fields.append(data[:val_len])
806                     data = data[val_len:]
807             return Protocol.DataRow(fields)
808         createFromData = staticmethod(createFromData)
809
810     class Connection(object):
811         def __init__(self, unix_sock=None, host=None, port=5432, socket_timeout=60):
812             self._client_encoding = "ascii"
813             if unix_sock == None and host != None:
814                 self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
815             elif unix_sock != None:
816                 self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
817             else:
818                 raise ProgrammingError("one of host or unix_sock must be provided")
819             self._sock.settimeout(socket_timeout)
820             if unix_sock == None and host != None:
821                 self._sock.connect((host, port))
822             elif unix_sock != None:
823                 self._sock.connect(unix_sock)
824             self._state = "noauth"
825             self._backend_key_data = None
826             self._sock_lock = threading.Lock()
827
828         def verifyState(self, state):
829             if self._state != state:
830                 raise InternalError("connection state must be %s, is %s" % (state, self._state))
831
832         def _send(self, msg):
833             data = msg.serialize()
834             self._sock.send(data)
835
836         def _read_message(self):
837             bytes = ""
838             while len(bytes) < 5:
839                 tmp = self._sock.recv(5 - len(bytes))
840                 bytes += tmp
841             if len(bytes) != 5:
842                 raise InternalError("unable to read 5 bytes from socket %r" % bytes)
843             message_code = bytes[0]
844             data_len = struct.unpack("!i", bytes[1:])[0] - 4
845             if data_len == 0:
846                 bytes = ""
847             else:
848                 bytes = ""
849                 while len(bytes) < data_len:
850                     tmp = self._sock.recv(data_len - len(bytes))
851                     bytes += tmp
852             assert len(bytes) == data_len
853             msg = Protocol.message_types[message_code].createFromData(bytes)
854             if isinstance(msg, Protocol.NoticeResponse):
855                 # ignore NoticeResponse
856                 return self._read_message()
857             else:
858                 return msg
859
860         def authenticate(self, user, **kwargs):
861             self.verifyState("noauth")
862             self._send(Protocol.StartupMessage(user, database=kwargs.get("database",None)))
863             msg = self._read_message()
864             if isinstance(msg, Protocol.AuthenticationRequest):
865                 if msg.ok(self, user, **kwargs):
866                     self._state = "auth"
867                     while 1:
868                         msg = self._read_message()
869                         if isinstance(msg, Protocol.ReadyForQuery):
870                             # done reading messages
871                             self._state = "ready"
872                             break
873                         elif isinstance(msg, Protocol.ParameterStatus):
874                             if msg.key == "client_encoding":
875                                 self._client_encoding = msg.value
876                         elif isinstance(msg, Protocol.BackendKeyData):
877                             self._backend_key_data = msg
878                         elif isinstance(msg, Protocol.ErrorResponse):
879                             raise msg.createException()
880                         else:
881                             raise InternalError("unexpected msg %r" % msg)
882                 else:
883                     raise InterfaceError("authentication method %s failed" % msg.__class__.__name__)
884             else:
885                 raise InternalError("StartupMessage was responded to with non-AuthenticationRequest msg %r" % msg)
886
887         def parse(self, statement, qs, types):
888             self.verifyState("ready")
889             self._sock_lock.acquire()
890             try:
891                 type_info = [Types.pg_type_info(x) for x in types]
892                 param_types, param_fc = [x[0] for x in type_info], [x[1] for x in type_info] # zip(*type_info) -- fails on empty arr
893                 self._send(Protocol.Parse(statement, qs, param_types))
894                 self._send(Protocol.DescribePreparedStatement(statement))
895                 self._send(Protocol.Flush())
896                 while 1:
897                     msg = self._read_message()
898                     if isinstance(msg, Protocol.ParseComplete):
899                         # ok, good.
900                         pass
901                     elif isinstance(msg, Protocol.ParameterDescription):
902                         # well, we don't really care -- we're going to send whatever
903                         # we want and let the database deal with it.  But thanks
904                         # anyways!
905                         pass
906                     elif isinstance(msg, Protocol.NoData):
907                         # We're not waiting for a row description.  Return
908                         # something destinctive to let bind know that there is no
909                         # output.
910                         return (None, param_fc)
911                     elif isinstance(msg, Protocol.RowDescription):
912                         return (msg, param_fc)
913                     elif isinstance(msg, Protocol.ErrorResponse):
914                         raise msg.createException()
915                     else:
916                         raise InternalError("Unexpected response msg %r" % (msg))
917             finally:
918                 self._sock_lock.release()
919
920         def bind(self, portal, statement, params, parse_data):
921             self.verifyState("ready")
922             self._sock_lock.acquire()
923             try:
924                 row_desc, param_fc = parse_data
925                 if row_desc == None:
926                     # no data coming out
927                     output_fc = ()
928                 else:
929                     # We've got row_desc that allows us to identify what we're going to
930                     # get back from this statement.
931                     output_fc = [Types.py_type_info(f) for f in row_desc.fields]
932                 self._send(Protocol.Bind(portal, statement, param_fc, params, output_fc, self._client_encoding))
933                 # We need to describe the portal after bind, since the return
934                 # format codes will be different (hopefully, always what we
935                 # requested).
936                 self._send(Protocol.DescribePortal(portal))
937                 self._send(Protocol.Flush())
938                 while 1:
939                     msg = self._read_message()
940                     if isinstance(msg, Protocol.BindComplete):
941                         # good news everybody!
942                         pass
943                     elif isinstance(msg, Protocol.NoData):
944                         # No data means we should execute this command right away.
945                         self._send(Protocol.Execute(portal, 0))
946                         self._send(Protocol.Sync())
947                         exc = None
948                         while 1:
949                             msg = self._read_message()
950                             if isinstance(msg, Protocol.CommandComplete):
951                                 # more good news!
952                                 pass
953                             elif isinstance(msg, Protocol.ReadyForQuery):
954                                 if exc != None:
955                                     raise exc
956                                 break
957                             elif isinstance(msg, Protocol.ErrorResponse):
958                                 exc = msg.createException()
959                             else:
960                                 raise InternalError("unexpected response")
961                         return None
962                     elif isinstance(msg, Protocol.RowDescription):
963                         # Return the new row desc, since it will have the format
964                         # types we asked for
965                         return msg
966                     elif isinstance(msg, Protocol.ErrorResponse):
967                         raise msg.createException()
968                     else:
969                         raise InternalError("Unexpected response msg %r" % (msg))
970             finally:
971                 self._sock_lock.release()
972
973         def fetch_rows(self, portal, row_count, row_desc):
974             self.verifyState("ready")
975             self._sock_lock.acquire()
976             try:
977                 self._send(Protocol.Execute(portal, row_count))
978                 self._send(Protocol.Flush())
979                 rows = []
980                 end_of_data = False
981                 while 1:
982                     msg = self._read_message()
983                     if isinstance(msg, Protocol.DataRow):
984                         rows.append(
985                                 [Types.py_value(msg.fields[i], row_desc.fields[i], client_encoding=self._client_encoding)
986                                     for i in range(len(msg.fields))]
987                                 )
988                     elif isinstance(msg, Protocol.PortalSuspended):
989                         # got all the rows we asked for, but not all that exist
990                         break
991                     elif isinstance(msg, Protocol.CommandComplete):
992                         self._send(Protocol.ClosePortal(portal))
993                         self._send(Protocol.Sync())
994                         while 1:
995                             msg = self._read_message()
996                             if isinstance(msg, Protocol.ReadyForQuery):
997                                 # ready to move on with life...
998                                 self._state = "ready"
999                                 break
1000                             elif isinstance(msg, Protocol.CloseComplete):
1001                                 # ok, great!
1002                                 pass
1003                             elif isinstance(msg, Protocol.ErrorResponse):
1004                                 raise msg.createException()
1005                             else:
1006                                 raise InternalError("unexpected response msg %r" % msg)
1007                         end_of_data = True
1008                         break
1009                     elif isinstance(msg, Protocol.ErrorResponse):
1010                         raise msg.createException()
1011                     else:
1012                         raise InternalError("Unexpected response msg %r" % msg)
1013                 return end_of_data, rows
1014             finally:
1015                 self._sock_lock.release()
1016
1017         def close_statement(self, statement):