Files
web2py/gluon/contrib/pg8000/dbapi.py
T
2012-01-12 13:59:00 -06:00

796 lines
30 KiB
Python

# vim: sw=4:expandtab:foldmethod=marker
#
# Copyright (c) 2007-2009, Mathieu Fenniak
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * The name of the author may not be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
__author__ = "Mathieu Fenniak"
__version__ = "1.10"
import datetime
import time
import interface
import types
import threading
from errors import *
from warnings import warn
##
# The DBAPI level supported. Currently 2.0. This property is part of the
# DBAPI 2.0 specification.
apilevel = "2.0"
##
# Integer constant stating the level of thread safety the DBAPI interface
# supports. This DBAPI interface supports sharing of the module, connections,
# and cursors. This property is part of the DBAPI 2.0 specification.
threadsafety = 3
##
# String property stating the type of parameter marker formatting expected by
# the interface. This value defaults to "format". This property is part of
# the DBAPI 2.0 specification.
# <p>
# Unlike the DBAPI specification, this value is not constant. It can be
# changed to any standard paramstyle value (ie. qmark, numeric, named, format,
# and pyformat).
paramstyle = 'format' # paramstyle can be changed to any DB-API paramstyle
def convert_paramstyle(src_style, query, args):
# I don't see any way to avoid scanning the query string char by char,
# so we might as well take that careful approach and create a
# state-based scanner. We'll use int variables for the state.
# 0 -- outside quoted string
# 1 -- inside single-quote string '...'
# 2 -- inside quoted identifier "..."
# 3 -- inside escaped single-quote string, E'...'
state = 0
output_query = ""
output_args = []
if src_style == "numeric":
output_args = args
elif src_style in ("pyformat", "named"):
mapping_to_idx = {}
i = 0
while 1:
if i == len(query):
break
c = query[i]
# print "begin loop", repr(i), repr(c), repr(state)
if state == 0:
if c == "'":
i += 1
output_query += c
state = 1
elif c == '"':
i += 1
output_query += c
state = 2
elif c == 'E':
# check for escaped single-quote string
i += 1
if i < len(query) and i > 1 and query[i] == "'":
i += 1
output_query += "E'"
state = 3
else:
output_query += c
elif src_style == "qmark" and c == "?":
i += 1
param_idx = len(output_args)
if param_idx == len(args):
raise QueryParameterIndexError("too many parameter fields, not enough parameters")
output_args.append(args[param_idx])
output_query += "$" + str(param_idx + 1)
elif src_style == "numeric" and c == ":":
i += 1
if i < len(query) and i > 1 and query[i].isdigit():
output_query += "$" + query[i]
i += 1
else:
raise QueryParameterParseError("numeric parameter : does not have numeric arg")
elif src_style == "named" and c == ":":
name = ""
while 1:
i += 1
if i == len(query):
break
c = query[i]
if c.isalnum() or c == '_':
name += c
else:
break
if name == "":
raise QueryParameterParseError("empty name of named parameter")
idx = mapping_to_idx.get(name)
if idx == None:
idx = len(output_args)
output_args.append(args[name])
idx += 1
mapping_to_idx[name] = idx
output_query += "$" + str(idx)
elif src_style == "format" and c == "%":
i += 1
if i < len(query) and i > 1:
if query[i] == "s":
param_idx = len(output_args)
if param_idx == len(args):
raise QueryParameterIndexError("too many parameter fields, not enough parameters")
output_args.append(args[param_idx])
output_query += "$" + str(param_idx + 1)
elif query[i] == "%":
output_query += "%"
else:
raise QueryParameterParseError("Only %s and %% are supported")
i += 1
else:
raise QueryParameterParseError("format parameter % does not have format code")
elif src_style == "pyformat" and c == "%":
i += 1
if i < len(query) and i > 1:
if query[i] == "(":
i += 1
# begin mapping name
end_idx = query.find(')', i)
if end_idx == -1:
raise QueryParameterParseError("began pyformat dict read, but couldn't find end of name")
else:
name = query[i:end_idx]
i = end_idx + 1
if i < len(query) and query[i] == "s":
i += 1
idx = mapping_to_idx.get(name)
if idx == None:
idx = len(output_args)
output_args.append(args[name])
idx += 1
mapping_to_idx[name] = idx
output_query += "$" + str(idx)
else:
raise QueryParameterParseError("format not specified or not supported (only %(...)s supported)")
elif query[i] == "%":
output_query += "%"
elif query[i] == "s":
# we have a %s in a pyformat query string. Assume
# support for format instead.
i -= 1
src_style = "format"
else:
raise QueryParameterParseError("Only %(name)s, %s and %% are supported")
else:
i += 1
output_query += c
elif state == 1:
output_query += c
i += 1
if c == "'":
# Could be a double ''
if i < len(query) and query[i] == "'":
# is a double quote.
output_query += query[i]
i += 1
else:
state = 0
elif src_style in ("pyformat","format") and c == "%":
# hm... we're only going to support an escaped percent sign
if i < len(query):
if query[i] == "%":
# good. We already output the first percent sign.
i += 1
else:
raise QueryParameterParseError("'%" + query[i] + "' not supported in quoted string")
elif state == 2:
output_query += c
i += 1
if c == '"':
state = 0
elif src_style in ("pyformat","format") and c == "%":
# hm... we're only going to support an escaped percent sign
if i < len(query):
if query[i] == "%":
# good. We already output the first percent sign.
i += 1
else:
raise QueryParameterParseError("'%" + query[i] + "' not supported in quoted string")
elif state == 3:
output_query += c
i += 1
if c == "\\":
# check for escaped single-quote
if i < len(query) and query[i] == "'":
output_query += "'"
i += 1
elif c == "'":
state = 0
elif src_style in ("pyformat","format") and c == "%":
# hm... we're only going to support an escaped percent sign
if i < len(query):
if query[i] == "%":
# good. We already output the first percent sign.
i += 1
else:
raise QueryParameterParseError("'%" + query[i] + "' not supported in quoted string")
return output_query, tuple(output_args)
def require_open_cursor(fn):
def _fn(self, *args, **kwargs):
if self.cursor == None:
raise CursorClosedError()
return fn(self, *args, **kwargs)
return _fn
##
# The class of object returned by the {@link #ConnectionWrapper.cursor cursor method}.
class CursorWrapper(object):
def __init__(self, conn, connection):
self.cursor = interface.Cursor(conn)
self.arraysize = 1
self._connection = connection
self._override_rowcount = None
##
# This read-only attribute returns a reference to the connection object on
# which the cursor was created.
# <p>
# Stability: Part of a DBAPI 2.0 extension. A warning "DB-API extension
# cursor.connection used" will be fired.
connection = property(lambda self: self._getConnection())
def _getConnection(self):
warn("DB-API extension cursor.connection used", stacklevel=3)
return self._connection
##
# This read-only attribute specifies the number of rows that the last
# .execute*() produced (for DQL statements like 'select') or affected (for
# DML statements like 'update' or 'insert').
# <p>
# The attribute is -1 in case no .execute*() has been performed on the
# cursor or the rowcount of the last operation is cannot be determined by
# the interface.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
rowcount = property(lambda self: self._getRowCount())
@require_open_cursor
def _getRowCount(self):
if self._override_rowcount != None:
return self._override_rowcount
return self.cursor.row_count
##
# This read-only attribute is a sequence of 7-item sequences. Each value
# contains information describing one result column. The 7 items returned
# for each column are (name, type_code, display_size, internal_size,
# precision, scale, null_ok). Only the first two values are provided by
# this interface implementation.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
description = property(lambda self: self._getDescription())
@require_open_cursor
def _getDescription(self):
if self.cursor.row_description == None:
return None
columns = []
for col in self.cursor.row_description:
columns.append((col["name"], col["type_oid"], None, None, None, None, None))
return columns
##
# Executes a database operation. Parameters may be provided as a sequence
# or mapping and will be bound to variables in the operation.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
@require_open_cursor
def execute(self, operation, args=()):
if not self._connection.in_transaction:
self._connection.begin()
self._override_rowcount = None
self._execute(operation, args)
def _execute(self, operation, args=()):
new_query, new_args = convert_paramstyle(paramstyle, operation, args)
try:
self.cursor.execute(new_query, *new_args)
except ConnectionClosedError:
# can't rollback in this case
raise
except:
# any error will rollback the transaction to-date
self._connection.rollback()
raise
def copy_from(self, fileobj, table=None, sep='\t', null=None, query=None):
if query == None:
if table == None:
raise CopyQueryOrTableRequiredError()
query = "COPY %s FROM stdout DELIMITER '%s'" % (table, sep)
if null is not None:
query += " NULL '%s'" % (null,)
self.copy_execute(fileobj, query)
def copy_to(self, fileobj, table=None, sep='\t', null=None, query=None):
if query == None:
if table == None:
raise CopyQueryOrTableRequiredError()
query = "COPY %s TO stdout DELIMITER '%s'" % (table, sep)
if null is not None:
query += " NULL '%s'" % (null,)
self.copy_execute(fileobj, query)
@require_open_cursor
def copy_execute(self, fileobj, query):
try:
self.cursor.execute(query, stream=fileobj)
except ConnectionClosedError:
# can't rollback in this case
raise
except:
# any error will rollback the transaction to-date
import traceback; traceback.print_exc()
self._connection.rollback()
raise
##
# Prepare a database operation and then execute it against all parameter
# sequences or mappings provided.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
@require_open_cursor
def executemany(self, operation, parameter_sets):
if not self._connection.in_transaction:
self._connection.begin()
self._override_rowcount = 0
for parameters in parameter_sets:
self._execute(operation, parameters)
if self.cursor.row_count == -1 or self._override_rowcount == -1:
self._override_rowcount = -1
else:
self._override_rowcount += self.cursor.row_count
##
# Fetch the next row of a query result set, returning a single sequence, or
# None when no more data is available.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
@require_open_cursor
def fetchone(self):
return self.cursor.read_tuple()
##
# Fetch the next set of rows of a query result, returning a sequence of
# sequences. An empty sequence is returned when no more rows are
# available.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
# @param size The number of rows to fetch when called. If not provided,
# the arraysize property value is used instead.
def fetchmany(self, size=None):
if size == None:
size = self.arraysize
rows = []
for i in range(size):
value = self.fetchone()
if value == None:
break
rows.append(value)
return rows
##
# Fetch all remaining rows of a query result, returning them as a sequence
# of sequences.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
@require_open_cursor
def fetchall(self):
return tuple(self.cursor.iterate_tuple())
##
# Close the cursor.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
@require_open_cursor
def close(self):
self.cursor.close()
self.cursor = None
self._override_rowcount = None
def next(self):
warn("DB-API extension cursor.next() used", stacklevel=2)
retval = self.fetchone()
if retval == None:
raise StopIteration()
return retval
def __iter__(self):
warn("DB-API extension cursor.__iter__() used", stacklevel=2)
return self
def setinputsizes(self, sizes):
pass
def setoutputsize(self, size, column=None):
pass
@require_open_cursor
def fileno(self):
return self.cursor.fileno()
@require_open_cursor
def isready(self):
return self.cursor.isready()
def require_open_connection(fn):
def _fn(self, *args, **kwargs):
if self.conn == None:
raise ConnectionClosedError()
return fn(self, *args, **kwargs)
return _fn
##
# The class of object returned by the {@link #connect connect method}.
class ConnectionWrapper(object):
# DBAPI Extension: supply exceptions as attributes on the connection
Warning = property(lambda self: self._getError(Warning))
Error = property(lambda self: self._getError(Error))
InterfaceError = property(lambda self: self._getError(InterfaceError))
DatabaseError = property(lambda self: self._getError(DatabaseError))
OperationalError = property(lambda self: self._getError(OperationalError))
IntegrityError = property(lambda self: self._getError(IntegrityError))
InternalError = property(lambda self: self._getError(InternalError))
ProgrammingError = property(lambda self: self._getError(ProgrammingError))
NotSupportedError = property(lambda self: self._getError(NotSupportedError))
def _getError(self, error):
warn("DB-API extension connection.%s used" % error.__name__, stacklevel=3)
return error
@property
def in_transaction(self):
if self.conn:
return self.conn.in_transaction
return False
def __init__(self, **kwargs):
self.conn = interface.Connection(**kwargs)
self.notifies = []
self.notifies_lock = threading.Lock()
self.conn.NotificationReceived += self._notificationReceived
# Two Phase Commit internal attributes:
self.__tpc_xid = None
self.__tpc_prepared = None
def set_autocommit(self, state):
if self.conn.in_transaction and state and not self.conn.autocommit:
warn("enabling autocommit in an open transaction!")
self.conn.autocommit = state
def get_autocommit(self):
return self.conn.autocommit
autocommit = property(get_autocommit, set_autocommit)
@require_open_connection
def begin(self):
self.conn.begin()
def _notificationReceived(self, notice):
try:
# psycopg2 compatible notification interface
self.notifies_lock.acquire()
self.notifies.append((notice.backend_pid, notice.condition))
finally:
self.notifies_lock.release()
##
# Creates a {@link #CursorWrapper CursorWrapper} object bound to this
# connection.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
@require_open_connection
def cursor(self):
return CursorWrapper(self.conn, self)
##
# Commits the current database transaction.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
@require_open_connection
def commit(self):
# There's a threading bug here. If a query is sent after the
# commit, but before the begin, it will be executed immediately
# without a surrounding transaction. Like all threading bugs -- it
# sounds unlikely, until it happens every time in one
# application... however, to fix this, we need to lock the
# database connection entirely, so that no cursors can execute
# statements on other threads. Support for that type of lock will
# be done later.
if self.__tpc_xid:
raise ProgrammingError("Cannot do a normal commit() inside a "
"TPC transaction!")
self.conn.commit()
##
# Rolls back the current database transaction.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
@require_open_connection
def rollback(self):
# see bug description in commit.
if self.__tpc_xid:
raise ProgrammingError("Cannot do a normal rollback() inside a "
"TPC transaction!")
self.conn.rollback()
##
# Closes the database connection.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
@require_open_connection
def close(self):
self.conn.close()
self.conn = None
##
# Returns the "server_version" string provided by the connected server.
# <p>
# Stability: Extension of the DBAPI 2.0 specification.
@property
@require_open_connection
def server_version(self):
return self.conn.server_version()
# Stability: psycopg2 compatibility
@require_open_connection
def set_client_encoding(self, encoding=None):
"Set the client encoding for the current session"
if encoding:
self.conn.execute("SET client_encoding TO '%s';" % (encoding, ), simple_query=True)
return self.conn.encoding()
def xid(self,format_id, global_transaction_id, branch_qualifier):
"""Create a Transaction IDs (only global_transaction_id is used in pg)
format_id and branch_qualifier are not used in postgres
global_transaction_id may be any string identifier supported by postgres
returns a tuple (format_id, global_transaction_id, branch_qualifier)"""
return (format_id, global_transaction_id, branch_qualifier)
@require_open_connection
def tpc_begin(self,xid):
"Begin a two-phase transaction"
# set auto-commit mode to begin a TPC transaction
self.autocommit = False
# (actually in postgres at this point it is a normal one)
if self.conn.in_transaction:
warn("tpc_begin() should be called outside a transaction block",
stacklevel=3)
self.conn.begin()
# store actual TPC transaction id
self.__tpc_xid = xid
self.__tpc_prepared = False
@require_open_connection
def tpc_prepare(self):
"Prepare a two-phase transaction"
if not self.__tpc_xid:
raise ProgrammingError("tpc_prepare() outside a TPC transaction "
"is not allowed!")
# Prepare the TPC
self.conn.execute("PREPARE TRANSACTION '%s';" % (self.__tpc_xid[1],),
simple_query=True)
self.conn.in_transaction = False
self.__tpc_prepared = True
@require_open_connection
def tpc_commit(self, xid=None):
"Commit a prepared two-phase transaction"
try:
# save current autocommit status (to be recovered later)
previous_autocommit_mode = self.autocommit
if not xid:
# use current tpc transaction
tpc_xid = self.__tpc_xid
else:
# use a recovered tpc transaction
tpc_xid = xid
if not xid in self.tpc_recover():
raise ProgrammingError("Requested TPC transaction is not "
"prepared!")
if not tpc_xid:
raise ProgrammingError("Cannot tpc_commit() without a TPC "
"transaction!")
if self.__tpc_prepared or (xid != self.__tpc_xid and xid):
# a two-phase commit:
# set the auto-commit mode for TPC commit
self.autocommit = True
try:
self.conn.execute("COMMIT PREPARED '%s';" % (tpc_xid[1], ),
simple_query=True)
finally:
# return to previous auto-commit mode
self.autocommit = previous_autocommit_mode
else:
try:
# a single-phase commit
self.conn.commit()
finally:
# return to previous auto-commit mode
self.autocommit = previous_autocommit_mode
finally:
# transaction is done, clear xid
self.__tpc_xid = None
@require_open_connection
def tpc_rollback(self, xid=None):
"Commit a prepared two-phase transaction"
try:
# save current autocommit status (to be recovered later)
previous_autocommit_mode = self.autocommit
if not xid:
# use current tpc transaction
tpc_xid = self.__tpc_xid
else:
# use a recovered tpc transaction
tpc_xid = xid
if not xid in self.tpc_recover():
raise ProgrammingError("Requested TPC transaction is not prepared!")
if not tpc_xid:
raise ProgrammingError("Cannot tpc_rollback() without a TPC prepared transaction!")
if self.__tpc_prepared or (xid != self.__tpc_xid and xid):
# a two-phase rollback
# set auto-commit for the TPC rollback
self.autocommit = True
try:
self.conn.execute("ROLLBACK PREPARED '%s';" % (tpc_xid[1],),
simple_query=True)
finally:
# return to previous auto-commit mode
self.autocommit = previous_autocommit_mode
else:
# a single-phase rollback
try:
self.conn.rollback()
finally:
# return to previous auto-commit mode
self.autocommit = previous_autocommit_mode
finally:
# transaction is done, clear xid
self.__tpc_xid = None
@require_open_connection
def tpc_recover(self):
"Returns a list of pending transaction IDs"
previous_autocommit_mode = self.autocommit
if not self.conn.in_transaction and not self.autocommit:
self.autocommit = True
elif not self.autocommit:
warn("tpc_recover() will open a transaction block", stacklevel=3)
curs = self.cursor()
xids = []
try:
# query system view that stores open (prepared) TPC transactions
curs.execute("SELECT gid FROM pg_prepared_xacts;");
xids.extend([self.xid(0,row[0],'') for row in curs])
finally:
curs.close()
# return to previous auto-commit mode
self.autocommit = previous_autocommit_mode
# return a list of TPC transaction ids (xid)
return xids
##
# Creates a DBAPI 2.0 compatible interface to a PostgreSQL database.
# <p>
# Stability: Part of the DBAPI 2.0 specification.
#
# @param user The username to connect to the PostgreSQL server with. This
# parameter is required.
#
# @keyparam host The hostname of the PostgreSQL server to connect with.
# Providing this parameter is necessary for TCP/IP connections. One of either
# host, or unix_sock, must be provided.
#
# @keyparam unix_sock The path to the UNIX socket to access the database
# through, for example, '/tmp/.s.PGSQL.5432'. One of either unix_sock or host
# must be provided. The port parameter will have no affect if unix_sock is
# provided.
#
# @keyparam port The TCP/IP port of the PostgreSQL server instance. This
# parameter defaults to 5432, the registered and common port of PostgreSQL
# TCP/IP servers.
#
# @keyparam database The name of the database instance to connect with. This
# parameter is optional, if omitted the PostgreSQL server will assume the
# database name is the same as the username.
#
# @keyparam password The user password to connect to the server with. This
# parameter is optional. If omitted, and the database server requests password
# based authentication, the connection will fail. On the other hand, if this
# parameter is provided and the database does not request password
# authentication, then the password will not be used.
#
# @keyparam socket_timeout Socket connect timeout measured in seconds.
# Defaults to 60 seconds.
#
# @keyparam ssl Use SSL encryption for TCP/IP socket. Defaults to False.
#
# @return An instance of {@link #ConnectionWrapper ConnectionWrapper}.
def connect(dsn="", user=None, host=None, unix_sock=None, port=5432, database=None, password=None, socket_timeout=60, ssl=False):
return ConnectionWrapper(dsn=dsn, user=user, host=host,
unix_sock=unix_sock, port=port, database=database,
password=password, socket_timeout=socket_timeout, ssl=ssl)
def Date(year, month, day):
return datetime.date(year, month, day)
def Time(hour, minute, second):
return datetime.time(hour, minute, second)
def Timestamp(year, month, day, hour, minute, second):
return datetime.datetime(year, month, day, hour, minute, second)
def DateFromTicks(ticks):
return Date(*time.localtime(ticks)[:3])
def TimeFromTicks(ticks):
return Time(*time.localtime(ticks)[3:6])
def TimestampFromTicks(ticks):
return Timestamp(*time.localtime(ticks)[:6])
##
# Construct an object holding binary data.
def Binary(value):
return types.Bytea(value)
# I have no idea what this would be used for by a client app. Should it be
# TEXT, VARCHAR, CHAR? It will only compare against row_description's
# type_code if it is this one type. It is the varchar type oid for now, this
# appears to match expectations in the DB API 2.0 compliance test suite.
STRING = 1043
# bytea type_oid
BINARY = 17
# numeric type_oid
NUMBER = 1700
# timestamp type_oid
DATETIME = 1114
# oid type_oid
ROWID = 26