Changeset 812

Show
Ignore:
Timestamp:
03/08/07 22:49:41 (2 years ago)
Author:
mfenniak
Message:

Complete prepared statement work

Files:

Legend:

Unmodified
Added
Removed
Modified
Copied
Moved
  • pg8000/trunk/pg8000.py

    r811 r812  
    6868 
    6969class DataIterator(object): 
    70     def __init__(self, connection): 
    71         self.connection = connection 
    72         if self.connection.iterate_dicts: 
    73             self.method = PreparedStatement.read_dict 
    74         else: 
    75             self.method = PreparedStatement.read_tuple 
     70    def __init__(self, obj, func): 
     71        self.obj = obj 
     72        self.func = func 
    7673 
    7774    def __iter__(self): 
     
    7976 
    8077    def next(self): 
    81         retval = self.method(self.connection
     78        retval = self.func(self.obj
    8279        if retval == None: 
    8380            raise StopIteration() 
    8481        return retval 
    8582 
     83## 
     84# This class represents a prepared statement.  A prepared statement is 
     85# pre-parsed on the server, which reduces the need to parse the query every 
     86# time it is run.  The statement can have parameters in the form of $1, $2, $3, 
     87# etc.  When parameters are used, the types of the parameters need to be 
     88# specified when creating the prepared statement. 
     89# <p> 
     90# Stability: Added in v1.00, stability guaranteed for v1.xx. 
     91# 
     92# @param connection     An instance of {@link Connection Connection}. 
     93# 
     94# @param statement      The SQL statement to be represented, often containing 
     95# parameters in the form of $1, $2, $3, etc. 
     96# 
     97# @param types          Python type objects for each parameter in the SQL 
     98# statement.  For example, int, float, str. 
    8699class PreparedStatement(object): 
    87     ## 
    88     # A configuration variable that determines whether iterating over the 
    89     # connection will return tuples of queried rows (False), or dictionaries 
    90     # indexed by column name/alias (True).  By default, this variable value is 
    91     # copied from the connection's iterate_dicts value. 
    92     # <p> 
    93     # Stability: Added in v1.00, stability guaranteed for v1.xx. 
    94     iterate_dicts = False 
    95  
     100 
     101    ## 
     102    # Determines the number of rows to read from the database server at once. 
     103    # Reading more rows increases performance at the cost of memory.  The 
     104    # default value is 100 rows.  The affect of this parameter is transparent. 
     105    # That is, the library reads more rows when the cache is empty 
     106    # automatically. 
     107    # <p> 
     108    # Stability: Added in v1.00, stability guaranteed for v1.xx.  It is 
     109    # possible that implementation changes in the future could cause this 
     110    # parameter to be ignored.O 
    96111    row_cache_size = 100 
    97112 
     
    105120        self._parse_row_desc = self.c.parse(self._statement_name, statement, types) 
    106121 
     122    def __del__(self): 
     123        # This __del__ should work with garbage collection / non-instant 
     124        # cleanup.  It only really needs to be called right away if the same 
     125        # object id (and therefore the same statement name) might be reused 
     126        # soon, and clearly that wouldn't happen in a GC situation. 
     127        self.c.close_statement(self._statement_name) 
     128 
     129    ## 
     130    # Run the SQL prepared statement with the given parameters. 
     131    # <p> 
     132    # Stability: Added in v1.00, stability guaranteed for v1.xx. 
    107133    def execute(self, *args): 
    108134        if not self._command_complete: 
    109135            # cleanup last execute 
    110136            self._cached_rows = [] 
    111             self.c.close(self._portal_name) 
     137            self.c.close_portal(self._portal_name) 
    112138        self._command_complete = False 
    113139        self._row_desc = self.c.bind(self._portal_name, self._statement_name, args, self._parse_row_desc) 
     140        if self._row_desc: 
     141            # We execute our cursor right away to fill up our cache.  This 
     142            # prevents the cursor from being destroyed, apparently, by a rogue 
     143            # Sync between Bind and Execute.  Since it is quite likely that 
     144            # data will be read from us right away anyways, this seems a safe 
     145            # move for now. 
     146            self._fill_cache() 
     147 
     148    def _fill_cache(self): 
     149        if self._cached_rows: 
     150            raise InternalError("attempt to fill cache that isn't empty") 
     151        end_of_data, rows = self.c.fetch_rows(self._portal_name, self.row_cache_size, self._row_desc) 
     152        self._cached_rows = rows 
     153        if end_of_data: 
     154            self._command_complete = True 
    114155 
    115156    def _fetch(self): 
     
    117158            if self._command_complete: 
    118159                return None 
    119             end_of_data, rows = self.c.fetch_rows(self._portal_name, self.row_cache_size, self._row_desc) 
    120             self._cached_rows = rows 
    121             if end_of_data: 
    122                 self._command_complete = True 
    123                 if not rows: 
    124                     # special case - an empty query, hit end_of_data and no 
    125                     # rows at the same time 
    126                     return None 
     160            self._fill_cache() 
     161            if self._command_complete and not self._cached_rows: 
     162                # fill cache tells us the command is complete, but yet we have 
     163                # no rows after filling our cache.  This is a special case when 
     164                # a query returns no rows. 
     165                return None 
    127166        row = self._cached_rows[0] 
    128167        del self._cached_rows[0] 
     
    132171    # Read a row from the database server, and return it in a dictionary 
    133172    # indexed by column name/alias.  This method will raise an error if two 
    134     # columns have the same name. 
     173    # columns have the same name.  Returns None after the last row. 
    135174    # <p> 
    136175    # Stability: Added in v1.00, stability guaranteed for v1.xx. 
     
    149188    ## 
    150189    # Read a row from the database server, and return it as a tuple of values. 
     190    # Returns None after the last row. 
    151191    # <p> 
    152192    # Stability: Added in v1.00, stability guaranteed for v1.xx. 
     
    158198 
    159199    ## 
    160     # Iterate over query results.  The behaviour of iterating over this object 
    161     # is dependent upon the value of the {@link #Connection.iterate_dicts 
    162     # iterate_dicts} variable. 
    163     # <p> 
    164     # Stability: Added in v1.00, stability guaranteed for v1.xx. 
    165     def __iter__(self): 
    166         return DataIterator(self) 
    167  
    168  
     200    # Return an iterator for the output of this statement.  The iterator will 
     201    # return a tuple for each row, in the same manner as {@link 
     202    # #PreparedStatement.read_tuple read_tuple}. 
     203    # <p> 
     204    # Stability: Added in v1.00, stability guaranteed for v1.xx. 
     205    def iterate_tuple(self): 
     206        return DataIterator(self, PreparedStatement.read_tuple) 
     207 
     208    ## 
     209    # Return an iterator for the output of this statement.  The iterator will 
     210    # return a dict for each row, in the same manner as {@link 
     211    # #PreparedStatement.read_dict read_dict}. 
     212    # <p> 
     213    # Stability: Added in v1.00, stability guaranteed for v1.xx. 
     214    def iterate_dict(self): 
     215        return DataIterator(self, PreparedStatement.read_dict) 
     216 
     217## 
     218# The Cursor class allows multiple queries to be performed concurrently with a 
     219# single PostgreSQL connection.  The Cursor object is implemented internally by 
     220# using a {@link PreparedStatement PreparedStatement} object, so if you plan to 
     221# use a statement multiple times, you might as well create a PreparedStatement 
     222# and save a small amount of reparsing time. 
     223# <p> 
     224# Stability: Added in v1.00, stability guaranteed for v1.xx. 
     225
     226# @param connection     An instance of {@link Connection Connection}. 
    169227class Cursor(object): 
    170228    def __init__(self, connection): 
     
    172230        self._stmt = None 
    173231 
     232    ## 
     233    # Run an SQL statement using this cursor.  The SQL statement can have 
     234    # parameters in the form of $1, $2, $3, etc., which will be filled in by 
     235    # the additional arguments passed to this function. 
     236    # <p> 
     237    # Stability: Added in v1.00, stability guaranteed for v1.xx. 
     238    # @param query      The SQL statement to execute. 
    174239    def execute(self, query, *args): 
    175240        self._stmt = PreparedStatement(self.connection, query, *[type(x) for x in args]) 
    176241        self._stmt.execute(*args) 
    177242 
     243    ## 
     244    # Read a row from the database server, and return it in a dictionary 
     245    # indexed by column name/alias.  This method will raise an error if two 
     246    # columns have the same name.  Returns None after the last row. 
     247    # <p> 
     248    # Stability: Added in v1.00, stability guaranteed for v1.xx. 
    178249    def read_dict(self): 
    179250        if self._stmt == None: 
    180             return None 
     251            raise ProgrammingError("attempting to read from unexecuted cursor") 
    181252        return self._stmt.read_dict() 
    182253 
     254    ## 
     255    # Read a row from the database server, and return it as a tuple of values. 
     256    # Returns None after the last row. 
     257    # <p> 
     258    # Stability: Added in v1.00, stability guaranteed for v1.xx. 
    183259    def read_tuple(self): 
    184260        if self._stmt == None: 
    185             return None 
     261            raise ProgrammingError("attempting to read from unexecuted cursor") 
    186262        return self._stmt.read_tuple() 
    187263 
     264    ## 
     265    # Return an iterator for the output of this statement.  The iterator will 
     266    # return a tuple for each row, in the same manner as {@link 
     267    # #PreparedStatement.read_tuple read_tuple}. 
     268    # <p> 
     269    # Stability: Added in v1.00, stability guaranteed for v1.xx. 
     270    def iterate_tuple(self): 
     271        if self._stmt == None: 
     272            raise ProgrammingError("attempting to read from unexecuted cursor") 
     273        return self._stmt.iterate_tuple() 
     274 
     275    ## 
     276    # Return an iterator for the output of this statement.  The iterator will 
     277    # return a dict for each row, in the same manner as {@link 
     278    # #PreparedStatement.read_dict read_dict}. 
     279    # <p> 
     280    # Stability: Added in v1.00, stability guaranteed for v1.xx. 
     281    def iterate_dict(self): 
     282        if self._stmt == None: 
     283            raise ProgrammingError("attempting to read from unexecuted cursor") 
     284        return self._stmt.iterate_dict() 
    188285 
    189286## 
    190287# This class represents a connection to a PostgreSQL database. 
    191288# <p> 
    192 # The database connection is derived from the {@link #Cursor Cursor} class, and 
    193 # provides access to the database's unnamed cursor through the standard Cursor 
    194 # methods.  It also provides transaction control via the 'begin', 'commit', and 
    195 # 'rollback' methods.  Without beginning a transaction explicitly, all 
    196 # statements will autocommit to the database. 
     289# The database connection is derived from the {@link #Cursor Cursor} class, 
     290# which provides a default cursor for running queries.  It also provides 
     291# transaction control via the 'begin', 'commit', and 'rollback' methods. 
     292# Without beginning a transaction explicitly, all statements will autocommit to 
     293# the database. 
    197294# <p> 
    198295# Stability: Added in v1.00, stability guaranteed for v1.xx. 
     
    221318# Defaults to 60 seconds. 
    222319class Connection(Cursor): 
    223  
    224     ## 
    225     # A configuration variable that determines whether iterating over the 
    226     # connection will return tuples of queried rows (False), or dictionaries 
    227     # indexed by column name/alias (True).  By default, this variable is set to 
    228     # False. 
    229     # <p> 
    230     # Stability: Added in v1.00, stability guaranteed for v1.xx. 
    231     iterate_dicts = False 
    232  
    233320    def __init__(self, host, user, port=5432, database=None, password=None, socket_timeout=60): 
    234321        self._row_desc = None 
     
    240327            raise InterfaceError("communication error", e) 
    241328        Cursor.__init__(self, self) 
    242  
     329        self._begin = PreparedStatement(self, "BEGIN TRANSACTION") 
     330        self._commit = PreparedStatement(self, "COMMIT TRANSACTION") 
     331        self._rollback = PreparedStatement(self, "ROLLBACK TRANSACTION") 
     332 
     333    ## 
     334    # Begins a new transaction. 
     335    # <p> 
     336    # Stability: Added in v1.00, stability guaranteed for v1.xx. 
    243337    def begin(self): 
    244         raise NotSupportedError("uncoded") 
    245  
     338        self._begin.execute() 
     339 
     340    ## 
     341    # Commits the running transaction. 
     342    # <p> 
     343    # Stability: Added in v1.00, stability guaranteed for v1.xx. 
    246344    def commit(self): 
    247         raise NotSupportedError("uncoded") 
    248  
     345        self._commit.execute() 
     346 
     347    ## 
     348    # Rolls back the running transaction. 
     349    # <p> 
     350    # Stability: Added in v1.00, stability guaranteed for v1.xx. 
    249351    def rollback(self): 
    250         raise NotSupportedError("uncoded"
     352        self._rollback.execute(
    251353 
    252354 
     
    600702 
    601703        def _send(self, msg): 
    602             self._sock.send(msg.serialize()) 
     704            #print repr(msg) 
     705            data = msg.serialize() 
     706            self._sock.send(data) 
    603707 
    604708        def _read_message(self): 
     
    607711            message_code = bytes[0] 
    608712            data_len = struct.unpack("!i", bytes[1:])[0] - 4 
    609             bytes = self._sock.recv(data_len) 
     713            if data_len == 0: 
     714                bytes = "" 
     715            else: 
     716                bytes = self._sock.recv(data_len) 
    610717            msg = Protocol.message_types[message_code].createFromData(bytes) 
    611718            if isinstance(msg, Protocol.NoticeResponse): 
     
    687794                output_fc = [Types.py_type_info(f) for f in row_desc.fields] 
    688795            self._send(Protocol.Bind(portal, statement, param_fc, params, output_fc, self._client_encoding)) 
    689             # I don't know why we need to send DescribePortal again, but without it, 
    690             # we don't receive our BindComplete.  It's like Flush fails to work. 
     796            # We need to describe the portal after bind, since the return 
     797            # format codes will be different (hopefully, always what we 
     798            # requested). 
    691799            self._send(Protocol.DescribePortal(portal)) 
    692800            self._send(Protocol.Flush()) 
     
    762870            return end_of_data, rows 
    763871 
    764         def close(self, portal): 
     872        def close_statement(self, statement): 
     873            self._send(Protocol.ClosePreparedStatement(statement)) 
     874            self._send(Protocol.Sync()) 
     875            while 1: 
     876                msg = self._read_message() 
     877                if isinstance(msg, Protocol.CloseComplete): 
     878                    # thanks! 
     879                    pass 
     880                elif isinstance(msg, Protocol.ReadyForQuery): 
     881                    return 
     882                elif isinstance(msg, Protocol.ErrorResponse): 
     883                    raise msg.createException() 
     884                else: 
     885                    raise InternalError("Unexpected response msg %r" % msg) 
     886 
     887        def close_portal(self, portal): 
    765888            self._send(Protocol.ClosePortal(portal)) 
    766889            self._send(Protocol.Sync()) 
  • pg8000/trunk/pg8000-test.py

    r811 r812  
    88#db = pg8000.Connection(host='joy.fenniak.net', user='Mathieu Fenniak', database="software", password="hello", socket_timeout=5) 
    99db = pg8000.Connection(host='localhost', user='mfenniak') 
    10 db.iterate_dicts = True 
    1110 
    12 #s1 = pg8000.PreparedStatement(db, "INSERT INTO t1 (f1, f2, f3) VALUES ($1, $2, $3)", int, int, str) 
    13 s1 = pg8000.PreparedStatement(db, "SELECT * FROM t1 WHERE f1 = $1", int) 
    14 s1.execute(5) 
    15 for row in s1: 
    16     print repr(row) 
    17 s1.execute(2) 
    18 for row in s1: 
    19     print repr(row) 
     11db.execute("DROP TABLE t1") 
     12db.execute("CREATE TABLE t1 (f1 int primary key, f2 int not null, f3 varchar(50) not null)") 
    2013 
    21  
    22 import sys 
    23 sys.exit(0) 
    24  
    25 cur1 = pg8000.Cursor(db) 
    26  
    27 cur1.execute("DROP TABLE t1") 
    28 cur1.execute("CREATE TABLE t1 (f1 int primary key, f2 int not null, f3 varchar(50) not null)") 
    29 cur1.execute("INSERT INTO t1 (f1, f2, f3) VALUES ($1, $2, $3)", 1, 1, "hello") 
    30 cur1.execute("INSERT INTO t1 (f1, f2, f3) VALUES ($1, $2, $3)", 2, 10, u"he\u0173llo") 
    31 cur1.execute("INSERT INTO t1 (f1, f2, f3) VALUES ($1, $2, $3)", 3, 100, "hello") 
    32 cur1.execute("INSERT INTO t1 (f1, f2, f3) VALUES ($1, $2, $3)", 4, 1000, "hello") 
    33 cur1.execute("INSERT INTO t1 (f1, f2, f3) VALUES ($1, $2, $3)", 5, 10000, "hello") 
     14s1 = pg8000.PreparedStatement(db, "INSERT INTO t1 (f1, f2, f3) VALUES ($1, $2, $3)", int, int, str) 
     15s1.execute(1, 1, "hello") 
     16s1.execute(2, 10, "he\u0173llo") 
     17s1.execute(3, 100, "hello") 
     18s1.execute(4, 1000, "hello") 
     19s1.execute(5, 10000, "hello") 
     20s1.execute(6, 100000, "hello") 
    3421 
    3522print "begin query..." 
     23db.execute("SELECT * FROM t1") 
     24for row in db.iterate_dict(): 
     25    print repr(row) 
     26print "end query..." 
     27 
     28print "begin query..." 
     29cur1 = pg8000.Cursor(db) 
    3630cur1.execute("SELECT * FROM t1") 
     31s1 = pg8000.PreparedStatement(db, "SELECT * FROM t1 WHERE f1 > $1", int) 
    3732i = 0 
    38 for row1 in cur1
     33for row1 in cur1.iterate_dict()
    3934    i = i + 1 
    4035    print i, repr(row1) 
    41     db.execute("SELECT * FROM t1 WHERE f1 > $1", row1['f1']) 
    42     for row2 in db
     36    s1.execute(row1['f1']) 
     37    for row2 in s1.iterate_dict()
    4338        print "\t", repr(row2) 
    4439print "end query..." 
     
    4742 
    4843cur1.execute("SELECT $1", 5) 
    49 assert tuple(cur1) == ({"?column?": 5},) 
     44assert tuple(cur1.iterate_dict()) == ({"?column?": 5},) 
    5045 
    5146cur1.execute("SELECT 5000::smallint") 
    52 assert tuple(cur1) == ({"int2": 5000},) 
     47assert tuple(cur1.iterate_dict()) == ({"int2": 5000},) 
    5348 
    5449cur1.execute("SELECT 5000::integer") 
    55 assert tuple(cur1) == ({"int4": 5000},) 
     50assert tuple(cur1.iterate_dict()) == ({"int4": 5000},) 
    5651 
    5752cur1.execute("SELECT 50000000000000::bigint") 
    58 assert tuple(cur1) == ({"int8": 50000000000000},) 
     53assert tuple(cur1.iterate_dict()) == ({"int8": 50000000000000},) 
    5954 
    6055cur1.execute("SELECT 5000.023232::decimal") 
    61 assert tuple(cur1) == ({"numeric": decimal.Decimal("5000.023232")},) 
     56assert tuple(cur1.iterate_dict()) == ({"numeric": decimal.Decimal("5000.023232")},) 
    6257 
    6358cur1.execute("SELECT 1.1::real") 
    64 assert tuple(cur1) == ({"float4": 1.1000000238418579},) 
     59assert tuple(cur1.iterate_dict()) == ({"float4": 1.1000000238418579},) 
    6560 
    6661cur1.execute("SELECT 1.1::double precision") 
    67 assert tuple(cur1) == ({"float8": 1.1000000000000001},) 
     62assert tuple(cur1.iterate_dict()) == ({"float8": 1.1000000000000001},) 
    6863 
    6964cur1.execute("SELECT 'hello'::varchar(50)") 
    70 assert tuple(cur1) == ({"varchar": u"hello"},) 
     65assert tuple(cur1.iterate_dict()) == ({"varchar": u"hello"},) 
    7166 
    7267cur1.execute("SELECT 'hello'::char(20)") 
    73 assert tuple(cur1) == ({"bpchar": u"hello               "},) 
     68assert tuple(cur1.iterate_dict()) == ({"bpchar": u"hello               "},) 
    7469 
    7570cur1.execute("SELECT 'hello'::text") 
    76 assert tuple(cur1) == ({"text": u"hello"},) 
     71assert tuple(cur1.iterate_dict()) == ({"text": u"hello"},) 
    7772 
    7873#cur1.execute("SELECT 'hell\007o'::bytea") 
    79 #assert tuple(cur1) == ({"bytea": "hello"},) 
     74#assert tuple(cur1.iterate_dict()) == ({"bytea": "hello"},) 
    8075 
    8176cur1.execute("SELECT '2001-02-03 04:05:06.17'::timestamp") 
    82 assert tuple(cur1) == ({'timestamp': datetime.datetime(2001, 2, 3, 4, 5, 6, 170000)},) 
     77assert tuple(cur1.iterate_dict()) == ({'timestamp': datetime.datetime(2001, 2, 3, 4, 5, 6, 170000)},) 
    8378 
    8479#cur1.execute("SELECT '2001-02-03 04:05:06.17'::timestamp with time zone") 
    85 #assert tuple(cur1) == ({'timestamp': datetime.datetime(2001, 2, 3, 4, 5, 6, 170000, pg8000.Types.FixedOffsetTz("-07"))},) 
     80#assert tuple(cur1.iterate_dict()) == ({'timestamp': datetime.datetime(2001, 2, 3, 4, 5, 6, 170000, pg8000.Types.FixedOffsetTz("-07"))},) 
    8681 
    8782cur1.execute("SELECT '1 month'::interval") 
    88 assert tuple(cur1) == ({'interval': '1 mon'},) 
    89 #print repr(tuple(cur1)) 
     83assert tuple(cur1.iterate_dict()) == ({'interval': '1 mon'},) 
     84#print repr(tuple(cur1.iterate_dict())) 
    9085 
    9186print "Type checks complete."