| 1 |
# vim: sw=4:expandtab:foldmethod=marker |
|---|
| 2 |
# |
|---|
| 3 |
# Copyright (c) 2007, Mathieu Fenniak |
|---|
| 4 |
# All rights reserved. |
|---|
| 5 |
# |
|---|
| 6 |
# Redistribution and use in source and binary forms, with or without |
|---|
| 7 |
# modification, are permitted provided that the following conditions are |
|---|
| 8 |
# met: |
|---|
| 9 |
# |
|---|
| 10 |
# * Redistributions of source code must retain the above copyright notice, |
|---|
| 11 |
# this list of conditions and the following disclaimer. |
|---|
| 12 |
# * Redistributions in binary form must reproduce the above copyright notice, |
|---|
| 13 |
# this list of conditions and the following disclaimer in the documentation |
|---|
| 14 |
# and/or other materials provided with the distribution. |
|---|
| 15 |
# * The name of the author may not be used to endorse or promote products |
|---|
| 16 |
# derived from this software without specific prior written permission. |
|---|
| 17 |
# |
|---|
| 18 |
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
|---|
| 19 |
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
|---|
| 20 |
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE |
|---|
| 21 |
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE |
|---|
| 22 |
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR |
|---|
| 23 |
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF |
|---|
| 24 |
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS |
|---|
| 25 |
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN |
|---|
| 26 |
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) |
|---|
| 27 |
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE |
|---|
| 28 |
# POSSIBILITY OF SUCH DAMAGE. |
|---|
| 29 |
|
|---|
| 30 |
__author__ = "Mathieu Fenniak" |
|---|
| 31 |
|
|---|
| 32 |
import socket |
|---|
| 33 |
import struct |
|---|
| 34 |
import datetime |
|---|
| 35 |
import md5 |
|---|
| 36 |
import decimal |
|---|
| 37 |
|
|---|
| 38 |
class Warning(StandardError): |
|---|
| 39 |
pass |
|---|
| 40 |
|
|---|
| 41 |
class Error(StandardError): |
|---|
| 42 |
pass |
|---|
| 43 |
|
|---|
| 44 |
class InterfaceError(Error): |
|---|
| 45 |
pass |
|---|
| 46 |
|
|---|
| 47 |
class DatabaseError(Error): |
|---|
| 48 |
pass |
|---|
| 49 |
|
|---|
| 50 |
class DataError(DatabaseError): |
|---|
| 51 |
pass |
|---|
| 52 |
|
|---|
| 53 |
class OperationalError(DatabaseError): |
|---|
| 54 |
pass |
|---|
| 55 |
|
|---|
| 56 |
class IntegrityError(DatabaseError): |
|---|
| 57 |
pass |
|---|
| 58 |
|
|---|
| 59 |
class InternalError(DatabaseError): |
|---|
| 60 |
pass |
|---|
| 61 |
|
|---|
| 62 |
class ProgrammingError(DatabaseError): |
|---|
| 63 |
pass |
|---|
| 64 |
|
|---|
| 65 |
class NotSupportedError(DatabaseError): |
|---|
| 66 |
pass |
|---|
| 67 |
|
|---|
| 68 |
|
|---|
| 69 |
class DataIterator(object): |
|---|
| 70 |
def __init__(self, obj, func): |
|---|
| 71 |
self.obj = obj |
|---|
| 72 |
self.func = func |
|---|
| 73 |
|
|---|
| 74 |
def __iter__(self): |
|---|
| 75 |
return self |
|---|
| 76 |
|
|---|
| 77 |
def next(self): |
|---|
| 78 |
retval = self.func(self.obj) |
|---|
| 79 |
if retval == None: |
|---|
| 80 |
raise StopIteration() |
|---|
| 81 |
return retval |
|---|
| 82 |
|
|---|
| 83 |
## |
|---|
| 84 |
# This class represents a prepared statement. A prepared statement is |
|---|
| 85 |
# pre-parsed on the server, which reduces the need to parse the query every |
|---|
| 86 |
# time it is run. The statement can have parameters in the form of $1, $2, $3, |
|---|
| 87 |
# etc. When parameters are used, the types of the parameters need to be |
|---|
| 88 |
# specified when creating the prepared statement. |
|---|
| 89 |
# <p> |
|---|
| 90 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 91 |
# |
|---|
| 92 |
# @param connection An instance of {@link Connection Connection}. |
|---|
| 93 |
# |
|---|
| 94 |
# @param statement The SQL statement to be represented, often containing |
|---|
| 95 |
# parameters in the form of $1, $2, $3, etc. |
|---|
| 96 |
# |
|---|
| 97 |
# @param types Python type objects for each parameter in the SQL |
|---|
| 98 |
# statement. For example, int, float, str. |
|---|
| 99 |
class PreparedStatement(object): |
|---|
| 100 |
|
|---|
| 101 |
## |
|---|
| 102 |
# Determines the number of rows to read from the database server at once. |
|---|
| 103 |
# Reading more rows increases performance at the cost of memory. The |
|---|
| 104 |
# default value is 100 rows. The affect of this parameter is transparent. |
|---|
| 105 |
# That is, the library reads more rows when the cache is empty |
|---|
| 106 |
# automatically. |
|---|
| 107 |
# <p> |
|---|
| 108 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. It is |
|---|
| 109 |
# possible that implementation changes in the future could cause this |
|---|
| 110 |
# parameter to be ignored.O |
|---|
| 111 |
row_cache_size = 100 |
|---|
| 112 |
|
|---|
| 113 |
def __init__(self, connection, statement, *types): |
|---|
| 114 |
self.c = connection.c |
|---|
| 115 |
self._portal_name = "pg8000_portal_%s_%s" % (id(self.c), id(self)) |
|---|
| 116 |
self._statement_name = "pg8000_statement_%s_%s" % (id(self.c), id(self)) |
|---|
| 117 |
self._row_desc = None |
|---|
| 118 |
self._cached_rows = [] |
|---|
| 119 |
self._command_complete = True |
|---|
| 120 |
self._parse_row_desc = self.c.parse(self._statement_name, statement, types) |
|---|
| 121 |
|
|---|
| 122 |
def __del__(self): |
|---|
| 123 |
# This __del__ should work with garbage collection / non-instant |
|---|
| 124 |
# cleanup. It only really needs to be called right away if the same |
|---|
| 125 |
# object id (and therefore the same statement name) might be reused |
|---|
| 126 |
# soon, and clearly that wouldn't happen in a GC situation. |
|---|
| 127 |
self.c.close_statement(self._statement_name) |
|---|
| 128 |
|
|---|
| 129 |
## |
|---|
| 130 |
# Run the SQL prepared statement with the given parameters. |
|---|
| 131 |
# <p> |
|---|
| 132 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 133 |
def execute(self, *args): |
|---|
| 134 |
if not self._command_complete: |
|---|
| 135 |
# cleanup last execute |
|---|
| 136 |
self._cached_rows = [] |
|---|
| 137 |
self.c.close_portal(self._portal_name) |
|---|
| 138 |
self._command_complete = False |
|---|
| 139 |
self._row_desc = self.c.bind(self._portal_name, self._statement_name, args, self._parse_row_desc) |
|---|
| 140 |
if self._row_desc: |
|---|
| 141 |
# We execute our cursor right away to fill up our cache. This |
|---|
| 142 |
# prevents the cursor from being destroyed, apparently, by a rogue |
|---|
| 143 |
# Sync between Bind and Execute. Since it is quite likely that |
|---|
| 144 |
# data will be read from us right away anyways, this seems a safe |
|---|
| 145 |
# move for now. |
|---|
| 146 |
self._fill_cache() |
|---|
| 147 |
|
|---|
| 148 |
def _fill_cache(self): |
|---|
| 149 |
if self._cached_rows: |
|---|
| 150 |
raise InternalError("attempt to fill cache that isn't empty") |
|---|
| 151 |
end_of_data, rows = self.c.fetch_rows(self._portal_name, self.row_cache_size, self._row_desc) |
|---|
| 152 |
self._cached_rows = rows |
|---|
| 153 |
if end_of_data: |
|---|
| 154 |
self._command_complete = True |
|---|
| 155 |
|
|---|
| 156 |
def _fetch(self): |
|---|
| 157 |
if not self._cached_rows: |
|---|
| 158 |
if self._command_complete: |
|---|
| 159 |
return None |
|---|
| 160 |
self._fill_cache() |
|---|
| 161 |
if self._command_complete and not self._cached_rows: |
|---|
| 162 |
# fill cache tells us the command is complete, but yet we have |
|---|
| 163 |
# no rows after filling our cache. This is a special case when |
|---|
| 164 |
# a query returns no rows. |
|---|
| 165 |
return None |
|---|
| 166 |
row = self._cached_rows[0] |
|---|
| 167 |
del self._cached_rows[0] |
|---|
| 168 |
return tuple(row) |
|---|
| 169 |
|
|---|
| 170 |
## |
|---|
| 171 |
# Read a row from the database server, and return it in a dictionary |
|---|
| 172 |
# indexed by column name/alias. This method will raise an error if two |
|---|
| 173 |
# columns have the same name. Returns None after the last row. |
|---|
| 174 |
# <p> |
|---|
| 175 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 176 |
def read_dict(self): |
|---|
| 177 |
row = self._fetch() |
|---|
| 178 |
if row == None: |
|---|
| 179 |
return row |
|---|
| 180 |
retval = {} |
|---|
| 181 |
for i in range(len(self._row_desc.fields)): |
|---|
| 182 |
col_name = self._row_desc.fields[i]['name'] |
|---|
| 183 |
if retval.has_key(col_name): |
|---|
| 184 |
raise InterfaceError("cannot return dict of row when two columns have the same name (%r)" % (col_name,)) |
|---|
| 185 |
retval[col_name] = row[i] |
|---|
| 186 |
return retval |
|---|
| 187 |
|
|---|
| 188 |
## |
|---|
| 189 |
# Read a row from the database server, and return it as a tuple of values. |
|---|
| 190 |
# Returns None after the last row. |
|---|
| 191 |
# <p> |
|---|
| 192 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 193 |
def read_tuple(self): |
|---|
| 194 |
row = self._fetch() |
|---|
| 195 |
if row == None: |
|---|
| 196 |
return row |
|---|
| 197 |
return row |
|---|
| 198 |
|
|---|
| 199 |
## |
|---|
| 200 |
# Return an iterator for the output of this statement. The iterator will |
|---|
| 201 |
# return a tuple for each row, in the same manner as {@link |
|---|
| 202 |
# #PreparedStatement.read_tuple read_tuple}. |
|---|
| 203 |
# <p> |
|---|
| 204 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 205 |
def iterate_tuple(self): |
|---|
| 206 |
return DataIterator(self, PreparedStatement.read_tuple) |
|---|
| 207 |
|
|---|
| 208 |
## |
|---|
| 209 |
# Return an iterator for the output of this statement. The iterator will |
|---|
| 210 |
# return a dict for each row, in the same manner as {@link |
|---|
| 211 |
# #PreparedStatement.read_dict read_dict}. |
|---|
| 212 |
# <p> |
|---|
| 213 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 214 |
def iterate_dict(self): |
|---|
| 215 |
return DataIterator(self, PreparedStatement.read_dict) |
|---|
| 216 |
|
|---|
| 217 |
## |
|---|
| 218 |
# The Cursor class allows multiple queries to be performed concurrently with a |
|---|
| 219 |
# single PostgreSQL connection. The Cursor object is implemented internally by |
|---|
| 220 |
# using a {@link PreparedStatement PreparedStatement} object, so if you plan to |
|---|
| 221 |
# use a statement multiple times, you might as well create a PreparedStatement |
|---|
| 222 |
# and save a small amount of reparsing time. |
|---|
| 223 |
# <p> |
|---|
| 224 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 225 |
# |
|---|
| 226 |
# @param connection An instance of {@link Connection Connection}. |
|---|
| 227 |
class Cursor(object): |
|---|
| 228 |
def __init__(self, connection): |
|---|
| 229 |
self.connection = connection |
|---|
| 230 |
self._stmt = None |
|---|
| 231 |
|
|---|
| 232 |
## |
|---|
| 233 |
# Run an SQL statement using this cursor. The SQL statement can have |
|---|
| 234 |
# parameters in the form of $1, $2, $3, etc., which will be filled in by |
|---|
| 235 |
# the additional arguments passed to this function. |
|---|
| 236 |
# <p> |
|---|
| 237 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 238 |
# @param query The SQL statement to execute. |
|---|
| 239 |
def execute(self, query, *args): |
|---|
| 240 |
self._stmt = PreparedStatement(self.connection, query, *[type(x) for x in args]) |
|---|
| 241 |
self._stmt.execute(*args) |
|---|
| 242 |
|
|---|
| 243 |
## |
|---|
| 244 |
# Read a row from the database server, and return it in a dictionary |
|---|
| 245 |
# indexed by column name/alias. This method will raise an error if two |
|---|
| 246 |
# columns have the same name. Returns None after the last row. |
|---|
| 247 |
# <p> |
|---|
| 248 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 249 |
def read_dict(self): |
|---|
| 250 |
if self._stmt == None: |
|---|
| 251 |
raise ProgrammingError("attempting to read from unexecuted cursor") |
|---|
| 252 |
return self._stmt.read_dict() |
|---|
| 253 |
|
|---|
| 254 |
## |
|---|
| 255 |
# Read a row from the database server, and return it as a tuple of values. |
|---|
| 256 |
# Returns None after the last row. |
|---|
| 257 |
# <p> |
|---|
| 258 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 259 |
def read_tuple(self): |
|---|
| 260 |
if self._stmt == None: |
|---|
| 261 |
raise ProgrammingError("attempting to read from unexecuted cursor") |
|---|
| 262 |
return self._stmt.read_tuple() |
|---|
| 263 |
|
|---|
| 264 |
## |
|---|
| 265 |
# Return an iterator for the output of this statement. The iterator will |
|---|
| 266 |
# return a tuple for each row, in the same manner as {@link |
|---|
| 267 |
# #PreparedStatement.read_tuple read_tuple}. |
|---|
| 268 |
# <p> |
|---|
| 269 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 270 |
def iterate_tuple(self): |
|---|
| 271 |
if self._stmt == None: |
|---|
| 272 |
raise ProgrammingError("attempting to read from unexecuted cursor") |
|---|
| 273 |
return self._stmt.iterate_tuple() |
|---|
| 274 |
|
|---|
| 275 |
## |
|---|
| 276 |
# Return an iterator for the output of this statement. The iterator will |
|---|
| 277 |
# return a dict for each row, in the same manner as {@link |
|---|
| 278 |
# #PreparedStatement.read_dict read_dict}. |
|---|
| 279 |
# <p> |
|---|
| 280 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 281 |
def iterate_dict(self): |
|---|
| 282 |
if self._stmt == None: |
|---|
| 283 |
raise ProgrammingError("attempting to read from unexecuted cursor") |
|---|
| 284 |
return self._stmt.iterate_dict() |
|---|
| 285 |
|
|---|
| 286 |
## |
|---|
| 287 |
# This class represents a connection to a PostgreSQL database. |
|---|
| 288 |
# <p> |
|---|
| 289 |
# The database connection is derived from the {@link #Cursor Cursor} class, |
|---|
| 290 |
# which provides a default cursor for running queries. It also provides |
|---|
| 291 |
# transaction control via the 'begin', 'commit', and 'rollback' methods. |
|---|
| 292 |
# Without beginning a transaction explicitly, all statements will autocommit to |
|---|
| 293 |
# the database. |
|---|
| 294 |
# <p> |
|---|
| 295 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 296 |
# |
|---|
| 297 |
# @param host The hostname of the PostgreSQL server to connect with. Only |
|---|
| 298 |
# TCP/IP connections are presently supported, so this parameter is mandatory. |
|---|
| 299 |
# |
|---|
| 300 |
# @param user The username to connect to the PostgreSQL server with. This |
|---|
| 301 |
# parameter is mandatory. |
|---|
| 302 |
# |
|---|
| 303 |
# @keyparam port The TCP/IP port of the PostgreSQL server instance. This |
|---|
| 304 |
# parameter defaults to 5432, the registered and common port of PostgreSQL |
|---|
| 305 |
# TCP/IP servers. |
|---|
| 306 |
# |
|---|
| 307 |
# @keyparam database The name of the database instance to connect with. This |
|---|
| 308 |
# parameter is optional, if omitted the PostgreSQL server will assume the |
|---|
| 309 |
# database name is the same as the username. |
|---|
| 310 |
# |
|---|
| 311 |
# @keyparam password The user password to connect to the server with. This |
|---|
| 312 |
# parameter is optional. If omitted, and the database server requests password |
|---|
| 313 |
# based authentication, the connection will fail. On the other hand, if this |
|---|
| 314 |
# parameter is provided and the database does not request password |
|---|
| 315 |
# authentication, then the password will not be used. |
|---|
| 316 |
# |
|---|
| 317 |
# @keyparam socket_timeout Socket connect timeout measured in seconds. |
|---|
| 318 |
# Defaults to 60 seconds. |
|---|
| 319 |
class Connection(Cursor): |
|---|
| 320 |
def __init__(self, host, user, port=5432, database=None, password=None, socket_timeout=60): |
|---|
| 321 |
self._row_desc = None |
|---|
| 322 |
try: |
|---|
| 323 |
self.c = Protocol.Connection(host, port, socket_timeout=socket_timeout) |
|---|
| 324 |
self.c.connect() |
|---|
| 325 |
self.c.authenticate(user, password=password, database=database) |
|---|
| 326 |
except socket.error, e: |
|---|
| 327 |
raise InterfaceError("communication error", e) |
|---|
| 328 |
Cursor.__init__(self, self) |
|---|
| 329 |
self._begin = PreparedStatement(self, "BEGIN TRANSACTION") |
|---|
| 330 |
self._commit = PreparedStatement(self, "COMMIT TRANSACTION") |
|---|
| 331 |
self._rollback = PreparedStatement(self, "ROLLBACK TRANSACTION") |
|---|
| 332 |
|
|---|
| 333 |
## |
|---|
| 334 |
# Begins a new transaction. |
|---|
| 335 |
# <p> |
|---|
| 336 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 337 |
def begin(self): |
|---|
| 338 |
self._begin.execute() |
|---|
| 339 |
|
|---|
| 340 |
## |
|---|
| 341 |
# Commits the running transaction. |
|---|
| 342 |
# <p> |
|---|
| 343 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 344 |
def commit(self): |
|---|
| 345 |
self._commit.execute() |
|---|
| 346 |
|
|---|
| 347 |
## |
|---|
| 348 |
# Rolls back the running transaction. |
|---|
| 349 |
# <p> |
|---|
| 350 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 351 |
def rollback(self): |
|---|
| 352 |
self._rollback.execute() |
|---|
| 353 |
|
|---|
| 354 |
|
|---|
| 355 |
class Protocol(object): |
|---|
| 356 |
class StartupMessage(object): |
|---|
| 357 |
def __init__(self, user, database=None): |
|---|
| 358 |
self.user = user |
|---|
| 359 |
self.database = database |
|---|
| 360 |
|
|---|
| 361 |
def serialize(self): |
|---|
| 362 |
protocol = 196608 |
|---|
| 363 |
val = struct.pack("!i", protocol) |
|---|
| 364 |
val += "user\x00" + self.user + "\x00" |
|---|
| 365 |
if self.database: |
|---|
| 366 |
val += "database\x00" + self.database + "\x00" |
|---|
| 367 |
val += "\x00" |
|---|
| 368 |
val = struct.pack("!i", len(val) + 4) + val |
|---|
| 369 |
return val |
|---|
| 370 |
|
|---|
| 371 |
class Query(object): |
|---|
| 372 |
def __init__(self, qs): |
|---|
| 373 |
self.qs = qs |
|---|
| 374 |
|
|---|
| 375 |
def serialize(self): |
|---|
| 376 |
val = self.qs + "\x00" |
|---|
| 377 |
val = struct.pack("!i", len(val) + 4) + val |
|---|
| 378 |
val = "Q" + val |
|---|
| 379 |
return val |
|---|
| 380 |
|
|---|
| 381 |
class Parse(object): |
|---|
| 382 |
def __init__(self, ps, qs, type_oids): |
|---|
| 383 |
self.ps = ps |
|---|
| 384 |
self.qs = qs |
|---|
| 385 |
self.type_oids = type_oids |
|---|
| 386 |
|
|---|
| 387 |
def serialize(self): |
|---|
| 388 |
val = self.ps + "\x00" + self.qs + "\x00" |
|---|
| 389 |
val = val + struct.pack("!h", len(self.type_oids)) |
|---|
| 390 |
for oid in self.type_oids: |
|---|
| 391 |
val = val + struct.pack("!i", oid) |
|---|
| 392 |
val = struct.pack("!i", len(val) + 4) + val |
|---|
| 393 |
val = "P" + val |
|---|
| 394 |
return val |
|---|
| 395 |
|
|---|
| 396 |
class Bind(object): |
|---|
| 397 |
def __init__(self, portal, ps, in_fc, params, out_fc, client_encoding): |
|---|
| 398 |
self.portal = portal |
|---|
| 399 |
self.ps = ps |
|---|
| 400 |
self.in_fc = in_fc |
|---|
| 401 |
self.params = [] |
|---|
| 402 |
for i in range(len(params)): |
|---|
| 403 |
if len(self.in_fc) == 0: |
|---|
| 404 |
fc = 0 |
|---|
| 405 |
elif len(self.in_fc) == 1: |
|---|
| 406 |
fc = self.in_fc[0] |
|---|
| 407 |
else: |
|---|
| 408 |
fc = self.in_fc[i] |
|---|
| 409 |
self.params.append(Types.pg_value(params[i], fc, client_encoding = client_encoding)) |
|---|
| 410 |
self.out_fc = out_fc |
|---|
| 411 |
|
|---|
| 412 |
def serialize(self): |
|---|
| 413 |
val = self.portal + "\x00" + self.ps + "\x00" |
|---|
| 414 |
val = val + struct.pack("!h", len(self.in_fc)) |
|---|
| 415 |
for fc in self.in_fc: |
|---|
| 416 |
val = val + struct.pack("!h", fc) |
|---|
| 417 |
val = val + struct.pack("!h", len(self.params)) |
|---|
| 418 |
for param in self.params: |
|---|
| 419 |
if param == None: |
|---|
| 420 |
# special case, NULL value |
|---|
| 421 |
val = val + struct.pack("!i", -1) |
|---|
| 422 |
else: |
|---|
| 423 |
val = val + struct.pack("!i", len(param)) + param |
|---|
| 424 |
val = val + struct.pack("!h", len(self.out_fc)) |
|---|
| 425 |
for fc in self.out_fc: |
|---|
| 426 |
val = val + struct.pack("!h", fc) |
|---|
| 427 |
val = struct.pack("!i", len(val) + 4) + val |
|---|
| 428 |
val = "B" + val |
|---|
| 429 |
return val |
|---|
| 430 |
|
|---|
| 431 |
class Close(object): |
|---|
| 432 |
def __init__(self, typ, name): |
|---|
| 433 |
if len(typ) != 1: |
|---|
| 434 |
raise InternalError("Close typ must be 1 char") |
|---|
| 435 |
self.typ = typ |
|---|
| 436 |
self.name = name |
|---|
| 437 |
|
|---|
| 438 |
def serialize(self): |
|---|
| 439 |
val = self.typ + self.name + "\x00" |
|---|
| 440 |
val = struct.pack("!i", len(val) + 4) + val |
|---|
| 441 |
val = "C" + val |
|---|
| 442 |
return val |
|---|
| 443 |
|
|---|
| 444 |
class ClosePortal(Close): |
|---|
| 445 |
def __init__(self, name): |
|---|
| 446 |
Protocol.Close.__init__(self, "P", name) |
|---|
| 447 |
|
|---|
| 448 |
class ClosePreparedStatement(Close): |
|---|
| 449 |
def __init__(self, name): |
|---|
| 450 |
Protocol.Close.__init__(self, "S", name) |
|---|
| 451 |
|
|---|
| 452 |
class Describe(object): |
|---|
| 453 |
def __init__(self, typ, name): |
|---|
| 454 |
if len(typ) != 1: |
|---|
| 455 |
raise InternalError("Describe typ must be 1 char") |
|---|
| 456 |
self.typ = typ |
|---|
| 457 |
self.name = name |
|---|
| 458 |
|
|---|
| 459 |
def serialize(self): |
|---|
| 460 |
val = self.typ + self.name + "\x00" |
|---|
| 461 |
val = struct.pack("!i", len(val) + 4) + val |
|---|
| 462 |
val = "D" + val |
|---|
| 463 |
return val |
|---|
| 464 |
|
|---|
| 465 |
class DescribePortal(Describe): |
|---|
| 466 |
def __init__(self, name): |
|---|
| 467 |
Protocol.Describe.__init__(self, "P", name) |
|---|
| 468 |
|
|---|
| 469 |
class DescribePreparedStatement(Describe): |
|---|
| 470 |
def __init__(self, name): |
|---|
| 471 |
Protocol.Describe.__init__(self, "S", name) |
|---|
| 472 |
|
|---|
| 473 |
class Flush(object): |
|---|
| 474 |
def serialize(self): |
|---|
| 475 |
return 'H\x00\x00\x00\x04' |
|---|
| 476 |
|
|---|
| 477 |
class Sync(object): |
|---|
| 478 |
def serialize(self): |
|---|
| 479 |
return 'S\x00\x00\x00\x04' |
|---|
| 480 |
|
|---|
| 481 |
class PasswordMessage(object): |
|---|
| 482 |
def __init__(self, pwd): |
|---|
| 483 |
self.pwd = pwd |
|---|
| 484 |
|
|---|
| 485 |
def serialize(self): |
|---|
| 486 |
val = self.pwd + "\x00" |
|---|
| 487 |
val = struct.pack("!i", len(val) + 4) + val |
|---|
| 488 |
val = "p" + val |
|---|
| 489 |
return val |
|---|
| 490 |
|
|---|
| 491 |
class Execute(object): |
|---|
| 492 |
def __init__(self, portal, row_count): |
|---|
| 493 |
self.portal = portal |
|---|
| 494 |
self.row_count = row_count |
|---|
| 495 |
|
|---|
| 496 |
def serialize(self): |
|---|
| 497 |
val = self.portal + "\x00" + struct.pack("!i", self.row_count) |
|---|
| 498 |
val = struct.pack("!i", len(val) + 4) + val |
|---|
| 499 |
val = "E" + val |
|---|
| 500 |
return val |
|---|
| 501 |
|
|---|
| 502 |
class AuthenticationRequest(object): |
|---|
| 503 |
def __init__(self, data): |
|---|
| 504 |
pass |
|---|
| 505 |
|
|---|
| 506 |
def createFromData(data): |
|---|
| 507 |
ident = struct.unpack("!i", data[:4])[0] |
|---|
| 508 |
klass = Protocol.authentication_codes.get(ident, None) |
|---|
| 509 |
if klass != None: |
|---|
| 510 |
return klass(data[4:]) |
|---|
| 511 |
else: |
|---|
| 512 |
raise NotSupportedError("authentication method %r not supported" % (ident,)) |
|---|
| 513 |
createFromData = staticmethod(createFromData) |
|---|
| 514 |
|
|---|
| 515 |
def ok(self, conn, user, **kwargs): |
|---|
| 516 |
raise InternalError("ok method should be overridden on AuthenticationRequest instance") |
|---|
| 517 |
|
|---|
| 518 |
class AuthenticationOk(AuthenticationRequest): |
|---|
| 519 |
def ok(self, conn, user, **kwargs): |
|---|
| 520 |
return True |
|---|
| 521 |
|
|---|
| 522 |
class AuthenticationMD5Password(AuthenticationRequest): |
|---|
| 523 |
def __init__(self, data): |
|---|
| 524 |
self.salt = "".join(struct.unpack("4c", data)) |
|---|
| 525 |
|
|---|
| 526 |
def ok(self, conn, user, password=None, **kwargs): |
|---|
| 527 |
if password == None: |
|---|
| 528 |
raise InterfaceError("server requesting MD5 password authentication, but no password was provided") |
|---|
| 529 |
pwd = "md5" + md5.new(md5.new(password + user).hexdigest() + self.salt).hexdigest() |
|---|
| 530 |
conn._send(Protocol.PasswordMessage(pwd)) |
|---|
| 531 |
msg = conn._read_message() |
|---|
| 532 |
if isinstance(msg, Protocol.AuthenticationRequest): |
|---|
| 533 |
return msg.ok(conn, user) |
|---|
| 534 |
elif isinstance(msg, Protocol.ErrorResponse): |
|---|
| 535 |
if msg.code == "28000": |
|---|
| 536 |
raise InterfaceError("md5 password authentication failed") |
|---|
| 537 |
else: |
|---|
| 538 |
raise InternalError("server returned unexpected error %r" % msg) |
|---|
| 539 |
else: |
|---|
| 540 |
raise InternalError("server returned unexpected response %r" % msg) |
|---|
| 541 |
|
|---|
| 542 |
authentication_codes = { |
|---|
| 543 |
0: AuthenticationOk, |
|---|
| 544 |
5: AuthenticationMD5Password, |
|---|
| 545 |
} |
|---|
| 546 |
|
|---|
| 547 |
class ParameterStatus(object): |
|---|
| 548 |
def __init__(self, key, value): |
|---|
| 549 |
self.key = key |
|---|
| 550 |
self.value = value |
|---|
| 551 |
|
|---|
| 552 |
def createFromData(data): |
|---|
| 553 |
key = data[:data.find("\x00")] |
|---|
| 554 |
value = data[data.find("\x00")+1:-1] |
|---|
| 555 |
return Protocol.ParameterStatus(key, value) |
|---|
| 556 |
createFromData = staticmethod(createFromData) |
|---|
| 557 |
|
|---|
| 558 |
class BackendKeyData(object): |
|---|
| 559 |
def __init__(self, process_id, secret_key): |
|---|
| 560 |
self.process_id = process_id |
|---|
| 561 |
self.secret_key = secret_key |
|---|
| 562 |
|
|---|
| 563 |
def createFromData(data): |
|---|
| 564 |
process_id, secret_key = struct.unpack("!2i", data) |
|---|
| 565 |
return Protocol.BackendKeyData(process_id, secret_key) |
|---|
| 566 |
createFromData = staticmethod(createFromData) |
|---|
| 567 |
|
|---|
| 568 |
class NoData(object): |
|---|
| 569 |
def createFromData(data): |
|---|
| 570 |
return Protocol.NoData() |
|---|
| 571 |
createFromData = staticmethod(createFromData) |
|---|
| 572 |
|
|---|
| 573 |
class ParseComplete(object): |
|---|
| 574 |
def createFromData(data): |
|---|
| 575 |
return Protocol.ParseComplete() |
|---|
| 576 |
createFromData = staticmethod(createFromData) |
|---|
| 577 |
|
|---|
| 578 |
class BindComplete(object): |
|---|
| 579 |
def createFromData(data): |
|---|
| 580 |
return Protocol.BindComplete() |
|---|
| 581 |
createFromData = staticmethod(createFromData) |
|---|
| 582 |
|
|---|
| 583 |
class CloseComplete(object): |
|---|
| 584 |
def createFromData(data): |
|---|
| 585 |
return Protocol.CloseComplete() |
|---|
| 586 |
createFromData = staticmethod(createFromData) |
|---|
| 587 |
|
|---|
| 588 |
class PortalSuspended(object): |
|---|
| 589 |
def createFromData(data): |
|---|
| 590 |
return Protocol.PortalSuspended() |
|---|
| 591 |
createFromData = staticmethod(createFromData) |
|---|
| 592 |
|
|---|
| 593 |
class ReadyForQuery(object): |
|---|
| 594 |
def __init__(self, status): |
|---|
| 595 |
self.status = status |
|---|
| 596 |
|
|---|
| 597 |
def __repr__(self): |
|---|
| 598 |
return "<ReadyForQuery %s>" % \ |
|---|
| 599 |
{"I": "Idle", "T": "Idle in Transaction", "E": "Idle in Failed Transaction"}[self.status] |
|---|
| 600 |
|
|---|
| 601 |
def createFromData(data): |
|---|
| 602 |
return Protocol.ReadyForQuery(data) |
|---|
| 603 |
createFromData = staticmethod(createFromData) |
|---|
| 604 |
|
|---|
| 605 |
class NoticeResponse(object): |
|---|
| 606 |
def __init__(self): |
|---|
| 607 |
pass |
|---|
| 608 |
def createFromData(data): |
|---|
| 609 |
# we could read the notice here, but we don't care yet. |
|---|
| 610 |
return Protocol.NoticeResponse() |
|---|
| 611 |
createFromData = staticmethod(createFromData) |
|---|
| 612 |
|
|---|
| 613 |
class ErrorResponse(object): |
|---|
| 614 |
def __init__(self, severity, code, msg): |
|---|
| 615 |
self.severity = severity |
|---|
| 616 |
self.code = code |
|---|
| 617 |
self.msg = msg |
|---|
| 618 |
|
|---|
| 619 |
def __repr__(self): |
|---|
| 620 |
return "<ErrorResponse %s %s %r>" % (self.severity, self.code, self.msg) |
|---|
| 621 |
|
|---|
| 622 |
def createException(self): |
|---|
| 623 |
return ProgrammingError(self.severity, self.code, self.msg) |
|---|
| 624 |
|
|---|
| 625 |
def createFromData(data): |
|---|
| 626 |
args = {} |
|---|
| 627 |
for s in data.split("\x00"): |
|---|
| 628 |
if not s: |
|---|
| 629 |
continue |
|---|
| 630 |
elif s[0] == "S": |
|---|
| 631 |
args["severity"] = s[1:] |
|---|
| 632 |
elif s[0] == "C": |
|---|
| 633 |
args["code"] = s[1:] |
|---|
| 634 |
elif s[0] == "M": |
|---|
| 635 |
args["msg"] = s[1:] |
|---|
| 636 |
return Protocol.ErrorResponse(**args) |
|---|
| 637 |
createFromData = staticmethod(createFromData) |
|---|
| 638 |
|
|---|
| 639 |
class ParameterDescription(object): |
|---|
| 640 |
def __init__(self, type_oids): |
|---|
| 641 |
self.type_oids = type_oids |
|---|
| 642 |
def createFromData(data): |
|---|
| 643 |
count = struct.unpack("!h", data[:2])[0] |
|---|
| 644 |
type_oids = struct.unpack("!" + "i"*count, data[2:]) |
|---|
| 645 |
return Protocol.ParameterDescription(type_oids) |
|---|
| 646 |
createFromData = staticmethod(createFromData) |
|---|
| 647 |
|
|---|
| 648 |
class RowDescription(object): |
|---|
| 649 |
def __init__(self, fields): |
|---|
| 650 |
self.fields = fields |
|---|
| 651 |
|
|---|
| 652 |
def createFromData(data): |
|---|
| 653 |
count = struct.unpack("!h", data[:2])[0] |
|---|
| 654 |
data = data[2:] |
|---|
| 655 |
fields = [] |
|---|
| 656 |
for i in range(count): |
|---|
| 657 |
null = data.find("\x00") |
|---|
| 658 |
field = {"name": data[:null]} |
|---|
| 659 |
data = data[null+1:] |
|---|
| 660 |
field["table_oid"], field["column_attrnum"], field["type_oid"], field["type_size"], field["type_modifier"], field["format"] = struct.unpack("!ihihih", data[:18]) |
|---|
| 661 |
data = data[18:] |
|---|
| 662 |
fields.append(field) |
|---|
| 663 |
return Protocol.RowDescription(fields) |
|---|
| 664 |
createFromData = staticmethod(createFromData) |
|---|
| 665 |
|
|---|
| 666 |
class CommandComplete(object): |
|---|
| 667 |
def __init__(self, tag): |
|---|
| 668 |
self.tag = tag |
|---|
| 669 |
|
|---|
| 670 |
def createFromData(data): |
|---|
| 671 |
return Protocol.CommandComplete(data[:-1]) |
|---|
| 672 |
createFromData = staticmethod(createFromData) |
|---|
| 673 |
|
|---|
| 674 |
class DataRow(object): |
|---|
| 675 |
def __init__(self, fields): |
|---|
| 676 |
self.fields = fields |
|---|
| 677 |
|
|---|
| 678 |
def createFromData(data): |
|---|
| 679 |
count = struct.unpack("!h", data[:2])[0] |
|---|
| 680 |
data = data[2:] |
|---|
| 681 |
fields = [] |
|---|
| 682 |
for i in range(count): |
|---|
| 683 |
val_len = struct.unpack("!i", data[:4])[0] |
|---|
| 684 |
data = data[4:] |
|---|
| 685 |
if val_len == -1: |
|---|
| 686 |
fields.append(None) |
|---|
| 687 |
else: |
|---|
| 688 |
fields.append(data[:val_len]) |
|---|
| 689 |
data = data[val_len:] |
|---|
| 690 |
return Protocol.DataRow(fields) |
|---|
| 691 |
createFromData = staticmethod(createFromData) |
|---|
| 692 |
|
|---|
| 693 |
class Connection(object): |
|---|
| 694 |
def __init__(self, host=None, port=5432, socket_timeout=60): |
|---|
| 695 |
self._state = "unconnected" |
|---|
| 696 |
self._client_encoding = "ascii" |
|---|
| 697 |
self._host = host |
|---|
| 698 |
self._port = port |
|---|
| 699 |
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
|---|
| 700 |
self._sock.settimeout(socket_timeout) |
|---|
| 701 |
self._backend_key_data = None |
|---|
| 702 |
|
|---|
| 703 |
def verifyState(self, state): |
|---|
| 704 |
if self._state != state: |
|---|
| 705 |
raise InternalError("connection state must be %s, is %s" % (state, self._state)) |
|---|
| 706 |
|
|---|
| 707 |
def _send(self, msg): |
|---|
| 708 |
#print repr(msg) |
|---|
| 709 |
data = msg.serialize() |
|---|
| 710 |
self._sock.send(data) |
|---|
| 711 |
|
|---|
| 712 |
def _read_message(self): |
|---|
| 713 |
bytes = self._sock.recv(5) |
|---|
| 714 |
assert len(bytes) == 5 |
|---|
| 715 |
message_code = bytes[0] |
|---|
| 716 |
data_len = struct.unpack("!i", bytes[1:])[0] - 4 |
|---|
| 717 |
if data_len == 0: |
|---|
| 718 |
bytes = "" |
|---|
| 719 |
else: |
|---|
| 720 |
bytes = self._sock.recv(data_len) |
|---|
| 721 |
msg = Protocol.message_types[message_code].createFromData(bytes) |
|---|
| 722 |
if isinstance(msg, Protocol.NoticeResponse): |
|---|
| 723 |
# ignore NoticeResponse |
|---|
| 724 |
return self._read_message() |
|---|
| 725 |
else: |
|---|
| 726 |
return msg |
|---|
| 727 |
|
|---|
| 728 |
def connect(self): |
|---|
| 729 |
self.verifyState("unconnected") |
|---|
| 730 |
self._sock.connect((self._host, self._port)) |
|---|
| 731 |
self._state = "noauth" |
|---|
| 732 |
|
|---|
| 733 |
def authenticate(self, user, **kwargs): |
|---|
| 734 |
self.verifyState("noauth") |
|---|
| 735 |
self._send(Protocol.StartupMessage(user, database=kwargs.get("database",None))) |
|---|
| 736 |
msg = self._read_message() |
|---|
| 737 |
if isinstance(msg, Protocol.AuthenticationRequest): |
|---|
| 738 |
if msg.ok(self, user, **kwargs): |
|---|
| 739 |
self._state = "auth" |
|---|
| 740 |
while 1: |
|---|
| 741 |
msg = self._read_message() |
|---|
| 742 |
if isinstance(msg, Protocol.ReadyForQuery): |
|---|
| 743 |
# done reading messages |
|---|
| 744 |
self._state = "ready" |
|---|
| 745 |
break |
|---|
| 746 |
elif isinstance(msg, Protocol.ParameterStatus): |
|---|
| 747 |
if msg.key == "client_encoding": |
|---|
| 748 |
self._client_encoding = msg.value |
|---|
| 749 |
elif isinstance(msg, Protocol.BackendKeyData): |
|---|
| 750 |
self._backend_key_data = msg |
|---|
| 751 |
elif isinstance(msg, Protocol.ErrorResponse): |
|---|
| 752 |
raise msg.createException() |
|---|
| 753 |
else: |
|---|
| 754 |
raise InternalError("unexpected msg %r" % msg) |
|---|
| 755 |
else: |
|---|
| 756 |
raise InterfaceError("authentication method %s failed" % msg.__class__.__name__) |
|---|
| 757 |
else: |
|---|
| 758 |
raise InternalError("StartupMessage was responded to with non-AuthenticationRequest msg %r" % msg) |
|---|
| 759 |
|
|---|
| 760 |
def parse(self, statement, qs, types): |
|---|
| 761 |
self.verifyState("ready") |
|---|
| 762 |
type_info = [Types.pg_type_info(x) for x in types] |
|---|
| 763 |
param_types, param_fc = [x[0] for x in type_info], [x[1] for x in type_info] # zip(*type_info) -- fails on empty arr |
|---|
| 764 |
self._send(Protocol.Parse(statement, qs, param_types)) |
|---|
| 765 |
self._send(Protocol.DescribePreparedStatement(statement)) |
|---|
| 766 |
self._send(Protocol.Flush()) |
|---|
| 767 |
while 1: |
|---|
| 768 |
msg = self._read_message() |
|---|
| 769 |
if isinstance(msg, Protocol.ParseComplete): |
|---|
| 770 |
# ok, good. |
|---|
| 771 |
pass |
|---|
| 772 |
elif isinstance(msg, Protocol.ParameterDescription): |
|---|
| 773 |
# well, we don't really care -- we're going to send whatever |
|---|
| 774 |
# we want and let the database deal with it. But thanks |
|---|
| 775 |
# anyways! |
|---|
| 776 |
pass |
|---|
| 777 |
elif isinstance(msg, Protocol.NoData): |
|---|
| 778 |
# We're not waiting for a row description. Return |
|---|
| 779 |
# something destinctive to let bind know that there is no |
|---|
| 780 |
# output. |
|---|
| 781 |
return (None, param_fc) |
|---|
| 782 |
elif isinstance(msg, Protocol.RowDescription): |
|---|
| 783 |
return (msg, param_fc) |
|---|
| 784 |
elif isinstance(msg, Protocol.ErrorResponse): |
|---|
| 785 |
raise msg.createException() |
|---|
| 786 |
else: |
|---|
| 787 |
raise InternalError("Unexpected response msg %r" % (msg)) |
|---|
| 788 |
|
|---|
| 789 |
def bind(self, portal, statement, params, parse_data): |
|---|
| 790 |
self.verifyState("ready") |
|---|
| 791 |
row_desc, param_fc = parse_data |
|---|
| 792 |
if row_desc == None: |
|---|
| 793 |
# no data coming out |
|---|
| 794 |
output_fc = () |
|---|
| 795 |
else: |
|---|
| 796 |
# We've got row_desc that allows us to identify what we're going to |
|---|
| 797 |
# get back from this statement. |
|---|
| 798 |
output_fc = [Types.py_type_info(f) for f in row_desc.fields] |
|---|
| 799 |
self._send(Protocol.Bind(portal, statement, param_fc, params, output_fc, self._client_encoding)) |
|---|
| 800 |
# We need to describe the portal after bind, since the return |
|---|
| 801 |
# format codes will be different (hopefully, always what we |
|---|
| 802 |
# requested). |
|---|
| 803 |
self._send(Protocol.DescribePortal(portal)) |
|---|
| 804 |
self._send(Protocol.Flush()) |
|---|
| 805 |
while 1: |
|---|
| 806 |
msg = self._read_message() |
|---|
| 807 |
if isinstance(msg, Protocol.BindComplete): |
|---|
| 808 |
# good news everybody! |
|---|
| 809 |
pass |
|---|
| 810 |
elif isinstance(msg, Protocol.NoData): |
|---|
| 811 |
# No data means we should execute this command right away. |
|---|
| 812 |
self._send(Protocol.Execute(portal, 0)) |
|---|
| 813 |
self._send(Protocol.Sync()) |
|---|
| 814 |
while 1: |
|---|
| 815 |
msg = self._read_message() |
|---|
| 816 |
if isinstance(msg, Protocol.CommandComplete): |
|---|
| 817 |
# more good news! |
|---|
| 818 |
pass |
|---|
| 819 |
elif isinstance(msg, Protocol.ReadyForQuery): |
|---|
| 820 |
# ready to move on with life... |
|---|
| 821 |
break |
|---|
| 822 |
elif isinstance(msg, Protocol.ErrorResponse): |
|---|
| 823 |
raise msg.createException() |
|---|
| 824 |
else: |
|---|
| 825 |
raise InternalError("unexpected response") |
|---|
| 826 |
return None |
|---|
| 827 |
elif isinstance(msg, Protocol.RowDescription): |
|---|
| 828 |
# Return the new row desc, since it will have the format |
|---|
| 829 |
# types we asked for |
|---|
| 830 |
return msg |
|---|
| 831 |
elif isinstance(msg, Protocol.ErrorResponse): |
|---|
| 832 |
raise msg.createException() |
|---|
| 833 |
else: |
|---|
| 834 |
raise InternalError("Unexpected response msg %r" % (msg)) |
|---|
| 835 |
|
|---|
| 836 |
def fetch_rows(self, portal, row_count, row_desc): |
|---|
| 837 |
self.verifyState("ready") |
|---|
| 838 |
self._send(Protocol.Execute(portal, row_count)) |
|---|
| 839 |
self._send(Protocol.Flush()) |
|---|
| 840 |
rows = [] |
|---|
| 841 |
end_of_data = False |
|---|
| 842 |
while 1: |
|---|
| 843 |
msg = self._read_message() |
|---|
| 844 |
if isinstance(msg, Protocol.DataRow): |
|---|
| 845 |
rows.append( |
|---|
| 846 |
[Types.py_value(msg.fields[i], row_desc.fields[i], client_encoding=self._client_encoding) |
|---|
| 847 |
for i in range(len(msg.fields))] |
|---|
| 848 |
) |
|---|
| 849 |
elif isinstance(msg, Protocol.PortalSuspended): |
|---|
| 850 |
# got all the rows we asked for, but not all that exist |
|---|
| 851 |
break |
|---|
| 852 |
elif isinstance(msg, Protocol.CommandComplete): |
|---|
| 853 |
self._send(Protocol.ClosePortal(portal)) |
|---|
| 854 |
self._send(Protocol.Sync()) |
|---|
| 855 |
while 1: |
|---|
| 856 |
msg = self._read_message() |
|---|
| 857 |
if isinstance(msg, Protocol.ReadyForQuery): |
|---|
| 858 |
# ready to move on with life... |
|---|
| 859 |
self._state = "ready" |
|---|
| 860 |
break |
|---|
| 861 |
elif isinstance(msg, Protocol.CloseComplete): |
|---|
| 862 |
# ok, great! |
|---|
| 863 |
pass |
|---|
| 864 |
elif isinstance(msg, Protocol.ErrorResponse): |
|---|
| 865 |
raise msg.createException() |
|---|
| 866 |
else: |
|---|
| 867 |
raise InternalError("unexpected response msg %r" % msg) |
|---|
| 868 |
end_of_data = True |
|---|
| 869 |
break |
|---|
| 870 |
elif isinstance(msg, Protocol.ErrorResponse): |
|---|
| 871 |
raise msg.createException() |
|---|
| 872 |
else: |
|---|
| 873 |
raise InternalError("Unexpected response msg %r" % msg) |
|---|
| 874 |
return end_of_data, rows |
|---|
| 875 |
|
|---|
| 876 |
def close_statement(self, statement): |
|---|
| 877 |
self._send(Protocol.ClosePreparedStatement(statement)) |
|---|
| 878 |
self._send(Protocol.Sync()) |
|---|
| 879 |
while 1: |
|---|
| 880 |
msg = self._read_message() |
|---|
| 881 |
if isinstance(msg, Protocol.CloseComplete): |
|---|
| 882 |
# thanks! |
|---|
| 883 |
pass |
|---|
| 884 |
elif isinstance(msg, Protocol.ReadyForQuery): |
|---|
| 885 |
return |
|---|
| 886 |
elif isinstance(msg, Protocol.ErrorResponse): |
|---|
| 887 |
raise msg.createException() |
|---|
| 888 |
else: |
|---|
| 889 |
raise InternalError("Unexpected response msg %r" % msg) |
|---|
| 890 |
|
|---|
| 891 |
def close_portal(self, portal): |
|---|
| 892 |
self._send(Protocol.ClosePortal(portal)) |
|---|
| 893 |
self._send(Protocol.Sync()) |
|---|
| 894 |
while 1: |
|---|
| 895 |
msg = self._read_message() |
|---|
| 896 |
if isinstance(msg, Protocol.CloseComplete): |
|---|
| 897 |
# thanks! |
|---|
| 898 |
pass |
|---|
| 899 |
elif isinstance(msg, Protocol.ReadyForQuery): |
|---|
| 900 |
return |
|---|
| 901 |
elif isinstance(msg, Protocol.ErrorResponse): |
|---|
| 902 |
raise msg.createException() |
|---|
| 903 |
else: |
|---|
| 904 |
raise InternalError("Unexpected response msg %r" % msg) |
|---|
| 905 |
|
|---|
| 906 |
def query(self, qs): |
|---|
| 907 |
self.verifyState("ready") |
|---|
| 908 |
self._send(Protocol.Query(qs)) |
|---|
| 909 |
msg = self._read_message() |
|---|
| 910 |
if isinstance(msg, Protocol.RowDescription): |
|---|
| 911 |
self._state = "in_query" |
|---|
| 912 |
return msg |
|---|
| 913 |
elif isinstance(msg, Protocol.ErrorResponse): |
|---|
| 914 |
raise msg.createException() |
|---|
| 915 |
else: |
|---|
| 916 |
raise InternalError("RowDescription expected, other message recv'd") |
|---|
| 917 |
|
|---|
| 918 |
def getrow(self): |
|---|
| 919 |
self.verifyState("in_query") |
|---|
| 920 |
msg = self._read_message() |
|---|
| 921 |
if isinstance(msg, Protocol.DataRow): |
|---|
| 922 |
return msg |
|---|
| 923 |
elif isinstance(msg, Protocol.CommandComplete): |
|---|
| 924 |
self.status = "query_complete" |
|---|
| 925 |
self._waitForReady() |
|---|
| 926 |
return None |
|---|
| 927 |
|
|---|
| 928 |
message_types = { |
|---|
| 929 |
"N": NoticeResponse, |
|---|
| 930 |
"R": AuthenticationRequest, |
|---|
| 931 |
"S": ParameterStatus, |
|---|
| 932 |
"K": BackendKeyData, |
|---|
| 933 |
"Z": ReadyForQuery, |
|---|
| 934 |
"T": RowDescription, |
|---|
| 935 |
"E": ErrorResponse, |
|---|
| 936 |
"D": DataRow, |
|---|
| 937 |
"C": CommandComplete, |
|---|
| 938 |
"1": ParseComplete, |
|---|
| 939 |
"2": BindComplete, |
|---|
| 940 |
"3": CloseComplete, |
|---|
| 941 |
"s": PortalSuspended, |
|---|
| 942 |
"n": NoData, |
|---|
| 943 |
"t": ParameterDescription, |
|---|
| 944 |
} |
|---|
| 945 |
|
|---|
| 946 |
class Types(object): |
|---|
| 947 |
def pg_type_info(typ): |
|---|
| 948 |
data = Types.py_types.get(typ) |
|---|
| 949 |
if data == None: |
|---|
| 950 |
raise NotSupportedError("type %r not mapped to pg type" % typ) |
|---|
| 951 |
type_oid = data.get("tid") |
|---|
| 952 |
if type_oid == None: |
|---|
| 953 |
raise InternalError("type %r has no type_oid" % typ) |
|---|
| 954 |
prefer = data.get("prefer") |
|---|
| 955 |
if prefer != None: |
|---|
| 956 |
if prefer == "bin": |
|---|
| 957 |
if data.get("bin_out") == None: |
|---|
| 958 |
raise InternalError("bin format prefered but not avail for type %r" % typ) |
|---|
| 959 |
format = 1 |
|---|
| 960 |
elif prefer == "txt": |
|---|
| 961 |
if data.get("txt_out") == None: |
|---|
| 962 |
raise InternalError("txt format prefered but not avail for type %r" % typ) |
|---|
| 963 |
format = 0 |
|---|
| 964 |
else: |
|---|
| 965 |
raise InternalError("prefer flag not recognized for type %r" % typ) |
|---|
| 966 |
else: |
|---|
| 967 |
# by default, prefer bin, but go with whatever exists |
|---|
| 968 |
if data.get("bin_out"): |
|---|
| 969 |
format = 1 |
|---|
| 970 |
elif data.get("txt_out"): |
|---|
| 971 |
format = 0 |
|---|
| 972 |
else: |
|---|
| 973 |
raise InternalError("no conversion fuction for type %r" % typ) |
|---|
| 974 |
return type_oid, format |
|---|
| 975 |
pg_type_info = staticmethod(pg_type_info) |
|---|
| 976 |
|
|---|
| 977 |
def pg_value(v, fc, **kwargs): |
|---|
| 978 |
typ = type(v) |
|---|
| 979 |
data = Types.py_types.get(typ) |
|---|
| 980 |
if data == None: |
|---|
| 981 |
raise NotSupportedError("type %r not mapped to pg type" % typ) |
|---|
| 982 |
elif data.get("tid") == -1: |
|---|
| 983 |
# special case: NULL values |
|---|
| 984 |
return None |
|---|
| 985 |
if fc == 0: |
|---|
| 986 |
func = data.get("txt_out") |
|---|
| 987 |
elif fc == 1: |
|---|
| 988 |
func = data.get("bin_out") |
|---|
| 989 |
else: |
|---|
| 990 |
raise InternalError("unrecognized format code %r" % fc) |
|---|
| 991 |
if func == None: |
|---|
| 992 |
raise NotSupportedError("type %r, format code %r not supported" % (typ, fc)) |
|---|
| 993 |
return func(v, **kwargs) |
|---|
| 994 |
pg_value = staticmethod(pg_value) |
|---|
| 995 |
|
|---|
| 996 |
def py_type_info(description): |
|---|
| 997 |
type_oid = description['type_oid'] |
|---|
| 998 |
data = Types.pg_types.get(type_oid) |
|---|
| 999 |
if data == None: |
|---|
| 1000 |
raise NotSupportedError("type oid %r not mapped to py type" % type_oid) |
|---|
| 1001 |
prefer = data.get("prefer") |
|---|
| 1002 |
if prefer != None: |
|---|
| 1003 |
if prefer == "bin": |
|---|
| 1004 |
if data.get("bin_in") == None: |
|---|
| 1005 |
raise InternalError("bin format prefered but not avail for type oid %r" % type_oid) |
|---|
| 1006 |
format = 1 |
|---|
| 1007 |
elif prefer == "txt": |
|---|
| 1008 |
if data.get("txt_in") == None: |
|---|
| 1009 |
raise InternalError("txt format prefered but not avail for type oid %r" % type_oid) |
|---|
| 1010 |
format = 0 |
|---|
| 1011 |
else: |
|---|
| 1012 |
raise InternalError("prefer flag not recognized for type oid %r" % type_oid) |
|---|
| 1013 |
else: |
|---|
| 1014 |
# by default, prefer bin, but go with whatever exists |
|---|
| 1015 |
if data.get("bin_in"): |
|---|
| 1016 |
format = 1 |
|---|
| 1017 |
elif data.get("txt_in"): |
|---|
|
|---|