| 1 |
# vim: sw=4:expandtab:foldmethod=marker |
|---|
| 2 |
# |
|---|
| 3 |
# Copyright (c) 2007, Mathieu Fenniak |
|---|
| 4 |
# All rights reserved. |
|---|
| 5 |
# |
|---|
| 6 |
# Redistribution and use in source and binary forms, with or without |
|---|
| 7 |
# modification, are permitted provided that the following conditions are |
|---|
| 8 |
# met: |
|---|
| 9 |
# |
|---|
| 10 |
# * Redistributions of source code must retain the above copyright notice, |
|---|
| 11 |
# this list of conditions and the following disclaimer. |
|---|
| 12 |
# * Redistributions in binary form must reproduce the above copyright notice, |
|---|
| 13 |
# this list of conditions and the following disclaimer in the documentation |
|---|
| 14 |
# and/or other materials provided with the distribution. |
|---|
| 15 |
# * The name of the author may not be used to endorse or promote products |
|---|
| 16 |
# derived from this software without specific prior written permission. |
|---|
| 17 |
# |
|---|
| 18 |
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
|---|
| 19 |
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
|---|
| 20 |
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE |
|---|
| 21 |
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE |
|---|
| 22 |
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR |
|---|
| 23 |
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF |
|---|
| 24 |
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS |
|---|
| 25 |
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN |
|---|
| 26 |
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) |
|---|
| 27 |
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE |
|---|
| 28 |
# POSSIBILITY OF SUCH DAMAGE. |
|---|
| 29 |
|
|---|
| 30 |
__author__ = "Mathieu Fenniak" |
|---|
| 31 |
|
|---|
| 32 |
import socket |
|---|
| 33 |
import struct |
|---|
| 34 |
import datetime |
|---|
| 35 |
import md5 |
|---|
| 36 |
import decimal |
|---|
| 37 |
import threading |
|---|
| 38 |
import time |
|---|
| 39 |
|
|---|
| 40 |
debug_log = file("/Users/mfenniak/SQLAlchemy-0.3.5/pg8000_debug.log", "w") |
|---|
| 41 |
|
|---|
| 42 |
class Warning(StandardError): |
|---|
| 43 |
pass |
|---|
| 44 |
|
|---|
| 45 |
class Error(StandardError): |
|---|
| 46 |
pass |
|---|
| 47 |
|
|---|
| 48 |
class InterfaceError(Error): |
|---|
| 49 |
pass |
|---|
| 50 |
|
|---|
| 51 |
class DatabaseError(Error): |
|---|
| 52 |
pass |
|---|
| 53 |
|
|---|
| 54 |
class DataError(DatabaseError): |
|---|
| 55 |
pass |
|---|
| 56 |
|
|---|
| 57 |
class OperationalError(DatabaseError): |
|---|
| 58 |
pass |
|---|
| 59 |
|
|---|
| 60 |
class IntegrityError(DatabaseError): |
|---|
| 61 |
pass |
|---|
| 62 |
|
|---|
| 63 |
class InternalError(DatabaseError): |
|---|
| 64 |
pass |
|---|
| 65 |
|
|---|
| 66 |
class ProgrammingError(DatabaseError): |
|---|
| 67 |
pass |
|---|
| 68 |
|
|---|
| 69 |
class NotSupportedError(DatabaseError): |
|---|
| 70 |
pass |
|---|
| 71 |
|
|---|
| 72 |
class DataIterator(object): |
|---|
| 73 |
def __init__(self, obj, func): |
|---|
| 74 |
self.obj = obj |
|---|
| 75 |
self.func = func |
|---|
| 76 |
|
|---|
| 77 |
def __iter__(self): |
|---|
| 78 |
return self |
|---|
| 79 |
|
|---|
| 80 |
def next(self): |
|---|
| 81 |
retval = self.func(self.obj) |
|---|
| 82 |
if retval == None: |
|---|
| 83 |
raise StopIteration() |
|---|
| 84 |
return retval |
|---|
| 85 |
|
|---|
| 86 |
class DBAPI(object): |
|---|
| 87 |
Warning = Warning |
|---|
| 88 |
Error = Error |
|---|
| 89 |
InterfaceError = InterfaceError |
|---|
| 90 |
InternalError = InternalError |
|---|
| 91 |
DatabaseError = DatabaseError |
|---|
| 92 |
DataError = DataError |
|---|
| 93 |
OperationalError = OperationalError |
|---|
| 94 |
IntegrityError = IntegrityError |
|---|
| 95 |
ProgrammingError = ProgrammingError |
|---|
| 96 |
NotSupportedError = NotSupportedError |
|---|
| 97 |
|
|---|
| 98 |
apilevel = "2.0" |
|---|
| 99 |
threadsafety = 3 |
|---|
| 100 |
paramstyle = 'format' # paramstyle can be changed to any DB-API paramstyle |
|---|
| 101 |
|
|---|
| 102 |
def convert_paramstyle(src_style, query, args): |
|---|
| 103 |
# I don't see any way to avoid scanning the query string char by char, |
|---|
| 104 |
# so we might as well take that careful approach and create a |
|---|
| 105 |
# state-based scanner. We'll use int variables for the state. |
|---|
| 106 |
# 0 -- outside quoted string |
|---|
| 107 |
# 1 -- inside single-quote string '...' |
|---|
| 108 |
# 2 -- inside quoted identifier "..." |
|---|
| 109 |
# 3 -- inside escaped single-quote string, E'...' |
|---|
| 110 |
debug_log.write("convert_paramstyle(%r, %r, %r)\n" % (src_style, query, args)) |
|---|
| 111 |
state = 0 |
|---|
| 112 |
output_query = "" |
|---|
| 113 |
output_args = [] |
|---|
| 114 |
if src_style == "numeric": |
|---|
| 115 |
output_args = args |
|---|
| 116 |
elif src_style in ("pyformat", "named"): |
|---|
| 117 |
mapping_to_idx = {} |
|---|
| 118 |
i = 0 |
|---|
| 119 |
while 1: |
|---|
| 120 |
if i == len(query): |
|---|
| 121 |
break |
|---|
| 122 |
c = query[i] |
|---|
| 123 |
# print "begin loop", repr(i), repr(c), repr(state) |
|---|
| 124 |
if state == 0: |
|---|
| 125 |
if c == "'": |
|---|
| 126 |
i += 1 |
|---|
| 127 |
output_query += c |
|---|
| 128 |
state = 1 |
|---|
| 129 |
elif c == '"': |
|---|
| 130 |
i += 1 |
|---|
| 131 |
output_query += c |
|---|
| 132 |
state = 2 |
|---|
| 133 |
elif c == 'E': |
|---|
| 134 |
# check for escaped single-quote string |
|---|
| 135 |
i += 1 |
|---|
| 136 |
if i < len(query) and i > 1 and query[i] == "'": |
|---|
| 137 |
i += 1 |
|---|
| 138 |
output_query += "E'" |
|---|
| 139 |
state = 3 |
|---|
| 140 |
else: |
|---|
| 141 |
output_query += c |
|---|
| 142 |
elif src_style == "qmark" and c == "?": |
|---|
| 143 |
i += 1 |
|---|
| 144 |
param_idx = len(output_args) |
|---|
| 145 |
if param_idx == len(args): |
|---|
| 146 |
raise ProgrammingError("too many parameter fields, not enough parameters") |
|---|
| 147 |
output_args.append(args[param_idx]) |
|---|
| 148 |
output_query += "$" + str(param_idx + 1) |
|---|
| 149 |
elif src_style == "numeric" and c == ":": |
|---|
| 150 |
i += 1 |
|---|
| 151 |
if i < len(query) and i > 1 and query[i].isdigit(): |
|---|
| 152 |
output_query += "$" + query[i] |
|---|
| 153 |
i += 1 |
|---|
| 154 |
else: |
|---|
| 155 |
raise ProgrammingError("numeric parameter : does not have numeric arg") |
|---|
| 156 |
elif src_style == "named" and c == ":": |
|---|
| 157 |
name = "" |
|---|
| 158 |
while 1: |
|---|
| 159 |
i += 1 |
|---|
| 160 |
if i == len(query): |
|---|
| 161 |
break |
|---|
| 162 |
c = query[i] |
|---|
| 163 |
if c.isalnum(): |
|---|
| 164 |
name += c |
|---|
| 165 |
else: |
|---|
| 166 |
break |
|---|
| 167 |
if name == "": |
|---|
| 168 |
raise ProgrammingError("empty name of named parameter") |
|---|
| 169 |
idx = mapping_to_idx.get(name) |
|---|
| 170 |
if idx == None: |
|---|
| 171 |
idx = len(output_args) |
|---|
| 172 |
output_args.append(args[name]) |
|---|
| 173 |
idx += 1 |
|---|
| 174 |
mapping_to_idx[name] = idx |
|---|
| 175 |
output_query += "$" + str(idx) |
|---|
| 176 |
elif src_style == "format" and c == "%": |
|---|
| 177 |
i += 1 |
|---|
| 178 |
if i < len(query) and i > 1: |
|---|
| 179 |
if query[i] == "s": |
|---|
| 180 |
param_idx = len(output_args) |
|---|
| 181 |
if param_idx == len(args): |
|---|
| 182 |
raise ProgrammingError("too many parameter fields, not enough parameters") |
|---|
| 183 |
output_args.append(args[param_idx]) |
|---|
| 184 |
output_query += "$" + str(param_idx + 1) |
|---|
| 185 |
elif query[i] == "%": |
|---|
| 186 |
output_query += "%" |
|---|
| 187 |
else: |
|---|
| 188 |
raise ProgrammingError("Only %s and %% are supported") |
|---|
| 189 |
i += 1 |
|---|
| 190 |
else: |
|---|
| 191 |
raise ProgrammingError("numeric parameter : does not have numeric arg") |
|---|
| 192 |
elif src_style == "pyformat" and c == "%": |
|---|
| 193 |
i += 1 |
|---|
| 194 |
if i < len(query) and i > 1: |
|---|
| 195 |
if query[i] == "(": |
|---|
| 196 |
i += 1 |
|---|
| 197 |
# begin mapping name |
|---|
| 198 |
end_idx = query.find(')', i) |
|---|
| 199 |
if end_idx == -1: |
|---|
| 200 |
raise ProgrammingError("began pyformat dict read, but couldn't find end of name") |
|---|
| 201 |
else: |
|---|
| 202 |
name = query[i:end_idx] |
|---|
| 203 |
i = end_idx + 1 |
|---|
| 204 |
if i < len(query) and query[i] == "s": |
|---|
| 205 |
i += 1 |
|---|
| 206 |
idx = mapping_to_idx.get(name) |
|---|
| 207 |
if idx == None: |
|---|
| 208 |
idx = len(output_args) |
|---|
| 209 |
output_args.append(args[name]) |
|---|
| 210 |
idx += 1 |
|---|
| 211 |
mapping_to_idx[name] = idx |
|---|
| 212 |
output_query += "$" + str(idx) |
|---|
| 213 |
else: |
|---|
| 214 |
raise ProgrammingError("format not specified or not supported (only %(...)s supported)") |
|---|
| 215 |
elif query[i] == "%": |
|---|
| 216 |
output_query += "%" |
|---|
| 217 |
else: |
|---|
| 218 |
i += 1 |
|---|
| 219 |
output_query += c |
|---|
| 220 |
elif state == 1: |
|---|
| 221 |
output_query += c |
|---|
| 222 |
i += 1 |
|---|
| 223 |
if c == "'": |
|---|
| 224 |
# Could be a double '' |
|---|
| 225 |
if i < len(query) and query[i] == "'": |
|---|
| 226 |
# is a double quote. |
|---|
| 227 |
output_query += query[i] |
|---|
| 228 |
i += 1 |
|---|
| 229 |
else: |
|---|
| 230 |
state = 0 |
|---|
| 231 |
elif src_style in ("pyformat","format") and c == "%": |
|---|
| 232 |
# hm... we're only going to support an escaped percent sign |
|---|
| 233 |
if i < len(query): |
|---|
| 234 |
if query[i] == "%": |
|---|
| 235 |
# good. We already output the first percent sign. |
|---|
| 236 |
i += 1 |
|---|
| 237 |
else: |
|---|
| 238 |
raise ProgrammingError("'%" + query[i] + "' not supported in quoted string") |
|---|
| 239 |
elif state == 2: |
|---|
| 240 |
output_query += c |
|---|
| 241 |
i += 1 |
|---|
| 242 |
if c == '"': |
|---|
| 243 |
state = 0 |
|---|
| 244 |
elif src_style in ("pyformat","format") and c == "%": |
|---|
| 245 |
# hm... we're only going to support an escaped percent sign |
|---|
| 246 |
if i < len(query): |
|---|
| 247 |
if query[i] == "%": |
|---|
| 248 |
# good. We already output the first percent sign. |
|---|
| 249 |
i += 1 |
|---|
| 250 |
else: |
|---|
| 251 |
raise ProgrammingError("'%" + query[i] + "' not supported in quoted string") |
|---|
| 252 |
elif state == 3: |
|---|
| 253 |
output_query += c |
|---|
| 254 |
i += 1 |
|---|
| 255 |
if c == "\\": |
|---|
| 256 |
# check for escaped single-quote |
|---|
| 257 |
if i < len(query) and query[i] == "'": |
|---|
| 258 |
output_query += "'" |
|---|
| 259 |
i += 1 |
|---|
| 260 |
elif c == "'": |
|---|
| 261 |
state = 0 |
|---|
| 262 |
elif src_style in ("pyformat","format") and c == "%": |
|---|
| 263 |
# hm... we're only going to support an escaped percent sign |
|---|
| 264 |
if i < len(query): |
|---|
| 265 |
if query[i] == "%": |
|---|
| 266 |
# good. We already output the first percent sign. |
|---|
| 267 |
i += 1 |
|---|
| 268 |
else: |
|---|
| 269 |
raise ProgrammingError("'%" + query[i] + "' not supported in quoted string") |
|---|
| 270 |
|
|---|
| 271 |
return output_query, tuple(output_args) |
|---|
| 272 |
convert_paramstyle = staticmethod(convert_paramstyle) |
|---|
| 273 |
|
|---|
| 274 |
|
|---|
| 275 |
class CursorWrapper(object): |
|---|
| 276 |
def __init__(self, conn): |
|---|
| 277 |
self.cursor = Cursor(conn) |
|---|
| 278 |
self.arraysize = 1 |
|---|
| 279 |
|
|---|
| 280 |
rowcount = property(lambda self: self._getRowCount()) |
|---|
| 281 |
def _getRowCount(self): |
|---|
| 282 |
return -1 |
|---|
| 283 |
|
|---|
| 284 |
description = property(lambda self: self._getDescription()) |
|---|
| 285 |
def _getDescription(self): |
|---|
| 286 |
if self.cursor.row_description == None: |
|---|
| 287 |
return None |
|---|
| 288 |
columns = [] |
|---|
| 289 |
for col in self.cursor.row_description: |
|---|
| 290 |
columns.append((col["name"], col["type_oid"])) |
|---|
| 291 |
return columns |
|---|
| 292 |
|
|---|
| 293 |
def execute(self, operation, args=()): |
|---|
| 294 |
debug_log.write("execute(%r, %r)\n" % (operation, args)) |
|---|
| 295 |
if self.cursor == None: |
|---|
| 296 |
raise InterfaceError("cursor is closed") |
|---|
| 297 |
new_query, new_args = DBAPI.convert_paramstyle(DBAPI.paramstyle, operation, args) |
|---|
| 298 |
try: |
|---|
| 299 |
self.cursor.execute(new_query, *new_args) |
|---|
| 300 |
except: |
|---|
| 301 |
# any error will rollback the transaction to-date |
|---|
| 302 |
self.cursor.connection.rollback() |
|---|
| 303 |
raise |
|---|
| 304 |
|
|---|
| 305 |
def executemany(self, operation, parameter_sets): |
|---|
| 306 |
for parameters in parameter_sets: |
|---|
| 307 |
self.execute(operation, parameters) |
|---|
| 308 |
|
|---|
| 309 |
def fetchone(self): |
|---|
| 310 |
if self.cursor == None: |
|---|
| 311 |
raise InterfaceError("cursor is closed") |
|---|
| 312 |
return self.cursor.read_tuple() |
|---|
| 313 |
|
|---|
| 314 |
def fetchmany(self, size=None): |
|---|
| 315 |
if size == None: |
|---|
| 316 |
size = self.arraysize |
|---|
| 317 |
rows = [] |
|---|
| 318 |
for i in range(size): |
|---|
| 319 |
rows.append(self.fetchone()) |
|---|
| 320 |
return rows |
|---|
| 321 |
|
|---|
| 322 |
def fetchall(self): |
|---|
| 323 |
if self.cursor == None: |
|---|
| 324 |
raise InterfaceError("cursor is closed") |
|---|
| 325 |
return tuple(self.cursor.iterate_tuple()) |
|---|
| 326 |
|
|---|
| 327 |
def close(self): |
|---|
| 328 |
self.cursor = None |
|---|
| 329 |
|
|---|
| 330 |
def setinputsizes(self, sizes): |
|---|
| 331 |
pass |
|---|
| 332 |
|
|---|
| 333 |
def setoutputsize(self, size, column=None): |
|---|
| 334 |
pass |
|---|
| 335 |
|
|---|
| 336 |
class ConnectionWrapper(object): |
|---|
| 337 |
def __init__(self, **kwargs): |
|---|
| 338 |
self.conn = Connection(**kwargs) |
|---|
| 339 |
self.conn.begin() |
|---|
| 340 |
|
|---|
| 341 |
def cursor(self): |
|---|
| 342 |
return DBAPI.CursorWrapper(self.conn) |
|---|
| 343 |
|
|---|
| 344 |
def commit(self): |
|---|
| 345 |
# There's a threading bug here. If a query is sent after the |
|---|
| 346 |
# commit, but before the begin, it will be executed immediately |
|---|
| 347 |
# without a surrounding transaction. Like all threading bugs -- it |
|---|
| 348 |
# sounds unlikely, until it happens every time in one |
|---|
| 349 |
# application... however, to fix this, we need to lock the |
|---|
| 350 |
# database connection entirely, so that no cursors can execute |
|---|
| 351 |
# statements on other threads. Support for that type of lock will |
|---|
| 352 |
# be done later. |
|---|
| 353 |
if self.conn == None: |
|---|
| 354 |
raise InterfaceError("connection is closed") |
|---|
| 355 |
self.conn.commit() |
|---|
| 356 |
self.conn.begin() |
|---|
| 357 |
|
|---|
| 358 |
def rollback(self): |
|---|
| 359 |
# see bug description in commit. |
|---|
| 360 |
if self.conn == None: |
|---|
| 361 |
raise InterfaceError("connection is closed") |
|---|
| 362 |
self.conn.rollback() |
|---|
| 363 |
self.conn.begin() |
|---|
| 364 |
|
|---|
| 365 |
def close(self): |
|---|
| 366 |
self.conn = None |
|---|
| 367 |
|
|---|
| 368 |
def connect(user, host=None, unix_sock=None, port=5432, database=None, password=None, socket_timeout=60, ssl=False): |
|---|
| 369 |
return DBAPI.ConnectionWrapper(user=user, host=host, |
|---|
| 370 |
unix_sock=unix_sock, port=port, database=database, |
|---|
| 371 |
password=password, socket_timeout=socket_timeout, ssl=ssl) |
|---|
| 372 |
connect = staticmethod(connect) |
|---|
| 373 |
|
|---|
| 374 |
def Date(year, month, day): |
|---|
| 375 |
return datetime.date(year, month, day) |
|---|
| 376 |
Date = staticmethod(Date) |
|---|
| 377 |
|
|---|
| 378 |
def Time(hour, minute, second): |
|---|
| 379 |
return datetime.time(hour, minute, second) |
|---|
| 380 |
Time = staticmethod(Time) |
|---|
| 381 |
|
|---|
| 382 |
def Timestamp(year, month, day, hour, minute, second): |
|---|
| 383 |
return datetime.datetime(year, month, day, hour, minute, second) |
|---|
| 384 |
Timestamp = staticmethod(Timestamp) |
|---|
| 385 |
|
|---|
| 386 |
def DateFromTicks(ticks): |
|---|
| 387 |
return DBAPI.Date(*time.localtime(ticks)[:3]) |
|---|
| 388 |
DateFromTicks = staticmethod(DateFromTicks) |
|---|
| 389 |
|
|---|
| 390 |
def TimeFromTicks(ticks): |
|---|
| 391 |
return DBAPI.Time(*time.localtime(ticks)[3:6]) |
|---|
| 392 |
TimeFromTicks = staticmethod(TimeFromTicks) |
|---|
| 393 |
|
|---|
| 394 |
def TimestampFromTicks(ticks): |
|---|
| 395 |
return DBAPI.Timestamp(*time.localtime(ticks)[:6]) |
|---|
| 396 |
TimestampFromTicks = staticmethod(TimestampFromTicks) |
|---|
| 397 |
|
|---|
| 398 |
def Binary(value): |
|---|
| 399 |
return Bytea(value) |
|---|
| 400 |
Binary = staticmethod(Binary) |
|---|
| 401 |
|
|---|
| 402 |
# I have no idea what this would be used for by a client app. Should it be |
|---|
| 403 |
# TEXT, VARCHAR, CHAR? It will only compare against row_description's |
|---|
| 404 |
# type_code if it is this one type. It is the TEXT type_oid for now. |
|---|
| 405 |
STRING = 25 |
|---|
| 406 |
|
|---|
| 407 |
# bytea type_oid |
|---|
| 408 |
BINARY = 17 |
|---|
| 409 |
|
|---|
| 410 |
# numeric type_oid |
|---|
| 411 |
NUMBER = 1700 |
|---|
| 412 |
|
|---|
| 413 |
# timestamp type_oid |
|---|
| 414 |
DATETIME = 1114 |
|---|
| 415 |
|
|---|
| 416 |
# oid type_oid |
|---|
| 417 |
ROWID = 26 |
|---|
| 418 |
|
|---|
| 419 |
|
|---|
| 420 |
## |
|---|
| 421 |
# This class represents a prepared statement. A prepared statement is |
|---|
| 422 |
# pre-parsed on the server, which reduces the need to parse the query every |
|---|
| 423 |
# time it is run. The statement can have parameters in the form of $1, $2, $3, |
|---|
| 424 |
# etc. When parameters are used, the types of the parameters need to be |
|---|
| 425 |
# specified when creating the prepared statement. |
|---|
| 426 |
# <p> |
|---|
| 427 |
# As of v1.01, instances of this class are thread-safe. This means that a |
|---|
| 428 |
# single PreparedStatement can be accessed by multiple threads without the |
|---|
| 429 |
# internal consistency of the statement being altered. However, the |
|---|
| 430 |
# responsibility is on the client application to ensure that one thread reading |
|---|
| 431 |
# from a statement isn't affected by another thread starting a new query with |
|---|
| 432 |
# the same statement. |
|---|
| 433 |
# <p> |
|---|
| 434 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 435 |
# |
|---|
| 436 |
# @param connection An instance of {@link Connection Connection}. |
|---|
| 437 |
# |
|---|
| 438 |
# @param statement The SQL statement to be represented, often containing |
|---|
| 439 |
# parameters in the form of $1, $2, $3, etc. |
|---|
| 440 |
# |
|---|
| 441 |
# @param types Python type objects for each parameter in the SQL |
|---|
| 442 |
# statement. For example, int, float, str. |
|---|
| 443 |
class PreparedStatement(object): |
|---|
| 444 |
|
|---|
| 445 |
## |
|---|
| 446 |
# Determines the number of rows to read from the database server at once. |
|---|
| 447 |
# Reading more rows increases performance at the cost of memory. The |
|---|
| 448 |
# default value is 100 rows. The affect of this parameter is transparent. |
|---|
| 449 |
# That is, the library reads more rows when the cache is empty |
|---|
| 450 |
# automatically. |
|---|
| 451 |
# <p> |
|---|
| 452 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. It is |
|---|
| 453 |
# possible that implementation changes in the future could cause this |
|---|
| 454 |
# parameter to be ignored.O |
|---|
| 455 |
row_cache_size = 100 |
|---|
| 456 |
|
|---|
| 457 |
def __init__(self, connection, statement, *types): |
|---|
| 458 |
self.c = connection.c |
|---|
| 459 |
self._portal_name = "pg8000_portal_%s_%s" % (id(self.c), id(self)) |
|---|
| 460 |
self._statement_name = "pg8000_statement_%s_%s" % (id(self.c), id(self)) |
|---|
| 461 |
self._row_desc = None |
|---|
| 462 |
self._cached_rows = [] |
|---|
| 463 |
self._command_complete = True |
|---|
| 464 |
self._parse_row_desc = self.c.parse(self._statement_name, statement, types) |
|---|
| 465 |
self._lock = threading.RLock() |
|---|
| 466 |
|
|---|
| 467 |
def __del__(self): |
|---|
| 468 |
# This __del__ should work with garbage collection / non-instant |
|---|
| 469 |
# cleanup. It only really needs to be called right away if the same |
|---|
| 470 |
# object id (and therefore the same statement name) might be reused |
|---|
| 471 |
# soon, and clearly that wouldn't happen in a GC situation. |
|---|
| 472 |
self.c.close_statement(self._statement_name) |
|---|
| 473 |
|
|---|
| 474 |
row_description = property(lambda self: self._getRowDescription()) |
|---|
| 475 |
def _getRowDescription(self): |
|---|
| 476 |
if self._row_desc == None: |
|---|
| 477 |
return None |
|---|
| 478 |
return self._row_desc.fields |
|---|
| 479 |
|
|---|
| 480 |
## |
|---|
| 481 |
# Run the SQL prepared statement with the given parameters. |
|---|
| 482 |
# <p> |
|---|
| 483 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 484 |
def execute(self, *args): |
|---|
| 485 |
self._lock.acquire() |
|---|
| 486 |
try: |
|---|
| 487 |
if not self._command_complete: |
|---|
| 488 |
# cleanup last execute |
|---|
| 489 |
self._cached_rows = [] |
|---|
| 490 |
self.c.close_portal(self._portal_name) |
|---|
| 491 |
self._command_complete = False |
|---|
| 492 |
self._row_desc = self.c.bind(self._portal_name, self._statement_name, args, self._parse_row_desc) |
|---|
| 493 |
if self._row_desc: |
|---|
| 494 |
# We execute our cursor right away to fill up our cache. This |
|---|
| 495 |
# prevents the cursor from being destroyed, apparently, by a rogue |
|---|
| 496 |
# Sync between Bind and Execute. Since it is quite likely that |
|---|
| 497 |
# data will be read from us right away anyways, this seems a safe |
|---|
| 498 |
# move for now. |
|---|
| 499 |
self._fill_cache() |
|---|
| 500 |
finally: |
|---|
| 501 |
self._lock.release() |
|---|
| 502 |
|
|---|
| 503 |
def _fill_cache(self): |
|---|
| 504 |
self._lock.acquire() |
|---|
| 505 |
try: |
|---|
| 506 |
if self._cached_rows: |
|---|
| 507 |
raise InternalError("attempt to fill cache that isn't empty") |
|---|
| 508 |
end_of_data, rows = self.c.fetch_rows(self._portal_name, self.row_cache_size, self._row_desc) |
|---|
| 509 |
self._cached_rows = rows |
|---|
| 510 |
if end_of_data: |
|---|
| 511 |
self._command_complete = True |
|---|
| 512 |
finally: |
|---|
| 513 |
self._lock.release() |
|---|
| 514 |
|
|---|
| 515 |
def _fetch(self): |
|---|
| 516 |
self._lock.acquire() |
|---|
| 517 |
try: |
|---|
| 518 |
if not self._cached_rows: |
|---|
| 519 |
if self._command_complete: |
|---|
| 520 |
return None |
|---|
| 521 |
self._fill_cache() |
|---|
| 522 |
if self._command_complete and not self._cached_rows: |
|---|
| 523 |
# fill cache tells us the command is complete, but yet we have |
|---|
| 524 |
# no rows after filling our cache. This is a special case when |
|---|
| 525 |
# a query returns no rows. |
|---|
| 526 |
return None |
|---|
| 527 |
row = self._cached_rows[0] |
|---|
| 528 |
del self._cached_rows[0] |
|---|
| 529 |
return tuple(row) |
|---|
| 530 |
finally: |
|---|
| 531 |
self._lock.release() |
|---|
| 532 |
|
|---|
| 533 |
## |
|---|
| 534 |
# Read a row from the database server, and return it in a dictionary |
|---|
| 535 |
# indexed by column name/alias. This method will raise an error if two |
|---|
| 536 |
# columns have the same name. Returns None after the last row. |
|---|
| 537 |
# <p> |
|---|
| 538 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 539 |
def read_dict(self): |
|---|
| 540 |
row = self._fetch() |
|---|
| 541 |
if row == None: |
|---|
| 542 |
return row |
|---|
| 543 |
retval = {} |
|---|
| 544 |
for i in range(len(self._row_desc.fields)): |
|---|
| 545 |
col_name = self._row_desc.fields[i]['name'] |
|---|
| 546 |
if retval.has_key(col_name): |
|---|
| 547 |
raise InterfaceError("cannot return dict of row when two columns have the same name (%r)" % (col_name,)) |
|---|
| 548 |
retval[col_name] = row[i] |
|---|
| 549 |
return retval |
|---|
| 550 |
|
|---|
| 551 |
## |
|---|
| 552 |
# Read a row from the database server, and return it as a tuple of values. |
|---|
| 553 |
# Returns None after the last row. |
|---|
| 554 |
# <p> |
|---|
| 555 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 556 |
def read_tuple(self): |
|---|
| 557 |
row = self._fetch() |
|---|
| 558 |
if row == None: |
|---|
| 559 |
return row |
|---|
| 560 |
return row |
|---|
| 561 |
|
|---|
| 562 |
## |
|---|
| 563 |
# Return an iterator for the output of this statement. The iterator will |
|---|
| 564 |
# return a tuple for each row, in the same manner as {@link |
|---|
| 565 |
# #PreparedStatement.read_tuple read_tuple}. |
|---|
| 566 |
# <p> |
|---|
| 567 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 568 |
def iterate_tuple(self): |
|---|
| 569 |
return DataIterator(self, PreparedStatement.read_tuple) |
|---|
| 570 |
|
|---|
| 571 |
## |
|---|
| 572 |
# Return an iterator for the output of this statement. The iterator will |
|---|
| 573 |
# return a dict for each row, in the same manner as {@link |
|---|
| 574 |
# #PreparedStatement.read_dict read_dict}. |
|---|
| 575 |
# <p> |
|---|
| 576 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 577 |
def iterate_dict(self): |
|---|
| 578 |
return DataIterator(self, PreparedStatement.read_dict) |
|---|
| 579 |
|
|---|
| 580 |
## |
|---|
| 581 |
# The Cursor class allows multiple queries to be performed concurrently with a |
|---|
| 582 |
# single PostgreSQL connection. The Cursor object is implemented internally by |
|---|
| 583 |
# using a {@link PreparedStatement PreparedStatement} object, so if you plan to |
|---|
| 584 |
# use a statement multiple times, you might as well create a PreparedStatement |
|---|
| 585 |
# and save a small amount of reparsing time. |
|---|
| 586 |
# <p> |
|---|
| 587 |
# As of v1.01, instances of this class are thread-safe. See {@link |
|---|
| 588 |
# PreparedStatement PreparedStatement} for more information. |
|---|
| 589 |
# <p> |
|---|
| 590 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 591 |
# |
|---|
| 592 |
# @param connection An instance of {@link Connection Connection}. |
|---|
| 593 |
class Cursor(object): |
|---|
| 594 |
def __init__(self, connection): |
|---|
| 595 |
self.connection = connection |
|---|
| 596 |
self._stmt = None |
|---|
| 597 |
|
|---|
| 598 |
row_description = property(lambda self: self._getRowDescription()) |
|---|
| 599 |
def _getRowDescription(self): |
|---|
| 600 |
if self._stmt == None: |
|---|
| 601 |
return None |
|---|
| 602 |
return self._stmt.row_description |
|---|
| 603 |
|
|---|
| 604 |
## |
|---|
| 605 |
# Run an SQL statement using this cursor. The SQL statement can have |
|---|
| 606 |
# parameters in the form of $1, $2, $3, etc., which will be filled in by |
|---|
| 607 |
# the additional arguments passed to this function. |
|---|
| 608 |
# <p> |
|---|
| 609 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 610 |
# @param query The SQL statement to execute. |
|---|
| 611 |
def execute(self, query, *args): |
|---|
| 612 |
self._stmt = PreparedStatement(self.connection, query, *[type(x) for x in args]) |
|---|
| 613 |
self._stmt.execute(*args) |
|---|
| 614 |
|
|---|
| 615 |
## |
|---|
| 616 |
# Read a row from the database server, and return it in a dictionary |
|---|
| 617 |
# indexed by column name/alias. This method will raise an error if two |
|---|
| 618 |
# columns have the same name. Returns None after the last row. |
|---|
| 619 |
# <p> |
|---|
| 620 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 621 |
def read_dict(self): |
|---|
| 622 |
if self._stmt == None: |
|---|
| 623 |
raise ProgrammingError("attempting to read from unexecuted cursor") |
|---|
| 624 |
return self._stmt.read_dict() |
|---|
| 625 |
|
|---|
| 626 |
## |
|---|
| 627 |
# Read a row from the database server, and return it as a tuple of values. |
|---|
| 628 |
# Returns None after the last row. |
|---|
| 629 |
# <p> |
|---|
| 630 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 631 |
def read_tuple(self): |
|---|
| 632 |
if self._stmt == None: |
|---|
| 633 |
raise ProgrammingError("attempting to read from unexecuted cursor") |
|---|
| 634 |
return self._stmt.read_tuple() |
|---|
| 635 |
|
|---|
| 636 |
## |
|---|
| 637 |
# Return an iterator for the output of this statement. The iterator will |
|---|
| 638 |
# return a tuple for each row, in the same manner as {@link |
|---|
| 639 |
# #PreparedStatement.read_tuple read_tuple}. |
|---|
| 640 |
# <p> |
|---|
| 641 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 642 |
def iterate_tuple(self): |
|---|
| 643 |
if self._stmt == None: |
|---|
| 644 |
raise ProgrammingError("attempting to read from unexecuted cursor") |
|---|
| 645 |
return self._stmt.iterate_tuple() |
|---|
| 646 |
|
|---|
| 647 |
## |
|---|
| 648 |
# Return an iterator for the output of this statement. The iterator will |
|---|
| 649 |
# return a dict for each row, in the same manner as {@link |
|---|
| 650 |
# #PreparedStatement.read_dict read_dict}. |
|---|
| 651 |
# <p> |
|---|
| 652 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 653 |
def iterate_dict(self): |
|---|
| 654 |
if self._stmt == None: |
|---|
| 655 |
raise ProgrammingError("attempting to read from unexecuted cursor") |
|---|
| 656 |
return self._stmt.iterate_dict() |
|---|
| 657 |
|
|---|
| 658 |
## |
|---|
| 659 |
# This class represents a connection to a PostgreSQL database. |
|---|
| 660 |
# <p> |
|---|
| 661 |
# The database connection is derived from the {@link #Cursor Cursor} class, |
|---|
| 662 |
# which provides a default cursor for running queries. It also provides |
|---|
| 663 |
# transaction control via the 'begin', 'commit', and 'rollback' methods. |
|---|
| 664 |
# Without beginning a transaction explicitly, all statements will autocommit to |
|---|
| 665 |
# the database. |
|---|
| 666 |
# <p> |
|---|
| 667 |
# As of v1.01, instances of this class are thread-safe. See {@link |
|---|
| 668 |
# PreparedStatement PreparedStatement} for more information. |
|---|
| 669 |
# <p> |
|---|
| 670 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 671 |
# |
|---|
| 672 |
# @param user The username to connect to the PostgreSQL server with. This |
|---|
| 673 |
# parameter is required. |
|---|
| 674 |
# |
|---|
| 675 |
# @keyparam host The hostname of the PostgreSQL server to connect with. |
|---|
| 676 |
# Providing this parameter is necessary for TCP/IP connections. One of either |
|---|
| 677 |
# host, or unix_sock, must be provided. |
|---|
| 678 |
# |
|---|
| 679 |
# @keyparam unix_sock The path to the UNIX socket to access the database |
|---|
| 680 |
# through, for example, '/tmp/.s.PGSQL.5432'. One of either unix_sock or host |
|---|
| 681 |
# must be provided. The port parameter will have no affect if unix_sock is |
|---|
| 682 |
# provided. |
|---|
| 683 |
# |
|---|
| 684 |
# @keyparam port The TCP/IP port of the PostgreSQL server instance. This |
|---|
| 685 |
# parameter defaults to 5432, the registered and common port of PostgreSQL |
|---|
| 686 |
# TCP/IP servers. |
|---|
| 687 |
# |
|---|
| 688 |
# @keyparam database The name of the database instance to connect with. This |
|---|
| 689 |
# parameter is optional, if omitted the PostgreSQL server will assume the |
|---|
| 690 |
# database name is the same as the username. |
|---|
| 691 |
# |
|---|
| 692 |
# @keyparam password The user password to connect to the server with. This |
|---|
| 693 |
# parameter is optional. If omitted, and the database server requests password |
|---|
| 694 |
# based authentication, the connection will fail. On the other hand, if this |
|---|
| 695 |
# parameter is provided and the database does not request password |
|---|
| 696 |
# authentication, then the password will not be used. |
|---|
| 697 |
# |
|---|
| 698 |
# @keyparam socket_timeout Socket connect timeout measured in seconds. |
|---|
| 699 |
# Defaults to 60 seconds. |
|---|
| 700 |
# |
|---|
| 701 |
# @keyparam ssl Use SSL encryption for TCP/IP socket. Defaults to False. |
|---|
| 702 |
class Connection(Cursor): |
|---|
| 703 |
def __init__(self, user, host=None, unix_sock=None, port=5432, database=None, password=None, socket_timeout=60, ssl=False): |
|---|
| 704 |
self._row_desc = None |
|---|
| 705 |
try: |
|---|
| 706 |
self.c = Protocol.Connection(unix_sock=unix_sock, host=host, port=port, socket_timeout=socket_timeout, ssl=ssl) |
|---|
| 707 |
#self.c.connect() |
|---|
| 708 |
self.c.authenticate(user, password=password, database=database) |
|---|
| 709 |
except socket.error, e: |
|---|
| 710 |
raise InterfaceError("communication error", e) |
|---|
| 711 |
Cursor.__init__(self, self) |
|---|
| 712 |
self._begin = PreparedStatement(self, "BEGIN TRANSACTION") |
|---|
| 713 |
self._commit = PreparedStatement(self, "COMMIT TRANSACTION") |
|---|
| 714 |
self._rollback = PreparedStatement(self, "ROLLBACK TRANSACTION") |
|---|
| 715 |
|
|---|
| 716 |
## |
|---|
| 717 |
# Begins a new transaction. |
|---|
| 718 |
# <p> |
|---|
| 719 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 720 |
def begin(self): |
|---|
| 721 |
self._begin.execute() |
|---|
| 722 |
|
|---|
| 723 |
## |
|---|
| 724 |
# Commits the running transaction. |
|---|
| 725 |
# <p> |
|---|
| 726 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 727 |
def commit(self): |
|---|
| 728 |
self._commit.execute() |
|---|
| 729 |
|
|---|
| 730 |
## |
|---|
| 731 |
# Rolls back the running transaction. |
|---|
| 732 |
# <p> |
|---|
| 733 |
# Stability: Added in v1.00, stability guaranteed for v1.xx. |
|---|
| 734 |
def rollback(self): |
|---|
| 735 |
self._rollback.execute() |
|---|
| 736 |
|
|---|
| 737 |
|
|---|
| 738 |
class Protocol(object): |
|---|
| 739 |
|
|---|
| 740 |
class SSLRequest(object): |
|---|
| 741 |
def __init__(self): |
|---|
| 742 |
pass |
|---|
| 743 |
|
|---|
| 744 |
def serialize(selF): |
|---|
| 745 |
return struct.pack("!ii", 8, 80877103) |
|---|
| 746 |
|
|---|
| 747 |
class StartupMessage(object): |
|---|
| 748 |
def __init__(self, user, database=None): |
|---|
| 749 |
self.user = user |
|---|
| 750 |
self.database = database |
|---|
| 751 |
|
|---|
| 752 |
def serialize(self): |
|---|
| 753 |
protocol = 196608 |
|---|
| 754 |
val = struct.pack("!i", protocol) |
|---|
| 755 |
val += "user\x00" + self.user + "\x00" |
|---|
| 756 |
if self.database: |
|---|
| 757 |
val += "database\x00" + self.database + "\x00" |
|---|
| 758 |
val += "\x00" |
|---|
| 759 |
val = struct.pack("!i", len(val) + 4) + val |
|---|
| 760 |
return val |
|---|
| 761 |
|
|---|
| 762 |
class Query(object): |
|---|
| 763 |
def __init__(self, qs): |
|---|
| 764 |
self.qs = qs |
|---|
| 765 |
|
|---|
| 766 |
def serialize(self): |
|---|
| 767 |
val = self.qs + "\x00" |
|---|
| 768 |
val = struct.pack("!i", len(val) + 4) + val |
|---|
| 769 |
val = "Q" + val |
|---|
| 770 |
return val |
|---|
| 771 |
|
|---|
| 772 |
class Parse(object): |
|---|
| 773 |
def __init__(self, ps, qs, type_oids): |
|---|
| 774 |
self.ps = ps |
|---|
| 775 |
self.qs = qs |
|---|
| 776 |
self.type_oids = type_oids |
|---|
| 777 |
|
|---|
| 778 |
def serialize(self): |
|---|
| 779 |
val = self.ps + "\x00" + self.qs + "\x00" |
|---|
| 780 |
val = val + struct.pack("!h", len(self.type_oids)) |
|---|
| 781 |
for oid in self.type_oids: |
|---|
| 782 |
# Parse message doesn't seem to handle the -1 type_oid for NULL |
|---|
| 783 |
# values that other messages handle. So we'll provide type_oid 705, |
|---|
| 784 |
# the PG "unknown" type. |
|---|
| 785 |
if oid == -1: oid = 705 |
|---|
| 786 |
val = val + struct.pack("!i", oid) |
|---|
| 787 |
val = struct.pack("!i", len(val) + 4) + val |
|---|
| 788 |
val = "P" + val |
|---|
| 789 |
return val |
|---|
| 790 |
|
|---|
| 791 |
class Bind(object): |
|---|
| 792 |
def __init__(self, portal, ps, in_fc, params, out_fc, client_encoding): |
|---|
| 793 |
self.portal = portal |
|---|
| 794 |
self.ps = ps |
|---|
| 795 |
self.in_fc = in_fc |
|---|
| 796 |
self.params = [] |
|---|
| 797 |
for i in range(len(params)): |
|---|
| 798 |
if len(self.in_fc) == 0: |
|---|
| 799 |
fc = 0 |
|---|
| 800 |
elif len(self.in_fc) == 1: |
|---|
| 801 |
fc = self.in_fc[0] |
|---|
| 802 |
else: |
|---|
| 803 |
fc = self.in_fc[i] |
|---|
| 804 |
self.params.append(Types.pg_value(params[i], fc, client_encoding = client_encoding)) |
|---|
| 805 |
self.out_fc = out_fc |
|---|
| 806 |
|
|---|
| 807 |
def serialize(self): |
|---|
| 808 |
val = self.portal + "\x00" + self.ps + "\x00" |
|---|
| 809 |
val = val + struct.pack("!h", len(self.in_fc)) |
|---|
| 810 |
for fc in self.in_fc: |
|---|
| 811 |
val = val + struct.pack("!h", fc) |
|---|
| 812 |
val = val + struct.pack("!h", len(self.params)) |
|---|
| 813 |
for param in self.params: |
|---|
| 814 |
if param == None: |
|---|
| 815 |
# special case, NULL value |
|---|
| 816 |
val = val + struct.pack("!i", -1) |
|---|
| 817 |
else: |
|---|
| 818 |
val = val + struct.pack("!i", len(param)) + param |
|---|
| 819 |
val = val + struct.pack("!h", len(self.out_fc)) |
|---|
| 820 |
for fc in self.out_fc: |
|---|
| 821 |
val = val + struct.pack("!h", fc) |
|---|
| 822 |
val = struct.pack("!i", len(val) + 4) + val |
|---|
| 823 |
val = "B" + val |
|---|
| 824 |
return val |
|---|
| 825 |
|
|---|
| 826 |
class Close(object): |
|---|
| 827 |
def __init__(self, typ, name): |
|---|
| 828 |
if len(typ) != 1: |
|---|
| 829 |
raise InternalError("Close typ must be 1 char") |
|---|
| 830 |
self.typ = typ |
|---|
| 831 |
self.name = name |
|---|
| 832 |
|
|---|
| 833 |
def serialize(self): |
|---|
| 834 |
val = self.typ + self.name + "\x00" |
|---|
| 835 |
val = struct.pack("!i", len(val) + 4) + val |
|---|
| 836 |
val = "C" + val |
|---|
| 837 |
return val |
|---|
| 838 |
|
|---|
| 839 |
class ClosePortal(Close): |
|---|
| 840 |
def __init__(self, name): |
|---|
| 841 |
Protocol.Close.__init__(self, "P", name) |
|---|
| 842 |
|
|---|
| 843 |
class ClosePreparedStatement(Close): |
|---|
| 844 |
def __init__(self, name): |
|---|
| 845 |
Protocol.Close.__init__(self, "S", name) |
|---|
| 846 |
|
|---|
| 847 |
class Describe(object): |
|---|
| 848 |
def __init__(self, typ, name): |
|---|
| 849 |
if len(typ) != 1: |
|---|
| 850 |
raise InternalError("Describe typ must be 1 char") |
|---|
| 851 |
self.typ = typ |
|---|
| 852 |
self.name = name |
|---|
| 853 |
|
|---|
| 854 |
def serialize(self): |
|---|
| 855 |
val = self.typ + self.name + "\x00" |
|---|
| 856 |
val = struct.pack("!i", len(val) + 4) + val |
|---|
| 857 |
val = "D" + val |
|---|
| 858 |
return val |
|---|
| 859 |
|
|---|
| 860 |
class DescribePortal(Describe): |
|---|
| 861 |
def __init__(self, name): |
|---|
| 862 |
Protocol.Describe.__init__(self, "P", name) |
|---|
| 863 |
|
|---|
| 864 |
class DescribePreparedStatement(Describe): |
|---|
| 865 |
def __init__(self, name): |
|---|
| 866 |
Protocol.Describe.__init__(self, "S", name) |
|---|
| 867 |
|
|---|
| 868 |
class Flush(object): |
|---|
| 869 |
def serialize(self): |
|---|
| 870 |
return 'H\x00\x00\x00\x04' |
|---|
| 871 |
|
|---|
| 872 |
class Sync(object): |
|---|
| 873 |
def serialize(self): |
|---|
| 874 |
return 'S\x00\x00\x00\x04' |
|---|
| 875 |
|
|---|
| 876 |
class PasswordMessage(object): |
|---|
| 877 |
def __init__(self, pwd): |
|---|
| 878 |
self.pwd = pwd |
|---|
| 879 |
|
|---|
| 880 |
def serialize(self): |
|---|
| 881 |
val = self.pwd + "\x00" |
|---|
| 882 |
val = struct.pack("!i", len(val) + 4) + val |
|---|
| 883 |
val = "p" + val |
|---|
| 884 |
return val |
|---|
| 885 |
|
|---|
| 886 |
class Execute(object): |
|---|
| 887 |
def __init__(self, portal, row_count): |
|---|
| 888 |
self.portal = portal |
|---|
| 889 |
self.row_count = row_count |
|---|
| 890 |
|
|---|
| 891 |
def serialize(self): |
|---|
| 892 |
val = self.portal + "\x00" + struct.pack("!i", self.row_count) |
|---|
| 893 |
val = struct.pack("!i", len(val) + 4) + val |
|---|
| 894 |
val = "E" + val |
|---|
| 895 |
return val |
|---|
| 896 |
|
|---|
| 897 |
class AuthenticationRequest(object): |
|---|
| 898 |
def __init__(self, data): |
|---|
| 899 |
pass |
|---|
| 900 |
|
|---|
| 901 |
def createFromData(data): |
|---|
| 902 |
ident = struct.unpack("!i", data[:4])[0] |
|---|
| 903 |
klass = Protocol.authentication_codes.get(ident, None) |
|---|
| 904 |
if klass != None: |
|---|
| 905 |
return klass(data[4:]) |
|---|
| 906 |
else: |
|---|
| 907 |
raise NotSupportedError("authentication method %r not supported" % (ident,)) |
|---|
| 908 |
createFromData = staticmethod(createFromData) |
|---|
| 909 |
|
|---|
| 910 |
def ok(self, conn, user, **kwargs): |
|---|
| 911 |
raise InternalError("ok method should be overridden on AuthenticationRequest instance") |
|---|
| 912 |
|
|---|
| 913 |
class AuthenticationOk(AuthenticationRequest): |
|---|
| 914 |
def ok(self, conn, user, **kwargs): |
|---|
| 915 |
return True |
|---|
| 916 |
|
|---|
| 917 |
class AuthenticationMD5Password(AuthenticationRequest): |
|---|
| 918 |
def __init__(self, data): |
|---|
| 919 |
self.salt = "".join(struct.unpack("4c", data)) |
|---|
| 920 |
|
|---|
| 921 |
def ok(self, conn, user, password=None, **kwargs): |
|---|
| 922 |
if password == None: |
|---|
| 923 |
raise InterfaceError("server requesting MD5 password authentication, but no password was provided") |
|---|
| 924 |
pwd = "md5" + md5.new(md5.new(password + user).hexdigest() + self.salt).hexdigest() |
|---|
| 925 |
conn._send(Protocol.PasswordMessage(pwd)) |
|---|
| 926 |
msg = conn._read_message() |
|---|
| 927 |
if isinstance(msg, Protocol.AuthenticationRequest): |
|---|
| 928 |
return msg.ok(conn, user) |
|---|
| 929 |
elif isinstance(msg, Protocol.ErrorResponse): |
|---|
| 930 |
if msg.code == "28000": |
|---|
| 931 |
raise InterfaceError("md5 password authentication failed") |
|---|
| 932 |
else: |
|---|
| 933 |
raise InternalError("server returned unexpected error %r" % msg) |
|---|
| 934 |
else: |
|---|
| 935 |
raise InternalError("server returned unexpected response %r" % msg) |
|---|
| 936 |
|
|---|
| 937 |
authentication_codes = { |
|---|
| 938 |
0: AuthenticationOk, |
|---|
| 939 |
5: AuthenticationMD5Password, |
|---|
| 940 |
} |
|---|
| 941 |
|
|---|
| 942 |
class ParameterStatus(object): |
|---|
| 943 |
def __init__(self, key, value): |
|---|
| 944 |
self.key = key |
|---|
| 945 |
self.value = value |
|---|
| 946 |
|
|---|
| 947 |
def createFromData(data): |
|---|
| 948 |
key = data[:data.find("\x00")] |
|---|
| 949 |
value = data[data.find("\x00")+1:-1] |
|---|
| 950 |
return Protocol.ParameterStatus(key, value) |
|---|
| 951 |
createFromData = staticmethod(createFromData) |
|---|
| 952 |
|
|---|
| 953 |
class BackendKeyData(object): |
|---|
| 954 |
def __init__(self, process_id, secret_key): |
|---|
| 955 |
self.process_id = process_id |
|---|
| 956 |
self.secret_key = secret_key |
|---|
| 957 |
|
|---|
| 958 |
def createFromData(data): |
|---|
| 959 |
process_id, secret_key = struct.unpack("!2i", data) |
|---|
| 960 |
return Protocol.BackendKeyData(process_id, secret_key) |
|---|
| 961 |
createFromData = staticmethod(createFromData) |
|---|
| 962 |
|
|---|
| 963 |
class NoData(object): |
|---|
| 964 |
def createFromData(data): |
|---|
| 965 |
return Protocol.NoData() |
|---|
| 966 |
createFromData = staticmethod(createFromData) |
|---|
| 967 |
|
|---|
| 968 |
class ParseComplete(object): |
|---|
| 969 |
def createFromData(data): |
|---|
| 970 |
return Protocol.ParseComplete() |
|---|
| 971 |
createFromData = staticmethod(createFromData) |
|---|
| 972 |
|
|---|
| 973 |
class BindComplete(object): |
|---|
| 974 |
def createFromData(data): |
|---|
| 975 |
return Protocol.BindComplete() |
|---|
| 976 |
createFromData = staticmethod(createFromData) |
|---|
| 977 |
|
|---|
| 978 |
class CloseComplete(object): |
|---|
| 979 |
def createFromData(data): |
|---|
| 980 |
return Protocol.CloseComplete() |
|---|
| 981 |
createFromData = staticmethod(createFromData) |
|---|
| 982 |
|
|---|
| 983 |
class PortalSuspended(object): |
|---|
| 984 |
def createFromData(data): |
|---|
| 985 |
return Protocol.PortalSuspended() |
|---|
| 986 |
createFromData = staticmethod(createFromData) |
|---|
| 987 |
|
|---|
| 988 |
class ReadyForQuery(object): |
|---|
| 989 |
def __init__(self, status): |
|---|
| 990 |
self.status = status |
|---|
| 991 |
|
|---|
| 992 |
def __repr__(self): |
|---|
| 993 |
return "<ReadyForQuery %s>" % \ |
|---|
| 994 |
{"I": "Idle", "T": "Idle in Transaction", "E": "Idle in Failed Transaction"}[self.status] |
|---|
| 995 |
|
|---|
| 996 |
def createFromData(data): |
|---|
| 997 |
return Protocol.ReadyForQuery(data) |
|---|
| 998 |
createFromData = staticmethod(createFromData) |
|---|
| 999 |
|
|---|
| 1000 |
class NoticeResponse(object): |
|---|
| 1001 |
def __init__(self): |
|---|
| 1002 |
pass |
|---|
| 1003 |
def createFromData(data): |
|---|
| 1004 |
# we could read the notice here, but we don't care yet. |
|---|
| 1005 |
return Protocol.NoticeResponse() |
|---|
| 1006 |
createFromData = staticmethod(createFromData) |
|---|
| 1007 |
|
|---|
| 1008 |
class ErrorResponse(object): |
|---|
| 1009 |
def __init__(self, severity, code, msg): |
|---|
| 1010 |
self.severity = severity |
|---|
| 1011 |
self.code = code |
|---|
| 1012 |
self.msg = msg |
|---|
| 1013 |
|
|---|
| 1014 |
def __repr__(self): |
|---|
| 1015 |
return "<ErrorResponse %s %s %r>" % (self.severity, self.code, self.msg) |
|---|
| 1016 |
|
|---|
| 1017 |
def createException(self): |
|---|
| 1018 |
return ProgrammingError(self.severity, self.code, self.msg) |
|---|
| 1019 |
|
|---|
| 1020 |
def createFromData(data): |
|---|
| 1021 |
args = {} |
|---|
| 1022 |
for s in data.split("\x00"): |
|---|
| 1023 |
if not s: |
|---|
| 1024 |
continue |
|---|
| 1025 |
elif s[0] == "S": |
|---|
| 1026 |
args["severity"] = s[1:] |
|---|
| 1027 |
elif s[0] == "C": |
|---|
| 1028 |
args["code"] = s[1:] |
|---|
| 1029 |
&nbs |
|---|