root/pg8000/trunk/pg8000.py

Revision 852, 62.8 kB (checked in by mfenniak, 2 years ago)

add date and time data types, and timezonetz read. Add bool write -- somehow this was missed.

Line 
1 # vim: sw=4:expandtab:foldmethod=marker
2 #
3 # Copyright (c) 2007, Mathieu Fenniak
4 # All rights reserved.
5 #
6 # Redistribution and use in source and binary forms, with or without
7 # modification, are permitted provided that the following conditions are
8 # met:
9 #
10 # * Redistributions of source code must retain the above copyright notice,
11 # this list of conditions and the following disclaimer.
12 # * Redistributions in binary form must reproduce the above copyright notice,
13 # this list of conditions and the following disclaimer in the documentation
14 # and/or other materials provided with the distribution.
15 # * The name of the author may not be used to endorse or promote products
16 # derived from this software without specific prior written permission.
17 #
18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21 # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
22 # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23 # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24 # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25 # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26 # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27 # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28 # POSSIBILITY OF SUCH DAMAGE.
29
30 __author__ = "Mathieu Fenniak"
31
32 import socket
33 import struct
34 import datetime
35 import md5
36 import decimal
37 import threading
38 import time
39
40 debug_log = file("/Users/mfenniak/SQLAlchemy-0.3.5/pg8000_debug.log", "w")
41
42 class Warning(StandardError):
43     pass
44
45 class Error(StandardError):
46     pass
47
48 class InterfaceError(Error):
49     pass
50
51 class DatabaseError(Error):
52     pass
53
54 class DataError(DatabaseError):
55     pass
56
57 class OperationalError(DatabaseError):
58     pass
59
60 class IntegrityError(DatabaseError):
61     pass
62
63 class InternalError(DatabaseError):
64     pass
65
66 class ProgrammingError(DatabaseError):
67     pass
68
69 class NotSupportedError(DatabaseError):
70     pass
71
72 class DataIterator(object):
73     def __init__(self, obj, func):
74         self.obj = obj
75         self.func = func
76
77     def __iter__(self):
78         return self
79
80     def next(self):
81         retval = self.func(self.obj)
82         if retval == None:
83             raise StopIteration()
84         return retval
85
86 class DBAPI(object):
87     Warning = Warning
88     Error = Error
89     InterfaceError = InterfaceError
90     InternalError = InternalError
91     DatabaseError = DatabaseError
92     DataError = DataError
93     OperationalError = OperationalError
94     IntegrityError = IntegrityError
95     ProgrammingError = ProgrammingError
96     NotSupportedError = NotSupportedError
97    
98     apilevel = "2.0"
99     threadsafety = 3
100     paramstyle = 'format' # paramstyle can be changed to any DB-API paramstyle
101
102     def convert_paramstyle(src_style, query, args):
103         # I don't see any way to avoid scanning the query string char by char,
104         # so we might as well take that careful approach and create a
105         # state-based scanner.  We'll use int variables for the state.
106         #  0 -- outside quoted string
107         #  1 -- inside single-quote string '...'
108         #  2 -- inside quoted identifier   "..."
109         #  3 -- inside escaped single-quote string, E'...'
110         debug_log.write("convert_paramstyle(%r, %r, %r)\n" % (src_style, query, args))
111         state = 0
112         output_query = ""
113         output_args = []
114         if src_style == "numeric":
115             output_args = args
116         elif src_style in ("pyformat", "named"):
117             mapping_to_idx = {}
118         i = 0
119         while 1:
120             if i == len(query):
121                 break
122             c = query[i]
123             # print "begin loop", repr(i), repr(c), repr(state)
124             if state == 0:
125                 if c == "'":
126                     i += 1
127                     output_query += c
128                     state = 1
129                 elif c == '"':
130                     i += 1
131                     output_query += c
132                     state = 2
133                 elif c == 'E':
134                     # check for escaped single-quote string
135                     i += 1
136                     if i < len(query) and i > 1 and query[i] == "'":
137                         i += 1
138                         output_query += "E'"
139                         state = 3
140                     else:
141                         output_query += c
142                 elif src_style == "qmark" and c == "?":
143                     i += 1
144                     param_idx = len(output_args)
145                     if param_idx == len(args):
146                         raise ProgrammingError("too many parameter fields, not enough parameters")
147                     output_args.append(args[param_idx])
148                     output_query += "$" + str(param_idx + 1)
149                 elif src_style == "numeric" and c == ":":
150                     i += 1
151                     if i < len(query) and i > 1 and query[i].isdigit():
152                         output_query += "$" + query[i]
153                         i += 1
154                     else:
155                         raise ProgrammingError("numeric parameter : does not have numeric arg")
156                 elif src_style == "named" and c == ":":
157                     name = ""
158                     while 1:
159                         i += 1
160                         if i == len(query):
161                             break
162                         c = query[i]
163                         if c.isalnum():
164                             name += c
165                         else:
166                             break
167                     if name == "":
168                         raise ProgrammingError("empty name of named parameter")
169                     idx = mapping_to_idx.get(name)
170                     if idx == None:
171                         idx = len(output_args)
172                         output_args.append(args[name])
173                         idx += 1
174                         mapping_to_idx[name] = idx
175                     output_query += "$" + str(idx)
176                 elif src_style == "format" and c == "%":
177                     i += 1
178                     if i < len(query) and i > 1:
179                         if query[i] == "s":
180                             param_idx = len(output_args)
181                             if param_idx == len(args):
182                                 raise ProgrammingError("too many parameter fields, not enough parameters")
183                             output_args.append(args[param_idx])
184                             output_query += "$" + str(param_idx + 1)
185                         elif query[i] == "%":
186                             output_query += "%"
187                         else:
188                             raise ProgrammingError("Only %s and %% are supported")
189                         i += 1
190                     else:
191                         raise ProgrammingError("numeric parameter : does not have numeric arg")
192                 elif src_style == "pyformat" and c == "%":
193                     i += 1
194                     if i < len(query) and i > 1:
195                         if query[i] == "(":
196                             i += 1
197                             # begin mapping name
198                             end_idx = query.find(')', i)
199                             if end_idx == -1:
200                                 raise ProgrammingError("began pyformat dict read, but couldn't find end of name")
201                             else:
202                                 name = query[i:end_idx]
203                                 i = end_idx + 1
204                                 if i < len(query) and query[i] == "s":
205                                     i += 1
206                                     idx = mapping_to_idx.get(name)
207                                     if idx == None:
208                                         idx = len(output_args)
209                                         output_args.append(args[name])
210                                         idx += 1
211                                         mapping_to_idx[name] = idx
212                                     output_query += "$" + str(idx)
213                                 else:
214                                     raise ProgrammingError("format not specified or not supported (only %(...)s supported)")
215                         elif query[i] == "%":
216                             output_query += "%"
217                 else:
218                     i += 1
219                     output_query += c
220             elif state == 1:
221                 output_query += c
222                 i += 1
223                 if c == "'":
224                     # Could be a double ''
225                     if i < len(query) and query[i] == "'":
226                         # is a double quote.
227                         output_query += query[i]
228                         i += 1
229                     else:
230                         state = 0
231                 elif src_style in ("pyformat","format") and c == "%":
232                     # hm... we're only going to support an escaped percent sign
233                     if i < len(query):
234                         if query[i] == "%":
235                             # good.  We already output the first percent sign.
236                             i += 1
237                         else:
238                             raise ProgrammingError("'%" + query[i] + "' not supported in quoted string")
239             elif state == 2:
240                 output_query += c
241                 i += 1
242                 if c == '"':
243                     state = 0
244                 elif src_style in ("pyformat","format") and c == "%":
245                     # hm... we're only going to support an escaped percent sign
246                     if i < len(query):
247                         if query[i] == "%":
248                             # good.  We already output the first percent sign.
249                             i += 1
250                         else:
251                             raise ProgrammingError("'%" + query[i] + "' not supported in quoted string")
252             elif state == 3:
253                 output_query += c
254                 i += 1
255                 if c == "\\":
256                     # check for escaped single-quote
257                     if i < len(query) and query[i] == "'":
258                         output_query += "'"
259                         i += 1
260                 elif c == "'":
261                     state = 0
262                 elif src_style in ("pyformat","format") and c == "%":
263                     # hm... we're only going to support an escaped percent sign
264                     if i < len(query):
265                         if query[i] == "%":
266                             # good.  We already output the first percent sign.
267                             i += 1
268                         else:
269                             raise ProgrammingError("'%" + query[i] + "' not supported in quoted string")
270
271         return output_query, tuple(output_args)
272     convert_paramstyle = staticmethod(convert_paramstyle)
273
274
275     class CursorWrapper(object):
276         def __init__(self, conn):
277             self.cursor = Cursor(conn)
278             self.arraysize = 1
279
280         rowcount = property(lambda self: self._getRowCount())
281         def _getRowCount(self):
282             return -1
283
284         description = property(lambda self: self._getDescription())
285         def _getDescription(self):
286             if self.cursor.row_description == None:
287                 return None
288             columns = []
289             for col in self.cursor.row_description:
290                 columns.append((col["name"], col["type_oid"]))
291             return columns
292
293         def execute(self, operation, args=()):
294             debug_log.write("execute(%r, %r)\n" % (operation, args))
295             if self.cursor == None:
296                 raise InterfaceError("cursor is closed")
297             new_query, new_args = DBAPI.convert_paramstyle(DBAPI.paramstyle, operation, args)
298             try:
299                 self.cursor.execute(new_query, *new_args)
300             except:
301                 # any error will rollback the transaction to-date
302                 self.cursor.connection.rollback()
303                 raise
304
305         def executemany(self, operation, parameter_sets):
306             for parameters in parameter_sets:
307                 self.execute(operation, parameters)
308
309         def fetchone(self):
310             if self.cursor == None:
311                 raise InterfaceError("cursor is closed")
312             return self.cursor.read_tuple()
313
314         def fetchmany(self, size=None):
315             if size == None:
316                 size = self.arraysize
317             rows = []
318             for i in range(size):
319                 rows.append(self.fetchone())
320             return rows
321
322         def fetchall(self):
323             if self.cursor == None:
324                 raise InterfaceError("cursor is closed")
325             return tuple(self.cursor.iterate_tuple())
326
327         def close(self):
328             self.cursor = None
329
330         def setinputsizes(self, sizes):
331             pass
332
333         def setoutputsize(self, size, column=None):
334             pass
335
336     class ConnectionWrapper(object):
337         def __init__(self, **kwargs):
338             self.conn = Connection(**kwargs)
339             self.conn.begin()
340
341         def cursor(self):
342             return DBAPI.CursorWrapper(self.conn)
343
344         def commit(self):
345             # There's a threading bug here.  If a query is sent after the
346             # commit, but before the begin, it will be executed immediately
347             # without a surrounding transaction.  Like all threading bugs -- it
348             # sounds unlikely, until it happens every time in one
349             # application...  however, to fix this, we need to lock the
350             # database connection entirely, so that no cursors can execute
351             # statements on other threads.  Support for that type of lock will
352             # be done later.
353             if self.conn == None:
354                 raise InterfaceError("connection is closed")
355             self.conn.commit()
356             self.conn.begin()
357
358         def rollback(self):
359             # see bug description in commit.
360             if self.conn == None:
361                 raise InterfaceError("connection is closed")
362             self.conn.rollback()
363             self.conn.begin()
364
365         def close(self):
366             self.conn = None
367
368     def connect(user, host=None, unix_sock=None, port=5432, database=None, password=None, socket_timeout=60, ssl=False):
369         return DBAPI.ConnectionWrapper(user=user, host=host,
370                 unix_sock=unix_sock, port=port, database=database,
371                 password=password, socket_timeout=socket_timeout, ssl=ssl)
372     connect = staticmethod(connect)
373
374     def Date(year, month, day):
375         return datetime.date(year, month, day)
376     Date = staticmethod(Date)
377
378     def Time(hour, minute, second):
379         return datetime.time(hour, minute, second)
380     Time = staticmethod(Time)
381
382     def Timestamp(year, month, day, hour, minute, second):
383         return datetime.datetime(year, month, day, hour, minute, second)
384     Timestamp = staticmethod(Timestamp)
385
386     def DateFromTicks(ticks):
387         return DBAPI.Date(*time.localtime(ticks)[:3])
388     DateFromTicks = staticmethod(DateFromTicks)
389
390     def TimeFromTicks(ticks):
391         return DBAPI.Time(*time.localtime(ticks)[3:6])
392     TimeFromTicks = staticmethod(TimeFromTicks)
393
394     def TimestampFromTicks(ticks):
395         return DBAPI.Timestamp(*time.localtime(ticks)[:6])
396     TimestampFromTicks = staticmethod(TimestampFromTicks)
397
398     def Binary(value):
399         return Bytea(value)
400     Binary = staticmethod(Binary)
401
402     # I have no idea what this would be used for by a client app.  Should it be
403     # TEXT, VARCHAR, CHAR?  It will only compare against row_description's
404     # type_code if it is this one type.  It is the TEXT type_oid for now.
405     STRING = 25
406
407     # bytea type_oid
408     BINARY = 17
409
410     # numeric type_oid
411     NUMBER = 1700
412
413     # timestamp type_oid
414     DATETIME = 1114
415
416     # oid type_oid
417     ROWID = 26
418
419
420 ##
421 # This class represents a prepared statement.  A prepared statement is
422 # pre-parsed on the server, which reduces the need to parse the query every
423 # time it is run.  The statement can have parameters in the form of $1, $2, $3,
424 # etc.  When parameters are used, the types of the parameters need to be
425 # specified when creating the prepared statement.
426 # <p>
427 # As of v1.01, instances of this class are thread-safe.  This means that a
428 # single PreparedStatement can be accessed by multiple threads without the
429 # internal consistency of the statement being altered.  However, the
430 # responsibility is on the client application to ensure that one thread reading
431 # from a statement isn't affected by another thread starting a new query with
432 # the same statement.
433 # <p>
434 # Stability: Added in v1.00, stability guaranteed for v1.xx.
435 #
436 # @param connection     An instance of {@link Connection Connection}.
437 #
438 # @param statement      The SQL statement to be represented, often containing
439 # parameters in the form of $1, $2, $3, etc.
440 #
441 # @param types          Python type objects for each parameter in the SQL
442 # statement.  For example, int, float, str.
443 class PreparedStatement(object):
444
445     ##
446     # Determines the number of rows to read from the database server at once.
447     # Reading more rows increases performance at the cost of memory.  The
448     # default value is 100 rows.  The affect of this parameter is transparent.
449     # That is, the library reads more rows when the cache is empty
450     # automatically.
451     # <p>
452     # Stability: Added in v1.00, stability guaranteed for v1.xx.  It is
453     # possible that implementation changes in the future could cause this
454     # parameter to be ignored.O
455     row_cache_size = 100
456
457     def __init__(self, connection, statement, *types):
458         self.c = connection.c
459         self._portal_name = "pg8000_portal_%s_%s" % (id(self.c), id(self))
460         self._statement_name = "pg8000_statement_%s_%s" % (id(self.c), id(self))
461         self._row_desc = None
462         self._cached_rows = []
463         self._command_complete = True
464         self._parse_row_desc = self.c.parse(self._statement_name, statement, types)
465         self._lock = threading.RLock()
466
467     def __del__(self):
468         # This __del__ should work with garbage collection / non-instant
469         # cleanup.  It only really needs to be called right away if the same
470         # object id (and therefore the same statement name) might be reused
471         # soon, and clearly that wouldn't happen in a GC situation.
472         self.c.close_statement(self._statement_name)
473
474     row_description = property(lambda self: self._getRowDescription())
475     def _getRowDescription(self):
476         if self._row_desc == None:
477             return None
478         return self._row_desc.fields
479
480     ##
481     # Run the SQL prepared statement with the given parameters.
482     # <p>
483     # Stability: Added in v1.00, stability guaranteed for v1.xx.
484     def execute(self, *args):
485         self._lock.acquire()
486         try:
487             if not self._command_complete:
488                 # cleanup last execute
489                 self._cached_rows = []
490                 self.c.close_portal(self._portal_name)
491             self._command_complete = False
492             self._row_desc = self.c.bind(self._portal_name, self._statement_name, args, self._parse_row_desc)
493             if self._row_desc:
494                 # We execute our cursor right away to fill up our cache.  This
495                 # prevents the cursor from being destroyed, apparently, by a rogue
496                 # Sync between Bind and Execute.  Since it is quite likely that
497                 # data will be read from us right away anyways, this seems a safe
498                 # move for now.
499                 self._fill_cache()
500         finally:
501             self._lock.release()
502
503     def _fill_cache(self):
504         self._lock.acquire()
505         try:
506             if self._cached_rows:
507                 raise InternalError("attempt to fill cache that isn't empty")
508             end_of_data, rows = self.c.fetch_rows(self._portal_name, self.row_cache_size, self._row_desc)
509             self._cached_rows = rows
510             if end_of_data:
511                 self._command_complete = True
512         finally:
513             self._lock.release()
514
515     def _fetch(self):
516         self._lock.acquire()
517         try:
518             if not self._cached_rows:
519                 if self._command_complete:
520                     return None
521                 self._fill_cache()
522                 if self._command_complete and not self._cached_rows:
523                     # fill cache tells us the command is complete, but yet we have
524                     # no rows after filling our cache.  This is a special case when
525                     # a query returns no rows.
526                     return None
527             row = self._cached_rows[0]
528             del self._cached_rows[0]
529             return tuple(row)
530         finally:
531             self._lock.release()
532
533     ##
534     # Read a row from the database server, and return it in a dictionary
535     # indexed by column name/alias.  This method will raise an error if two
536     # columns have the same name.  Returns None after the last row.
537     # <p>
538     # Stability: Added in v1.00, stability guaranteed for v1.xx.
539     def read_dict(self):
540         row = self._fetch()
541         if row == None:
542             return row
543         retval = {}
544         for i in range(len(self._row_desc.fields)):
545             col_name = self._row_desc.fields[i]['name']
546             if retval.has_key(col_name):
547                 raise InterfaceError("cannot return dict of row when two columns have the same name (%r)" % (col_name,))
548             retval[col_name] = row[i]
549         return retval
550
551     ##
552     # Read a row from the database server, and return it as a tuple of values.
553     # Returns None after the last row.
554     # <p>
555     # Stability: Added in v1.00, stability guaranteed for v1.xx.
556     def read_tuple(self):
557         row = self._fetch()
558         if row == None:
559             return row
560         return row
561
562     ##
563     # Return an iterator for the output of this statement.  The iterator will
564     # return a tuple for each row, in the same manner as {@link
565     # #PreparedStatement.read_tuple read_tuple}.
566     # <p>
567     # Stability: Added in v1.00, stability guaranteed for v1.xx.
568     def iterate_tuple(self):
569         return DataIterator(self, PreparedStatement.read_tuple)
570
571     ##
572     # Return an iterator for the output of this statement.  The iterator will
573     # return a dict for each row, in the same manner as {@link
574     # #PreparedStatement.read_dict read_dict}.
575     # <p>
576     # Stability: Added in v1.00, stability guaranteed for v1.xx.
577     def iterate_dict(self):
578         return DataIterator(self, PreparedStatement.read_dict)
579
580 ##
581 # The Cursor class allows multiple queries to be performed concurrently with a
582 # single PostgreSQL connection.  The Cursor object is implemented internally by
583 # using a {@link PreparedStatement PreparedStatement} object, so if you plan to
584 # use a statement multiple times, you might as well create a PreparedStatement
585 # and save a small amount of reparsing time.
586 # <p>
587 # As of v1.01, instances of this class are thread-safe.  See {@link
588 # PreparedStatement PreparedStatement} for more information.
589 # <p>
590 # Stability: Added in v1.00, stability guaranteed for v1.xx.
591 #
592 # @param connection     An instance of {@link Connection Connection}.
593 class Cursor(object):
594     def __init__(self, connection):
595         self.connection = connection
596         self._stmt = None
597
598     row_description = property(lambda self: self._getRowDescription())
599     def _getRowDescription(self):
600         if self._stmt == None:
601             return None
602         return self._stmt.row_description
603
604     ##
605     # Run an SQL statement using this cursor.  The SQL statement can have
606     # parameters in the form of $1, $2, $3, etc., which will be filled in by
607     # the additional arguments passed to this function.
608     # <p>
609     # Stability: Added in v1.00, stability guaranteed for v1.xx.
610     # @param query      The SQL statement to execute.
611     def execute(self, query, *args):
612         self._stmt = PreparedStatement(self.connection, query, *[type(x) for x in args])
613         self._stmt.execute(*args)
614
615     ##
616     # Read a row from the database server, and return it in a dictionary
617     # indexed by column name/alias.  This method will raise an error if two
618     # columns have the same name.  Returns None after the last row.
619     # <p>
620     # Stability: Added in v1.00, stability guaranteed for v1.xx.
621     def read_dict(self):
622         if self._stmt == None:
623             raise ProgrammingError("attempting to read from unexecuted cursor")
624         return self._stmt.read_dict()
625
626     ##
627     # Read a row from the database server, and return it as a tuple of values.
628     # Returns None after the last row.
629     # <p>
630     # Stability: Added in v1.00, stability guaranteed for v1.xx.
631     def read_tuple(self):
632         if self._stmt == None:
633             raise ProgrammingError("attempting to read from unexecuted cursor")
634         return self._stmt.read_tuple()
635
636     ##
637     # Return an iterator for the output of this statement.  The iterator will
638     # return a tuple for each row, in the same manner as {@link
639     # #PreparedStatement.read_tuple read_tuple}.
640     # <p>
641     # Stability: Added in v1.00, stability guaranteed for v1.xx.
642     def iterate_tuple(self):
643         if self._stmt == None:
644             raise ProgrammingError("attempting to read from unexecuted cursor")
645         return self._stmt.iterate_tuple()
646
647     ##
648     # Return an iterator for the output of this statement.  The iterator will
649     # return a dict for each row, in the same manner as {@link
650     # #PreparedStatement.read_dict read_dict}.
651     # <p>
652     # Stability: Added in v1.00, stability guaranteed for v1.xx.
653     def iterate_dict(self):
654         if self._stmt == None:
655             raise ProgrammingError("attempting to read from unexecuted cursor")
656         return self._stmt.iterate_dict()
657
658 ##
659 # This class represents a connection to a PostgreSQL database.
660 # <p>
661 # The database connection is derived from the {@link #Cursor Cursor} class,
662 # which provides a default cursor for running queries.  It also provides
663 # transaction control via the 'begin', 'commit', and 'rollback' methods.
664 # Without beginning a transaction explicitly, all statements will autocommit to
665 # the database.
666 # <p>
667 # As of v1.01, instances of this class are thread-safe.  See {@link
668 # PreparedStatement PreparedStatement} for more information.
669 # <p>
670 # Stability: Added in v1.00, stability guaranteed for v1.xx.
671 #
672 # @param user   The username to connect to the PostgreSQL server with.  This
673 # parameter is required.
674 #
675 # @keyparam host   The hostname of the PostgreSQL server to connect with.
676 # Providing this parameter is necessary for TCP/IP connections.  One of either
677 # host, or unix_sock, must be provided.
678 #
679 # @keyparam unix_sock   The path to the UNIX socket to access the database
680 # through, for example, '/tmp/.s.PGSQL.5432'.  One of either unix_sock or host
681 # must be provided.  The port parameter will have no affect if unix_sock is
682 # provided.
683 #
684 # @keyparam port   The TCP/IP port of the PostgreSQL server instance.  This
685 # parameter defaults to 5432, the registered and common port of PostgreSQL
686 # TCP/IP servers.
687 #
688 # @keyparam database   The name of the database instance to connect with.  This
689 # parameter is optional, if omitted the PostgreSQL server will assume the
690 # database name is the same as the username.
691 #
692 # @keyparam password   The user password to connect to the server with.  This
693 # parameter is optional.  If omitted, and the database server requests password
694 # based authentication, the connection will fail.  On the other hand, if this
695 # parameter is provided and the database does not request password
696 # authentication, then the password will not be used.
697 #
698 # @keyparam socket_timeout  Socket connect timeout measured in seconds.
699 # Defaults to 60 seconds.
700 #
701 # @keyparam ssl     Use SSL encryption for TCP/IP socket.  Defaults to False.
702 class Connection(Cursor):
703     def __init__(self, user, host=None, unix_sock=None, port=5432, database=None, password=None, socket_timeout=60, ssl=False):
704         self._row_desc = None
705         try:
706             self.c = Protocol.Connection(unix_sock=unix_sock, host=host, port=port, socket_timeout=socket_timeout, ssl=ssl)
707             #self.c.connect()
708             self.c.authenticate(user, password=password, database=database)
709         except socket.error, e:
710             raise InterfaceError("communication error", e)
711         Cursor.__init__(self, self)
712         self._begin = PreparedStatement(self, "BEGIN TRANSACTION")
713         self._commit = PreparedStatement(self, "COMMIT TRANSACTION")
714         self._rollback = PreparedStatement(self, "ROLLBACK TRANSACTION")
715
716     ##
717     # Begins a new transaction.
718     # <p>
719     # Stability: Added in v1.00, stability guaranteed for v1.xx.
720     def begin(self):
721         self._begin.execute()
722
723     ##
724     # Commits the running transaction.
725     # <p>
726     # Stability: Added in v1.00, stability guaranteed for v1.xx.
727     def commit(self):
728         self._commit.execute()
729
730     ##
731     # Rolls back the running transaction.
732     # <p>
733     # Stability: Added in v1.00, stability guaranteed for v1.xx.
734     def rollback(self):
735         self._rollback.execute()
736
737
738 class Protocol(object):
739
740     class SSLRequest(object):
741         def __init__(self):
742             pass
743
744         def serialize(selF):
745             return struct.pack("!ii", 8, 80877103)
746
747     class StartupMessage(object):
748         def __init__(self, user, database=None):
749             self.user = user
750             self.database = database
751
752         def serialize(self):
753             protocol = 196608
754             val = struct.pack("!i", protocol)
755             val += "user\x00" + self.user + "\x00"
756             if self.database:
757                 val += "database\x00" + self.database + "\x00"
758             val += "\x00"
759             val = struct.pack("!i", len(val) + 4) + val
760             return val
761
762     class Query(object):
763         def __init__(self, qs):
764             self.qs = qs
765
766         def serialize(self):
767             val = self.qs + "\x00"
768             val = struct.pack("!i", len(val) + 4) + val
769             val = "Q" + val
770             return val
771
772     class Parse(object):
773         def __init__(self, ps, qs, type_oids):
774             self.ps = ps
775             self.qs = qs
776             self.type_oids = type_oids
777
778         def serialize(self):
779             val = self.ps + "\x00" + self.qs + "\x00"
780             val = val + struct.pack("!h", len(self.type_oids))
781             for oid in self.type_oids:
782                 # Parse message doesn't seem to handle the -1 type_oid for NULL
783                 # values that other messages handle.  So we'll provide type_oid 705,
784                 # the PG "unknown" type.
785                 if oid == -1: oid = 705
786                 val = val + struct.pack("!i", oid)
787             val = struct.pack("!i", len(val) + 4) + val
788             val = "P" + val
789             return val
790
791     class Bind(object):
792         def __init__(self, portal, ps, in_fc, params, out_fc, client_encoding):
793             self.portal = portal
794             self.ps = ps
795             self.in_fc = in_fc
796             self.params = []
797             for i in range(len(params)):
798                 if len(self.in_fc) == 0:
799                     fc = 0
800                 elif len(self.in_fc) == 1:
801                     fc = self.in_fc[0]
802                 else:
803                     fc = self.in_fc[i]
804                 self.params.append(Types.pg_value(params[i], fc, client_encoding = client_encoding))
805             self.out_fc = out_fc
806
807         def serialize(self):
808             val = self.portal + "\x00" + self.ps + "\x00"
809             val = val + struct.pack("!h", len(self.in_fc))
810             for fc in self.in_fc:
811                 val = val + struct.pack("!h", fc)
812             val = val + struct.pack("!h", len(self.params))
813             for param in self.params:
814                 if param == None:
815                     # special case, NULL value
816                     val = val + struct.pack("!i", -1)
817                 else:
818                     val = val + struct.pack("!i", len(param)) + param
819             val = val + struct.pack("!h", len(self.out_fc))
820             for fc in self.out_fc:
821                 val = val + struct.pack("!h", fc)
822             val = struct.pack("!i", len(val) + 4) + val
823             val = "B" + val
824             return val
825
826     class Close(object):
827         def __init__(self, typ, name):
828             if len(typ) != 1:
829                 raise InternalError("Close typ must be 1 char")
830             self.typ = typ
831             self.name = name
832
833         def serialize(self):
834             val = self.typ + self.name + "\x00"
835             val = struct.pack("!i", len(val) + 4) + val
836             val = "C" + val
837             return val
838
839     class ClosePortal(Close):
840         def __init__(self, name):
841             Protocol.Close.__init__(self, "P", name)
842
843     class ClosePreparedStatement(Close):
844         def __init__(self, name):
845             Protocol.Close.__init__(self, "S", name)
846
847     class Describe(object):
848         def __init__(self, typ, name):
849             if len(typ) != 1:
850                 raise InternalError("Describe typ must be 1 char")
851             self.typ = typ
852             self.name = name
853
854         def serialize(self):
855             val = self.typ + self.name + "\x00"
856             val = struct.pack("!i", len(val) + 4) + val
857             val = "D" + val
858             return val
859
860     class DescribePortal(Describe):
861         def __init__(self, name):
862             Protocol.Describe.__init__(self, "P", name)
863
864     class DescribePreparedStatement(Describe):
865         def __init__(self, name):
866             Protocol.Describe.__init__(self, "S", name)
867
868     class Flush(object):
869         def serialize(self):
870             return 'H\x00\x00\x00\x04'
871
872     class Sync(object):
873         def serialize(self):
874             return 'S\x00\x00\x00\x04'
875
876     class PasswordMessage(object):
877         def __init__(self, pwd):
878             self.pwd = pwd
879
880         def serialize(self):
881             val = self.pwd + "\x00"
882             val = struct.pack("!i", len(val) + 4) + val
883             val = "p" + val
884             return val
885
886     class Execute(object):
887         def __init__(self, portal, row_count):
888             self.portal = portal
889             self.row_count = row_count
890
891         def serialize(self):
892             val = self.portal + "\x00" + struct.pack("!i", self.row_count)
893             val = struct.pack("!i", len(val) + 4) + val
894             val = "E" + val
895             return val
896
897     class AuthenticationRequest(object):
898         def __init__(self, data):
899             pass
900
901         def createFromData(data):
902             ident = struct.unpack("!i", data[:4])[0]
903             klass = Protocol.authentication_codes.get(ident, None)
904             if klass != None:
905                 return klass(data[4:])
906             else:
907                 raise NotSupportedError("authentication method %r not supported" % (ident,))
908         createFromData = staticmethod(createFromData)
909
910         def ok(self, conn, user, **kwargs):
911             raise InternalError("ok method should be overridden on AuthenticationRequest instance")
912
913     class AuthenticationOk(AuthenticationRequest):
914         def ok(self, conn, user, **kwargs):
915             return True
916
917     class AuthenticationMD5Password(AuthenticationRequest):
918         def __init__(self, data):
919             self.salt = "".join(struct.unpack("4c", data))
920
921         def ok(self, conn, user, password=None, **kwargs):
922             if password == None:
923                 raise InterfaceError("server requesting MD5 password authentication, but no password was provided")
924             pwd = "md5" + md5.new(md5.new(password + user).hexdigest() + self.salt).hexdigest()
925             conn._send(Protocol.PasswordMessage(pwd))
926             msg = conn._read_message()
927             if isinstance(msg, Protocol.AuthenticationRequest):
928                 return msg.ok(conn, user)
929             elif isinstance(msg, Protocol.ErrorResponse):
930                 if msg.code == "28000":
931                     raise InterfaceError("md5 password authentication failed")
932                 else:
933                     raise InternalError("server returned unexpected error %r" % msg)
934             else:
935                 raise InternalError("server returned unexpected response %r" % msg)
936
937     authentication_codes = {
938         0: AuthenticationOk,
939         5: AuthenticationMD5Password,
940     }
941
942     class ParameterStatus(object):
943         def __init__(self, key, value):
944             self.key = key
945             self.value = value
946
947         def createFromData(data):
948             key = data[:data.find("\x00")]
949             value = data[data.find("\x00")+1:-1]
950             return Protocol.ParameterStatus(key, value)
951         createFromData = staticmethod(createFromData)
952
953     class BackendKeyData(object):
954         def __init__(self, process_id, secret_key):
955             self.process_id = process_id
956             self.secret_key = secret_key
957
958         def createFromData(data):
959             process_id, secret_key = struct.unpack("!2i", data)
960             return Protocol.BackendKeyData(process_id, secret_key)
961         createFromData = staticmethod(createFromData)
962
963     class NoData(object):
964         def createFromData(data):
965             return Protocol.NoData()
966         createFromData = staticmethod(createFromData)
967
968     class ParseComplete(object):
969         def createFromData(data):
970             return Protocol.ParseComplete()
971         createFromData = staticmethod(createFromData)
972
973     class BindComplete(object):
974         def createFromData(data):
975             return Protocol.BindComplete()
976         createFromData = staticmethod(createFromData)
977
978     class CloseComplete(object):
979         def createFromData(data):
980             return Protocol.CloseComplete()
981         createFromData = staticmethod(createFromData)
982
983     class PortalSuspended(object):
984         def createFromData(data):
985             return Protocol.PortalSuspended()
986         createFromData = staticmethod(createFromData)
987
988     class ReadyForQuery(object):
989         def __init__(self, status):
990             self.status = status
991
992         def __repr__(self):
993             return "<ReadyForQuery %s>" % \
994                     {"I": "Idle", "T": "Idle in Transaction", "E": "Idle in Failed Transaction"}[self.status]
995
996         def createFromData(data):
997             return Protocol.ReadyForQuery(data)
998         createFromData = staticmethod(createFromData)
999
1000     class NoticeResponse(object):
1001         def __init__(self):
1002             pass
1003         def createFromData(data):
1004             # we could read the notice here, but we don't care yet.
1005             return Protocol.NoticeResponse()
1006         createFromData = staticmethod(createFromData)
1007
1008     class ErrorResponse(object):
1009         def __init__(self, severity, code, msg):
1010             self.severity = severity
1011             self.code = code
1012             self.msg = msg
1013
1014         def __repr__(self):
1015             return "<ErrorResponse %s %s %r>" % (self.severity, self.code, self.msg)
1016
1017         def createException(self):
1018             return ProgrammingError(self.severity, self.code, self.msg)
1019
1020         def createFromData(data):
1021             args = {}
1022             for s in data.split("\x00"):
1023                 if not s:
1024                     continue
1025                 elif s[0] == "S":
1026                     args["severity"] = s[1:]
1027                 elif s[0] == "C":
1028                     args["code"] = s[1:]
1029         &nbs