root/pg8000/pg8000-v1.00/pg8000.py

Revision 816, 43.7 kB (checked in by mfenniak, 2 years ago)

Add support for NULL DB parameters and values

Line 
1 # vim: sw=4:expandtab:foldmethod=marker
2 #
3 # Copyright (c) 2007, Mathieu Fenniak
4 # All rights reserved.
5 #
6 # Redistribution and use in source and binary forms, with or without
7 # modification, are permitted provided that the following conditions are
8 # met:
9 #
10 # * Redistributions of source code must retain the above copyright notice,
11 # this list of conditions and the following disclaimer.
12 # * Redistributions in binary form must reproduce the above copyright notice,
13 # this list of conditions and the following disclaimer in the documentation
14 # and/or other materials provided with the distribution.
15 # * The name of the author may not be used to endorse or promote products
16 # derived from this software without specific prior written permission.
17 #
18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21 # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
22 # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23 # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24 # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25 # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26 # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27 # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28 # POSSIBILITY OF SUCH DAMAGE.
29
30 __author__ = "Mathieu Fenniak"
31
32 import socket
33 import struct
34 import datetime
35 import md5
36 import decimal
37
38 class Warning(StandardError):
39     pass
40
41 class Error(StandardError):
42     pass
43
44 class InterfaceError(Error):
45     pass
46
47 class DatabaseError(Error):
48     pass
49
50 class DataError(DatabaseError):
51     pass
52
53 class OperationalError(DatabaseError):
54     pass
55
56 class IntegrityError(DatabaseError):
57     pass
58
59 class InternalError(DatabaseError):
60     pass
61
62 class ProgrammingError(DatabaseError):
63     pass
64
65 class NotSupportedError(DatabaseError):
66     pass
67
68
69 class DataIterator(object):
70     def __init__(self, obj, func):
71         self.obj = obj
72         self.func = func
73
74     def __iter__(self):
75         return self
76
77     def next(self):
78         retval = self.func(self.obj)
79         if retval == None:
80             raise StopIteration()
81         return retval
82
83 ##
84 # This class represents a prepared statement.  A prepared statement is
85 # pre-parsed on the server, which reduces the need to parse the query every
86 # time it is run.  The statement can have parameters in the form of $1, $2, $3,
87 # etc.  When parameters are used, the types of the parameters need to be
88 # specified when creating the prepared statement.
89 # <p>
90 # Stability: Added in v1.00, stability guaranteed for v1.xx.
91 #
92 # @param connection     An instance of {@link Connection Connection}.
93 #
94 # @param statement      The SQL statement to be represented, often containing
95 # parameters in the form of $1, $2, $3, etc.
96 #
97 # @param types          Python type objects for each parameter in the SQL
98 # statement.  For example, int, float, str.
99 class PreparedStatement(object):
100
101     ##
102     # Determines the number of rows to read from the database server at once.
103     # Reading more rows increases performance at the cost of memory.  The
104     # default value is 100 rows.  The affect of this parameter is transparent.
105     # That is, the library reads more rows when the cache is empty
106     # automatically.
107     # <p>
108     # Stability: Added in v1.00, stability guaranteed for v1.xx.  It is
109     # possible that implementation changes in the future could cause this
110     # parameter to be ignored.O
111     row_cache_size = 100
112
113     def __init__(self, connection, statement, *types):
114         self.c = connection.c
115         self._portal_name = "pg8000_portal_%s_%s" % (id(self.c), id(self))
116         self._statement_name = "pg8000_statement_%s_%s" % (id(self.c), id(self))
117         self._row_desc = None
118         self._cached_rows = []
119         self._command_complete = True
120         self._parse_row_desc = self.c.parse(self._statement_name, statement, types)
121
122     def __del__(self):
123         # This __del__ should work with garbage collection / non-instant
124         # cleanup.  It only really needs to be called right away if the same
125         # object id (and therefore the same statement name) might be reused
126         # soon, and clearly that wouldn't happen in a GC situation.
127         self.c.close_statement(self._statement_name)
128
129     ##
130     # Run the SQL prepared statement with the given parameters.
131     # <p>
132     # Stability: Added in v1.00, stability guaranteed for v1.xx.
133     def execute(self, *args):
134         if not self._command_complete:
135             # cleanup last execute
136             self._cached_rows = []
137             self.c.close_portal(self._portal_name)
138         self._command_complete = False
139         self._row_desc = self.c.bind(self._portal_name, self._statement_name, args, self._parse_row_desc)
140         if self._row_desc:
141             # We execute our cursor right away to fill up our cache.  This
142             # prevents the cursor from being destroyed, apparently, by a rogue
143             # Sync between Bind and Execute.  Since it is quite likely that
144             # data will be read from us right away anyways, this seems a safe
145             # move for now.
146             self._fill_cache()
147
148     def _fill_cache(self):
149         if self._cached_rows:
150             raise InternalError("attempt to fill cache that isn't empty")
151         end_of_data, rows = self.c.fetch_rows(self._portal_name, self.row_cache_size, self._row_desc)
152         self._cached_rows = rows
153         if end_of_data:
154             self._command_complete = True
155
156     def _fetch(self):
157         if not self._cached_rows:
158             if self._command_complete:
159                 return None
160             self._fill_cache()
161             if self._command_complete and not self._cached_rows:
162                 # fill cache tells us the command is complete, but yet we have
163                 # no rows after filling our cache.  This is a special case when
164                 # a query returns no rows.
165                 return None
166         row = self._cached_rows[0]
167         del self._cached_rows[0]
168         return tuple(row)
169
170     ##
171     # Read a row from the database server, and return it in a dictionary
172     # indexed by column name/alias.  This method will raise an error if two
173     # columns have the same name.  Returns None after the last row.
174     # <p>
175     # Stability: Added in v1.00, stability guaranteed for v1.xx.
176     def read_dict(self):
177         row = self._fetch()
178         if row == None:
179             return row
180         retval = {}
181         for i in range(len(self._row_desc.fields)):
182             col_name = self._row_desc.fields[i]['name']
183             if retval.has_key(col_name):
184                 raise InterfaceError("cannot return dict of row when two columns have the same name (%r)" % (col_name,))
185             retval[col_name] = row[i]
186         return retval
187
188     ##
189     # Read a row from the database server, and return it as a tuple of values.
190     # Returns None after the last row.
191     # <p>
192     # Stability: Added in v1.00, stability guaranteed for v1.xx.
193     def read_tuple(self):
194         row = self._fetch()
195         if row == None:
196             return row
197         return row
198
199     ##
200     # Return an iterator for the output of this statement.  The iterator will
201     # return a tuple for each row, in the same manner as {@link
202     # #PreparedStatement.read_tuple read_tuple}.
203     # <p>
204     # Stability: Added in v1.00, stability guaranteed for v1.xx.
205     def iterate_tuple(self):
206         return DataIterator(self, PreparedStatement.read_tuple)
207
208     ##
209     # Return an iterator for the output of this statement.  The iterator will
210     # return a dict for each row, in the same manner as {@link
211     # #PreparedStatement.read_dict read_dict}.
212     # <p>
213     # Stability: Added in v1.00, stability guaranteed for v1.xx.
214     def iterate_dict(self):
215         return DataIterator(self, PreparedStatement.read_dict)
216
217 ##
218 # The Cursor class allows multiple queries to be performed concurrently with a
219 # single PostgreSQL connection.  The Cursor object is implemented internally by
220 # using a {@link PreparedStatement PreparedStatement} object, so if you plan to
221 # use a statement multiple times, you might as well create a PreparedStatement
222 # and save a small amount of reparsing time.
223 # <p>
224 # Stability: Added in v1.00, stability guaranteed for v1.xx.
225 #
226 # @param connection     An instance of {@link Connection Connection}.
227 class Cursor(object):
228     def __init__(self, connection):
229         self.connection = connection
230         self._stmt = None
231
232     ##
233     # Run an SQL statement using this cursor.  The SQL statement can have
234     # parameters in the form of $1, $2, $3, etc., which will be filled in by
235     # the additional arguments passed to this function.
236     # <p>
237     # Stability: Added in v1.00, stability guaranteed for v1.xx.
238     # @param query      The SQL statement to execute.
239     def execute(self, query, *args):
240         self._stmt = PreparedStatement(self.connection, query, *[type(x) for x in args])
241         self._stmt.execute(*args)
242
243     ##
244     # Read a row from the database server, and return it in a dictionary
245     # indexed by column name/alias.  This method will raise an error if two
246     # columns have the same name.  Returns None after the last row.
247     # <p>
248     # Stability: Added in v1.00, stability guaranteed for v1.xx.
249     def read_dict(self):
250         if self._stmt == None:
251             raise ProgrammingError("attempting to read from unexecuted cursor")
252         return self._stmt.read_dict()
253
254     ##
255     # Read a row from the database server, and return it as a tuple of values.
256     # Returns None after the last row.
257     # <p>
258     # Stability: Added in v1.00, stability guaranteed for v1.xx.
259     def read_tuple(self):
260         if self._stmt == None:
261             raise ProgrammingError("attempting to read from unexecuted cursor")
262         return self._stmt.read_tuple()
263
264     ##
265     # Return an iterator for the output of this statement.  The iterator will
266     # return a tuple for each row, in the same manner as {@link
267     # #PreparedStatement.read_tuple read_tuple}.
268     # <p>
269     # Stability: Added in v1.00, stability guaranteed for v1.xx.
270     def iterate_tuple(self):
271         if self._stmt == None:
272             raise ProgrammingError("attempting to read from unexecuted cursor")
273         return self._stmt.iterate_tuple()
274
275     ##
276     # Return an iterator for the output of this statement.  The iterator will
277     # return a dict for each row, in the same manner as {@link
278     # #PreparedStatement.read_dict read_dict}.
279     # <p>
280     # Stability: Added in v1.00, stability guaranteed for v1.xx.
281     def iterate_dict(self):
282         if self._stmt == None:
283             raise ProgrammingError("attempting to read from unexecuted cursor")
284         return self._stmt.iterate_dict()
285
286 ##
287 # This class represents a connection to a PostgreSQL database.
288 # <p>
289 # The database connection is derived from the {@link #Cursor Cursor} class,
290 # which provides a default cursor for running queries.  It also provides
291 # transaction control via the 'begin', 'commit', and 'rollback' methods.
292 # Without beginning a transaction explicitly, all statements will autocommit to
293 # the database.
294 # <p>
295 # Stability: Added in v1.00, stability guaranteed for v1.xx.
296 #
297 # @param host   The hostname of the PostgreSQL server to connect with.  Only
298 # TCP/IP connections are presently supported, so this parameter is mandatory.
299 #
300 # @param user   The username to connect to the PostgreSQL server with.  This
301 # parameter is mandatory.
302 #
303 # @keyparam port   The TCP/IP port of the PostgreSQL server instance.  This
304 # parameter defaults to 5432, the registered and common port of PostgreSQL
305 # TCP/IP servers.
306 #
307 # @keyparam database   The name of the database instance to connect with.  This
308 # parameter is optional, if omitted the PostgreSQL server will assume the
309 # database name is the same as the username.
310 #
311 # @keyparam password   The user password to connect to the server with.  This
312 # parameter is optional.  If omitted, and the database server requests password
313 # based authentication, the connection will fail.  On the other hand, if this
314 # parameter is provided and the database does not request password
315 # authentication, then the password will not be used.
316 #
317 # @keyparam socket_timeout  Socket connect timeout measured in seconds.
318 # Defaults to 60 seconds.
319 class Connection(Cursor):
320     def __init__(self, host, user, port=5432, database=None, password=None, socket_timeout=60):
321         self._row_desc = None
322         try:
323             self.c = Protocol.Connection(host, port, socket_timeout=socket_timeout)
324             self.c.connect()
325             self.c.authenticate(user, password=password, database=database)
326         except socket.error, e:
327             raise InterfaceError("communication error", e)
328         Cursor.__init__(self, self)
329         self._begin = PreparedStatement(self, "BEGIN TRANSACTION")
330         self._commit = PreparedStatement(self, "COMMIT TRANSACTION")
331         self._rollback = PreparedStatement(self, "ROLLBACK TRANSACTION")
332
333     ##
334     # Begins a new transaction.
335     # <p>
336     # Stability: Added in v1.00, stability guaranteed for v1.xx.
337     def begin(self):
338         self._begin.execute()
339
340     ##
341     # Commits the running transaction.
342     # <p>
343     # Stability: Added in v1.00, stability guaranteed for v1.xx.
344     def commit(self):
345         self._commit.execute()
346
347     ##
348     # Rolls back the running transaction.
349     # <p>
350     # Stability: Added in v1.00, stability guaranteed for v1.xx.
351     def rollback(self):
352         self._rollback.execute()
353
354
355 class Protocol(object):
356     class StartupMessage(object):
357         def __init__(self, user, database=None):
358             self.user = user
359             self.database = database
360
361         def serialize(self):
362             protocol = 196608
363             val = struct.pack("!i", protocol)
364             val += "user\x00" + self.user + "\x00"
365             if self.database:
366                 val += "database\x00" + self.database + "\x00"
367             val += "\x00"
368             val = struct.pack("!i", len(val) + 4) + val
369             return val
370
371     class Query(object):
372         def __init__(self, qs):
373             self.qs = qs
374
375         def serialize(self):
376             val = self.qs + "\x00"
377             val = struct.pack("!i", len(val) + 4) + val
378             val = "Q" + val
379             return val
380
381     class Parse(object):
382         def __init__(self, ps, qs, type_oids):
383             self.ps = ps
384             self.qs = qs
385             self.type_oids = type_oids
386
387         def serialize(self):
388             val = self.ps + "\x00" + self.qs + "\x00"
389             val = val + struct.pack("!h", len(self.type_oids))
390             for oid in self.type_oids:
391                 val = val + struct.pack("!i", oid)
392             val = struct.pack("!i", len(val) + 4) + val
393             val = "P" + val
394             return val
395
396     class Bind(object):
397         def __init__(self, portal, ps, in_fc, params, out_fc, client_encoding):
398             self.portal = portal
399             self.ps = ps
400             self.in_fc = in_fc
401             self.params = []
402             for i in range(len(params)):
403                 if len(self.in_fc) == 0:
404                     fc = 0
405                 elif len(self.in_fc) == 1:
406                     fc = self.in_fc[0]
407                 else:
408                     fc = self.in_fc[i]
409                 self.params.append(Types.pg_value(params[i], fc, client_encoding = client_encoding))
410             self.out_fc = out_fc
411
412         def serialize(self):
413             val = self.portal + "\x00" + self.ps + "\x00"
414             val = val + struct.pack("!h", len(self.in_fc))
415             for fc in self.in_fc:
416                 val = val + struct.pack("!h", fc)
417             val = val + struct.pack("!h", len(self.params))
418             for param in self.params:
419                 if param == None:
420                     # special case, NULL value
421                     val = val + struct.pack("!i", -1)
422                 else:
423                     val = val + struct.pack("!i", len(param)) + param
424             val = val + struct.pack("!h", len(self.out_fc))
425             for fc in self.out_fc:
426                 val = val + struct.pack("!h", fc)
427             val = struct.pack("!i", len(val) + 4) + val
428             val = "B" + val
429             return val
430
431     class Close(object):
432         def __init__(self, typ, name):
433             if len(typ) != 1:
434                 raise InternalError("Close typ must be 1 char")
435             self.typ = typ
436             self.name = name
437
438         def serialize(self):
439             val = self.typ + self.name + "\x00"
440             val = struct.pack("!i", len(val) + 4) + val
441             val = "C" + val
442             return val
443
444     class ClosePortal(Close):
445         def __init__(self, name):
446             Protocol.Close.__init__(self, "P", name)
447
448     class ClosePreparedStatement(Close):
449         def __init__(self, name):
450             Protocol.Close.__init__(self, "S", name)
451
452     class Describe(object):
453         def __init__(self, typ, name):
454             if len(typ) != 1:
455                 raise InternalError("Describe typ must be 1 char")
456             self.typ = typ
457             self.name = name
458
459         def serialize(self):
460             val = self.typ + self.name + "\x00"
461             val = struct.pack("!i", len(val) + 4) + val
462             val = "D" + val
463             return val
464
465     class DescribePortal(Describe):
466         def __init__(self, name):
467             Protocol.Describe.__init__(self, "P", name)
468
469     class DescribePreparedStatement(Describe):
470         def __init__(self, name):
471             Protocol.Describe.__init__(self, "S", name)
472
473     class Flush(object):
474         def serialize(self):
475             return 'H\x00\x00\x00\x04'
476
477     class Sync(object):
478         def serialize(self):
479             return 'S\x00\x00\x00\x04'
480
481     class PasswordMessage(object):
482         def __init__(self, pwd):
483             self.pwd = pwd
484
485         def serialize(self):
486             val = self.pwd + "\x00"
487             val = struct.pack("!i", len(val) + 4) + val
488             val = "p" + val
489             return val
490
491     class Execute(object):
492         def __init__(self, portal, row_count):
493             self.portal = portal
494             self.row_count = row_count
495
496         def serialize(self):
497             val = self.portal + "\x00" + struct.pack("!i", self.row_count)
498             val = struct.pack("!i", len(val) + 4) + val
499             val = "E" + val
500             return val
501
502     class AuthenticationRequest(object):
503         def __init__(self, data):
504             pass
505
506         def createFromData(data):
507             ident = struct.unpack("!i", data[:4])[0]
508             klass = Protocol.authentication_codes.get(ident, None)
509             if klass != None:
510                 return klass(data[4:])
511             else:
512                 raise NotSupportedError("authentication method %r not supported" % (ident,))
513         createFromData = staticmethod(createFromData)
514
515         def ok(self, conn, user, **kwargs):
516             raise InternalError("ok method should be overridden on AuthenticationRequest instance")
517
518     class AuthenticationOk(AuthenticationRequest):
519         def ok(self, conn, user, **kwargs):
520             return True
521
522     class AuthenticationMD5Password(AuthenticationRequest):
523         def __init__(self, data):
524             self.salt = "".join(struct.unpack("4c", data))
525
526         def ok(self, conn, user, password=None, **kwargs):
527             if password == None:
528                 raise InterfaceError("server requesting MD5 password authentication, but no password was provided")
529             pwd = "md5" + md5.new(md5.new(password + user).hexdigest() + self.salt).hexdigest()
530             conn._send(Protocol.PasswordMessage(pwd))
531             msg = conn._read_message()
532             if isinstance(msg, Protocol.AuthenticationRequest):
533                 return msg.ok(conn, user)
534             elif isinstance(msg, Protocol.ErrorResponse):
535                 if msg.code == "28000":
536                     raise InterfaceError("md5 password authentication failed")
537                 else:
538                     raise InternalError("server returned unexpected error %r" % msg)
539             else:
540                 raise InternalError("server returned unexpected response %r" % msg)
541
542     authentication_codes = {
543         0: AuthenticationOk,
544         5: AuthenticationMD5Password,
545     }
546
547     class ParameterStatus(object):
548         def __init__(self, key, value):
549             self.key = key
550             self.value = value
551
552         def createFromData(data):
553             key = data[:data.find("\x00")]
554             value = data[data.find("\x00")+1:-1]
555             return Protocol.ParameterStatus(key, value)
556         createFromData = staticmethod(createFromData)
557
558     class BackendKeyData(object):
559         def __init__(self, process_id, secret_key):
560             self.process_id = process_id
561             self.secret_key = secret_key
562
563         def createFromData(data):
564             process_id, secret_key = struct.unpack("!2i", data)
565             return Protocol.BackendKeyData(process_id, secret_key)
566         createFromData = staticmethod(createFromData)
567
568     class NoData(object):
569         def createFromData(data):
570             return Protocol.NoData()
571         createFromData = staticmethod(createFromData)
572
573     class ParseComplete(object):
574         def createFromData(data):
575             return Protocol.ParseComplete()
576         createFromData = staticmethod(createFromData)
577
578     class BindComplete(object):
579         def createFromData(data):
580             return Protocol.BindComplete()
581         createFromData = staticmethod(createFromData)
582
583     class CloseComplete(object):
584         def createFromData(data):
585             return Protocol.CloseComplete()
586         createFromData = staticmethod(createFromData)
587
588     class PortalSuspended(object):
589         def createFromData(data):
590             return Protocol.PortalSuspended()
591         createFromData = staticmethod(createFromData)
592
593     class ReadyForQuery(object):
594         def __init__(self, status):
595             self.status = status
596
597         def __repr__(self):
598             return "<ReadyForQuery %s>" % \
599                     {"I": "Idle", "T": "Idle in Transaction", "E": "Idle in Failed Transaction"}[self.status]
600
601         def createFromData(data):
602             return Protocol.ReadyForQuery(data)
603         createFromData = staticmethod(createFromData)
604
605     class NoticeResponse(object):
606         def __init__(self):
607             pass
608         def createFromData(data):
609             # we could read the notice here, but we don't care yet.
610             return Protocol.NoticeResponse()
611         createFromData = staticmethod(createFromData)
612
613     class ErrorResponse(object):
614         def __init__(self, severity, code, msg):
615             self.severity = severity
616             self.code = code
617             self.msg = msg
618
619         def __repr__(self):
620             return "<ErrorResponse %s %s %r>" % (self.severity, self.code, self.msg)
621
622         def createException(self):
623             return ProgrammingError(self.severity, self.code, self.msg)
624
625         def createFromData(data):
626             args = {}
627             for s in data.split("\x00"):
628                 if not s:
629                     continue
630                 elif s[0] == "S":
631                     args["severity"] = s[1:]
632                 elif s[0] == "C":
633                     args["code"] = s[1:]
634                 elif s[0] == "M":
635                     args["msg"] = s[1:]
636             return Protocol.ErrorResponse(**args)
637         createFromData = staticmethod(createFromData)
638
639     class ParameterDescription(object):
640         def __init__(self, type_oids):
641             self.type_oids = type_oids
642         def createFromData(data):
643             count = struct.unpack("!h", data[:2])[0]
644             type_oids = struct.unpack("!" + "i"*count, data[2:])
645             return Protocol.ParameterDescription(type_oids)
646         createFromData = staticmethod(createFromData)
647
648     class RowDescription(object):
649         def __init__(self, fields):
650             self.fields = fields
651
652         def createFromData(data):
653             count = struct.unpack("!h", data[:2])[0]
654             data = data[2:]
655             fields = []
656             for i in range(count):
657                 null = data.find("\x00")
658                 field = {"name": data[:null]}
659                 data = data[null+1:]
660                 field["table_oid"], field["column_attrnum"], field["type_oid"], field["type_size"], field["type_modifier"], field["format"] = struct.unpack("!ihihih", data[:18])
661                 data = data[18:]
662                 fields.append(field)
663             return Protocol.RowDescription(fields)
664         createFromData = staticmethod(createFromData)
665
666     class CommandComplete(object):
667         def __init__(self, tag):
668             self.tag = tag
669
670         def createFromData(data):
671             return Protocol.CommandComplete(data[:-1])
672         createFromData = staticmethod(createFromData)
673
674     class DataRow(object):
675         def __init__(self, fields):
676             self.fields = fields
677
678         def createFromData(data):
679             count = struct.unpack("!h", data[:2])[0]
680             data = data[2:]
681             fields = []
682             for i in range(count):
683                 val_len = struct.unpack("!i", data[:4])[0]
684                 data = data[4:]
685                 if val_len == -1:
686                     fields.append(None)
687                 else:
688                     fields.append(data[:val_len])
689                     data = data[val_len:]
690             return Protocol.DataRow(fields)
691         createFromData = staticmethod(createFromData)
692
693     class Connection(object):
694         def __init__(self, host=None, port=5432, socket_timeout=60):
695             self._state = "unconnected"
696             self._client_encoding = "ascii"
697             self._host = host
698             self._port = port
699             self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
700             self._sock.settimeout(socket_timeout)
701             self._backend_key_data = None
702
703         def verifyState(self, state):
704             if self._state != state:
705                 raise InternalError("connection state must be %s, is %s" % (state, self._state))
706
707         def _send(self, msg):
708             #print repr(msg)
709             data = msg.serialize()
710             self._sock.send(data)
711
712         def _read_message(self):
713             bytes = self._sock.recv(5)
714             assert len(bytes) == 5
715             message_code = bytes[0]
716             data_len = struct.unpack("!i", bytes[1:])[0] - 4
717             if data_len == 0:
718                 bytes = ""
719             else:
720                 bytes = self._sock.recv(data_len)
721             msg = Protocol.message_types[message_code].createFromData(bytes)
722             if isinstance(msg, Protocol.NoticeResponse):
723                 # ignore NoticeResponse
724                 return self._read_message()
725             else:
726                 return msg
727
728         def connect(self):
729             self.verifyState("unconnected")
730             self._sock.connect((self._host, self._port))
731             self._state = "noauth"
732
733         def authenticate(self, user, **kwargs):
734             self.verifyState("noauth")
735             self._send(Protocol.StartupMessage(user, database=kwargs.get("database",None)))
736             msg = self._read_message()
737             if isinstance(msg, Protocol.AuthenticationRequest):
738                 if msg.ok(self, user, **kwargs):
739                     self._state = "auth"
740                     while 1:
741                         msg = self._read_message()
742                         if isinstance(msg, Protocol.ReadyForQuery):
743                             # done reading messages
744                             self._state = "ready"
745                             break
746                         elif isinstance(msg, Protocol.ParameterStatus):
747                             if msg.key == "client_encoding":
748                                 self._client_encoding = msg.value
749                         elif isinstance(msg, Protocol.BackendKeyData):
750                             self._backend_key_data = msg
751                         elif isinstance(msg, Protocol.ErrorResponse):
752                             raise msg.createException()
753                         else:
754                             raise InternalError("unexpected msg %r" % msg)
755                 else:
756                     raise InterfaceError("authentication method %s failed" % msg.__class__.__name__)
757             else:
758                 raise InternalError("StartupMessage was responded to with non-AuthenticationRequest msg %r" % msg)
759
760         def parse(self, statement, qs, types):
761             self.verifyState("ready")
762             type_info = [Types.pg_type_info(x) for x in types]
763             param_types, param_fc = [x[0] for x in type_info], [x[1] for x in type_info] # zip(*type_info) -- fails on empty arr
764             self._send(Protocol.Parse(statement, qs, param_types))
765             self._send(Protocol.DescribePreparedStatement(statement))
766             self._send(Protocol.Flush())
767             while 1:
768                 msg = self._read_message()
769                 if isinstance(msg, Protocol.ParseComplete):
770                     # ok, good.
771                     pass
772                 elif isinstance(msg, Protocol.ParameterDescription):
773                     # well, we don't really care -- we're going to send whatever
774                     # we want and let the database deal with it.  But thanks
775                     # anyways!
776                     pass
777                 elif isinstance(msg, Protocol.NoData):
778                     # We're not waiting for a row description.  Return
779                     # something destinctive to let bind know that there is no
780                     # output.
781                     return (None, param_fc)
782                 elif isinstance(msg, Protocol.RowDescription):
783                     return (msg, param_fc)
784                 elif isinstance(msg, Protocol.ErrorResponse):
785                     raise msg.createException()
786                 else:
787                     raise InternalError("Unexpected response msg %r" % (msg))
788
789         def bind(self, portal, statement, params, parse_data):
790             self.verifyState("ready")
791             row_desc, param_fc = parse_data
792             if row_desc == None:
793                 # no data coming out
794                 output_fc = ()
795             else:
796                 # We've got row_desc that allows us to identify what we're going to
797                 # get back from this statement.
798                 output_fc = [Types.py_type_info(f) for f in row_desc.fields]
799             self._send(Protocol.Bind(portal, statement, param_fc, params, output_fc, self._client_encoding))
800             # We need to describe the portal after bind, since the return
801             # format codes will be different (hopefully, always what we
802             # requested).
803             self._send(Protocol.DescribePortal(portal))
804             self._send(Protocol.Flush())
805             while 1:
806                 msg = self._read_message()
807                 if isinstance(msg, Protocol.BindComplete):
808                     # good news everybody!
809                     pass
810                 elif isinstance(msg, Protocol.NoData):
811                     # No data means we should execute this command right away.
812                     self._send(Protocol.Execute(portal, 0))
813                     self._send(Protocol.Sync())
814                     while 1:
815                         msg = self._read_message()
816                         if isinstance(msg, Protocol.CommandComplete):
817                             # more good news!
818                             pass
819                         elif isinstance(msg, Protocol.ReadyForQuery):
820                             # ready to move on with life...
821                             break
822                         elif isinstance(msg, Protocol.ErrorResponse):
823                             raise msg.createException()
824                         else:
825                             raise InternalError("unexpected response")
826                     return None
827                 elif isinstance(msg, Protocol.RowDescription):
828                     # Return the new row desc, since it will have the format
829                     # types we asked for
830                     return msg
831                 elif isinstance(msg, Protocol.ErrorResponse):
832                     raise msg.createException()
833                 else:
834                     raise InternalError("Unexpected response msg %r" % (msg))
835
836         def fetch_rows(self, portal, row_count, row_desc):
837             self.verifyState("ready")
838             self._send(Protocol.Execute(portal, row_count))
839             self._send(Protocol.Flush())
840             rows = []
841             end_of_data = False
842             while 1:
843                 msg = self._read_message()
844                 if isinstance(msg, Protocol.DataRow):
845                     rows.append(
846                             [Types.py_value(msg.fields[i], row_desc.fields[i], client_encoding=self._client_encoding)
847                                 for i in range(len(msg.fields))]
848                             )
849                 elif isinstance(msg, Protocol.PortalSuspended):
850                     # got all the rows we asked for, but not all that exist
851                     break
852                 elif isinstance(msg, Protocol.CommandComplete):
853                     self._send(Protocol.ClosePortal(portal))
854                     self._send(Protocol.Sync())
855                     while 1:
856                         msg = self._read_message()
857                         if isinstance(msg, Protocol.ReadyForQuery):
858                             # ready to move on with life...
859                             self._state = "ready"
860                             break
861                         elif isinstance(msg, Protocol.CloseComplete):
862                             # ok, great!
863                             pass
864                         elif isinstance(msg, Protocol.ErrorResponse):
865                             raise msg.createException()
866                         else:
867                             raise InternalError("unexpected response msg %r" % msg)
868                     end_of_data = True
869                     break
870                 elif isinstance(msg, Protocol.ErrorResponse):
871                     raise msg.createException()
872                 else:
873                     raise InternalError("Unexpected response msg %r" % msg)
874             return end_of_data, rows
875
876         def close_statement(self, statement):
877             self._send(Protocol.ClosePreparedStatement(statement))
878             self._send(Protocol.Sync())
879             while 1:
880                 msg = self._read_message()
881                 if isinstance(msg, Protocol.CloseComplete):
882                     # thanks!
883                     pass
884                 elif isinstance(msg, Protocol.ReadyForQuery):
885                     return
886                 elif isinstance(msg, Protocol.ErrorResponse):
887                     raise msg.createException()
888                 else:
889                     raise InternalError("Unexpected response msg %r" % msg)
890
891         def close_portal(self, portal):
892             self._send(Protocol.ClosePortal(portal))
893             self._send(Protocol.Sync())
894             while 1:
895                 msg = self._read_message()
896                 if isinstance(msg, Protocol.CloseComplete):
897                     # thanks!
898                     pass
899                 elif isinstance(msg, Protocol.ReadyForQuery):
900                     return
901                 elif isinstance(msg, Protocol.ErrorResponse):
902                     raise msg.createException()
903                 else:
904                     raise InternalError("Unexpected response msg %r" % msg)
905
906         def query(self, qs):
907             self.verifyState("ready")
908             self._send(Protocol.Query(qs))
909             msg = self._read_message()
910             if isinstance(msg, Protocol.RowDescription):
911                 self._state = "in_query"
912                 return msg
913             elif isinstance(msg, Protocol.ErrorResponse):
914                 raise msg.createException()
915             else:
916                 raise InternalError("RowDescription expected, other message recv'd")
917
918         def getrow(self):
919             self.verifyState("in_query")
920             msg = self._read_message()
921             if isinstance(msg, Protocol.DataRow):
922                 return msg
923             elif isinstance(msg, Protocol.CommandComplete):
924                 self.status = "query_complete"
925                 self._waitForReady()
926                 return None
927
928     message_types = {
929         "N": NoticeResponse,
930         "R": AuthenticationRequest,
931         "S": ParameterStatus,
932         "K": BackendKeyData,
933         "Z": ReadyForQuery,
934         "T": RowDescription,
935         "E": ErrorResponse,
936         "D": DataRow,
937         "C": CommandComplete,
938         "1": ParseComplete,
939         "2": BindComplete,
940         "3": CloseComplete,
941         "s": PortalSuspended,
942         "n": NoData,
943         "t": ParameterDescription,
944         }
945
946 class Types(object):
947     def pg_type_info(typ):
948         data = Types.py_types.get(typ)
949         if data == None:
950             raise NotSupportedError("type %r not mapped to pg type" % typ)
951         type_oid = data.get("tid")
952         if type_oid == None:
953             raise InternalError("type %r has no type_oid" % typ)
954         prefer = data.get("prefer")
955         if prefer != None:
956             if prefer == "bin":
957                 if data.get("bin_out") == None:
958                     raise InternalError("bin format prefered but not avail for type %r" % typ)
959                 format = 1
960             elif prefer == "txt":
961                 if data.get("txt_out") == None:
962                     raise InternalError("txt format prefered but not avail for type %r" % typ)
963                 format = 0
964             else:
965                 raise InternalError("prefer flag not recognized for type %r" % typ)
966         else:
967             # by default, prefer bin, but go with whatever exists
968             if data.get("bin_out"):
969                 format = 1
970             elif data.get("txt_out"):
971                 format = 0
972             else:
973                 raise InternalError("no conversion fuction for type %r" % typ)
974         return type_oid, format
975     pg_type_info = staticmethod(pg_type_info)
976
977     def pg_value(v, fc, **kwargs):
978         typ = type(v)
979         data = Types.py_types.get(typ)
980         if data == None:
981             raise NotSupportedError("type %r not mapped to pg type" % typ)
982         elif data.get("tid") == -1:
983             # special case: NULL values
984             return None
985         if fc == 0:
986             func = data.get("txt_out")
987         elif fc == 1:
988             func = data.get("bin_out")
989         else:
990             raise InternalError("unrecognized format code %r" % fc)
991         if func == None:
992             raise NotSupportedError("type %r, format code %r not supported" % (typ, fc))
993         return func(v, **kwargs)
994     pg_value = staticmethod(pg_value)
995
996     def py_type_info(description):
997         type_oid = description['type_oid']
998         data = Types.pg_types.get(type_oid)
999         if data == None:
1000             raise NotSupportedError("type oid %r not mapped to py type" % type_oid)
1001         prefer = data.get("prefer")
1002         if prefer != None:
1003             if prefer == "bin":
1004                 if data.get("bin_in") == None:
1005                     raise InternalError("bin format prefered but not avail for type oid %r" % type_oid)
1006                 format = 1
1007             elif prefer == "txt":
1008                 if data.get("txt_in") == None:
1009                     raise InternalError("txt format prefered but not avail for type oid %r" % type_oid)
1010                 format = 0
1011             else:
1012                 raise InternalError("prefer flag not recognized for type oid %r" % type_oid)
1013         else:
1014             # by default, prefer bin, but go with whatever exists
1015             if data.get("bin_in"):
1016                 format = 1
1017             elif data.get("txt_in"):