diff --git a/VERSION b/VERSION index 2e52af03..e5836eea 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -Version 1.99.4 (2012-01-12 13:22:21) stable +Version 1.99.4 (2012-01-12 13:58:45) stable diff --git a/gluon/contrib/pg8000/__init__.py b/gluon/contrib/pg8000/__init__.py new file mode 100644 index 00000000..57de8e80 --- /dev/null +++ b/gluon/contrib/pg8000/__init__.py @@ -0,0 +1,37 @@ +# vim: sw=4:expandtab:foldmethod=marker +# +# Copyright (c) 2007-2009, Mathieu Fenniak +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * The name of the author may not be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +__author__ = "Mathieu Fenniak" + +import dbapi as DBAPI +pg8000_dbapi = DBAPI + +from interface import * +from types import Bytea + diff --git a/gluon/contrib/pg8000/dbapi.py b/gluon/contrib/pg8000/dbapi.py new file mode 100644 index 00000000..4518b192 --- /dev/null +++ b/gluon/contrib/pg8000/dbapi.py @@ -0,0 +1,795 @@ +# vim: sw=4:expandtab:foldmethod=marker +# +# Copyright (c) 2007-2009, Mathieu Fenniak +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * The name of the author may not be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +__author__ = "Mathieu Fenniak" + +__version__ = "1.10" + +import datetime +import time +import interface +import types +import threading +from errors import * + +from warnings import warn + +## +# The DBAPI level supported. Currently 2.0. This property is part of the +# DBAPI 2.0 specification. +apilevel = "2.0" + +## +# Integer constant stating the level of thread safety the DBAPI interface +# supports. This DBAPI interface supports sharing of the module, connections, +# and cursors. This property is part of the DBAPI 2.0 specification. +threadsafety = 3 + +## +# String property stating the type of parameter marker formatting expected by +# the interface. This value defaults to "format". This property is part of +# the DBAPI 2.0 specification. +#

+# Unlike the DBAPI specification, this value is not constant. It can be +# changed to any standard paramstyle value (ie. qmark, numeric, named, format, +# and pyformat). +paramstyle = 'format' # paramstyle can be changed to any DB-API paramstyle + +def convert_paramstyle(src_style, query, args): + # I don't see any way to avoid scanning the query string char by char, + # so we might as well take that careful approach and create a + # state-based scanner. We'll use int variables for the state. + # 0 -- outside quoted string + # 1 -- inside single-quote string '...' + # 2 -- inside quoted identifier "..." + # 3 -- inside escaped single-quote string, E'...' + state = 0 + output_query = "" + output_args = [] + if src_style == "numeric": + output_args = args + elif src_style in ("pyformat", "named"): + mapping_to_idx = {} + i = 0 + while 1: + if i == len(query): + break + c = query[i] + # print "begin loop", repr(i), repr(c), repr(state) + if state == 0: + if c == "'": + i += 1 + output_query += c + state = 1 + elif c == '"': + i += 1 + output_query += c + state = 2 + elif c == 'E': + # check for escaped single-quote string + i += 1 + if i < len(query) and i > 1 and query[i] == "'": + i += 1 + output_query += "E'" + state = 3 + else: + output_query += c + elif src_style == "qmark" and c == "?": + i += 1 + param_idx = len(output_args) + if param_idx == len(args): + raise QueryParameterIndexError("too many parameter fields, not enough parameters") + output_args.append(args[param_idx]) + output_query += "$" + str(param_idx + 1) + elif src_style == "numeric" and c == ":": + i += 1 + if i < len(query) and i > 1 and query[i].isdigit(): + output_query += "$" + query[i] + i += 1 + else: + raise QueryParameterParseError("numeric parameter : does not have numeric arg") + elif src_style == "named" and c == ":": + name = "" + while 1: + i += 1 + if i == len(query): + break + c = query[i] + if c.isalnum() or c == '_': + name += c + else: + break + if name == "": + raise QueryParameterParseError("empty name of named parameter") + idx = mapping_to_idx.get(name) + if idx == None: + idx = len(output_args) + output_args.append(args[name]) + idx += 1 + mapping_to_idx[name] = idx + output_query += "$" + str(idx) + elif src_style == "format" and c == "%": + i += 1 + if i < len(query) and i > 1: + if query[i] == "s": + param_idx = len(output_args) + if param_idx == len(args): + raise QueryParameterIndexError("too many parameter fields, not enough parameters") + output_args.append(args[param_idx]) + output_query += "$" + str(param_idx + 1) + elif query[i] == "%": + output_query += "%" + else: + raise QueryParameterParseError("Only %s and %% are supported") + i += 1 + else: + raise QueryParameterParseError("format parameter % does not have format code") + elif src_style == "pyformat" and c == "%": + i += 1 + if i < len(query) and i > 1: + if query[i] == "(": + i += 1 + # begin mapping name + end_idx = query.find(')', i) + if end_idx == -1: + raise QueryParameterParseError("began pyformat dict read, but couldn't find end of name") + else: + name = query[i:end_idx] + i = end_idx + 1 + if i < len(query) and query[i] == "s": + i += 1 + idx = mapping_to_idx.get(name) + if idx == None: + idx = len(output_args) + output_args.append(args[name]) + idx += 1 + mapping_to_idx[name] = idx + output_query += "$" + str(idx) + else: + raise QueryParameterParseError("format not specified or not supported (only %(...)s supported)") + elif query[i] == "%": + output_query += "%" + elif query[i] == "s": + # we have a %s in a pyformat query string. Assume + # support for format instead. + i -= 1 + src_style = "format" + else: + raise QueryParameterParseError("Only %(name)s, %s and %% are supported") + else: + i += 1 + output_query += c + elif state == 1: + output_query += c + i += 1 + if c == "'": + # Could be a double '' + if i < len(query) and query[i] == "'": + # is a double quote. + output_query += query[i] + i += 1 + else: + state = 0 + elif src_style in ("pyformat","format") and c == "%": + # hm... we're only going to support an escaped percent sign + if i < len(query): + if query[i] == "%": + # good. We already output the first percent sign. + i += 1 + else: + raise QueryParameterParseError("'%" + query[i] + "' not supported in quoted string") + elif state == 2: + output_query += c + i += 1 + if c == '"': + state = 0 + elif src_style in ("pyformat","format") and c == "%": + # hm... we're only going to support an escaped percent sign + if i < len(query): + if query[i] == "%": + # good. We already output the first percent sign. + i += 1 + else: + raise QueryParameterParseError("'%" + query[i] + "' not supported in quoted string") + elif state == 3: + output_query += c + i += 1 + if c == "\\": + # check for escaped single-quote + if i < len(query) and query[i] == "'": + output_query += "'" + i += 1 + elif c == "'": + state = 0 + elif src_style in ("pyformat","format") and c == "%": + # hm... we're only going to support an escaped percent sign + if i < len(query): + if query[i] == "%": + # good. We already output the first percent sign. + i += 1 + else: + raise QueryParameterParseError("'%" + query[i] + "' not supported in quoted string") + + return output_query, tuple(output_args) + + +def require_open_cursor(fn): + def _fn(self, *args, **kwargs): + if self.cursor == None: + raise CursorClosedError() + return fn(self, *args, **kwargs) + return _fn + +## +# The class of object returned by the {@link #ConnectionWrapper.cursor cursor method}. +class CursorWrapper(object): + def __init__(self, conn, connection): + self.cursor = interface.Cursor(conn) + self.arraysize = 1 + self._connection = connection + self._override_rowcount = None + + ## + # This read-only attribute returns a reference to the connection object on + # which the cursor was created. + #

+ # Stability: Part of a DBAPI 2.0 extension. A warning "DB-API extension + # cursor.connection used" will be fired. + connection = property(lambda self: self._getConnection()) + + def _getConnection(self): + warn("DB-API extension cursor.connection used", stacklevel=3) + return self._connection + + ## + # This read-only attribute specifies the number of rows that the last + # .execute*() produced (for DQL statements like 'select') or affected (for + # DML statements like 'update' or 'insert'). + #

+ # The attribute is -1 in case no .execute*() has been performed on the + # cursor or the rowcount of the last operation is cannot be determined by + # the interface. + #

+ # Stability: Part of the DBAPI 2.0 specification. + rowcount = property(lambda self: self._getRowCount()) + + @require_open_cursor + def _getRowCount(self): + if self._override_rowcount != None: + return self._override_rowcount + return self.cursor.row_count + + ## + # This read-only attribute is a sequence of 7-item sequences. Each value + # contains information describing one result column. The 7 items returned + # for each column are (name, type_code, display_size, internal_size, + # precision, scale, null_ok). Only the first two values are provided by + # this interface implementation. + #

+ # Stability: Part of the DBAPI 2.0 specification. + description = property(lambda self: self._getDescription()) + + @require_open_cursor + def _getDescription(self): + if self.cursor.row_description == None: + return None + columns = [] + for col in self.cursor.row_description: + columns.append((col["name"], col["type_oid"], None, None, None, None, None)) + return columns + + ## + # Executes a database operation. Parameters may be provided as a sequence + # or mapping and will be bound to variables in the operation. + #

+ # Stability: Part of the DBAPI 2.0 specification. + @require_open_cursor + def execute(self, operation, args=()): + if not self._connection.in_transaction: + self._connection.begin() + self._override_rowcount = None + self._execute(operation, args) + + def _execute(self, operation, args=()): + new_query, new_args = convert_paramstyle(paramstyle, operation, args) + try: + self.cursor.execute(new_query, *new_args) + except ConnectionClosedError: + # can't rollback in this case + raise + except: + # any error will rollback the transaction to-date + self._connection.rollback() + raise + + def copy_from(self, fileobj, table=None, sep='\t', null=None, query=None): + if query == None: + if table == None: + raise CopyQueryOrTableRequiredError() + query = "COPY %s FROM stdout DELIMITER '%s'" % (table, sep) + if null is not None: + query += " NULL '%s'" % (null,) + self.copy_execute(fileobj, query) + + def copy_to(self, fileobj, table=None, sep='\t', null=None, query=None): + if query == None: + if table == None: + raise CopyQueryOrTableRequiredError() + query = "COPY %s TO stdout DELIMITER '%s'" % (table, sep) + if null is not None: + query += " NULL '%s'" % (null,) + self.copy_execute(fileobj, query) + + @require_open_cursor + def copy_execute(self, fileobj, query): + try: + self.cursor.execute(query, stream=fileobj) + except ConnectionClosedError: + # can't rollback in this case + raise + except: + # any error will rollback the transaction to-date + import traceback; traceback.print_exc() + self._connection.rollback() + raise + + ## + # Prepare a database operation and then execute it against all parameter + # sequences or mappings provided. + #

+ # Stability: Part of the DBAPI 2.0 specification. + @require_open_cursor + def executemany(self, operation, parameter_sets): + if not self._connection.in_transaction: + self._connection.begin() + self._override_rowcount = 0 + for parameters in parameter_sets: + self._execute(operation, parameters) + if self.cursor.row_count == -1 or self._override_rowcount == -1: + self._override_rowcount = -1 + else: + self._override_rowcount += self.cursor.row_count + + ## + # Fetch the next row of a query result set, returning a single sequence, or + # None when no more data is available. + #

+ # Stability: Part of the DBAPI 2.0 specification. + @require_open_cursor + def fetchone(self): + return self.cursor.read_tuple() + + ## + # Fetch the next set of rows of a query result, returning a sequence of + # sequences. An empty sequence is returned when no more rows are + # available. + #

+ # Stability: Part of the DBAPI 2.0 specification. + # @param size The number of rows to fetch when called. If not provided, + # the arraysize property value is used instead. + def fetchmany(self, size=None): + if size == None: + size = self.arraysize + rows = [] + for i in range(size): + value = self.fetchone() + if value == None: + break + rows.append(value) + return rows + + ## + # Fetch all remaining rows of a query result, returning them as a sequence + # of sequences. + #

+ # Stability: Part of the DBAPI 2.0 specification. + @require_open_cursor + def fetchall(self): + return tuple(self.cursor.iterate_tuple()) + + ## + # Close the cursor. + #

+ # Stability: Part of the DBAPI 2.0 specification. + @require_open_cursor + def close(self): + self.cursor.close() + self.cursor = None + self._override_rowcount = None + + def next(self): + warn("DB-API extension cursor.next() used", stacklevel=2) + retval = self.fetchone() + if retval == None: + raise StopIteration() + return retval + + def __iter__(self): + warn("DB-API extension cursor.__iter__() used", stacklevel=2) + return self + + def setinputsizes(self, sizes): + pass + + def setoutputsize(self, size, column=None): + pass + + @require_open_cursor + def fileno(self): + return self.cursor.fileno() + + @require_open_cursor + def isready(self): + return self.cursor.isready() + +def require_open_connection(fn): + def _fn(self, *args, **kwargs): + if self.conn == None: + raise ConnectionClosedError() + return fn(self, *args, **kwargs) + return _fn + +## +# The class of object returned by the {@link #connect connect method}. +class ConnectionWrapper(object): + # DBAPI Extension: supply exceptions as attributes on the connection + Warning = property(lambda self: self._getError(Warning)) + Error = property(lambda self: self._getError(Error)) + InterfaceError = property(lambda self: self._getError(InterfaceError)) + DatabaseError = property(lambda self: self._getError(DatabaseError)) + OperationalError = property(lambda self: self._getError(OperationalError)) + IntegrityError = property(lambda self: self._getError(IntegrityError)) + InternalError = property(lambda self: self._getError(InternalError)) + ProgrammingError = property(lambda self: self._getError(ProgrammingError)) + NotSupportedError = property(lambda self: self._getError(NotSupportedError)) + + def _getError(self, error): + warn("DB-API extension connection.%s used" % error.__name__, stacklevel=3) + return error + + @property + def in_transaction(self): + if self.conn: + return self.conn.in_transaction + return False + + def __init__(self, **kwargs): + self.conn = interface.Connection(**kwargs) + self.notifies = [] + self.notifies_lock = threading.Lock() + self.conn.NotificationReceived += self._notificationReceived + # Two Phase Commit internal attributes: + self.__tpc_xid = None + self.__tpc_prepared = None + + def set_autocommit(self, state): + if self.conn.in_transaction and state and not self.conn.autocommit: + warn("enabling autocommit in an open transaction!") + self.conn.autocommit = state + + def get_autocommit(self): + return self.conn.autocommit + + autocommit = property(get_autocommit, set_autocommit) + + @require_open_connection + def begin(self): + self.conn.begin() + + def _notificationReceived(self, notice): + try: + # psycopg2 compatible notification interface + self.notifies_lock.acquire() + self.notifies.append((notice.backend_pid, notice.condition)) + finally: + self.notifies_lock.release() + + ## + # Creates a {@link #CursorWrapper CursorWrapper} object bound to this + # connection. + #

+ # Stability: Part of the DBAPI 2.0 specification. + @require_open_connection + def cursor(self): + return CursorWrapper(self.conn, self) + + ## + # Commits the current database transaction. + #

+ # Stability: Part of the DBAPI 2.0 specification. + @require_open_connection + def commit(self): + # There's a threading bug here. If a query is sent after the + # commit, but before the begin, it will be executed immediately + # without a surrounding transaction. Like all threading bugs -- it + # sounds unlikely, until it happens every time in one + # application... however, to fix this, we need to lock the + # database connection entirely, so that no cursors can execute + # statements on other threads. Support for that type of lock will + # be done later. + if self.__tpc_xid: + raise ProgrammingError("Cannot do a normal commit() inside a " + "TPC transaction!") + self.conn.commit() + + ## + # Rolls back the current database transaction. + #

+ # Stability: Part of the DBAPI 2.0 specification. + @require_open_connection + def rollback(self): + # see bug description in commit. + if self.__tpc_xid: + raise ProgrammingError("Cannot do a normal rollback() inside a " + "TPC transaction!") + self.conn.rollback() + + ## + # Closes the database connection. + #

+ # Stability: Part of the DBAPI 2.0 specification. + @require_open_connection + def close(self): + self.conn.close() + self.conn = None + + ## + # Returns the "server_version" string provided by the connected server. + #

+ # Stability: Extension of the DBAPI 2.0 specification. + @property + @require_open_connection + def server_version(self): + return self.conn.server_version() + + # Stability: psycopg2 compatibility + @require_open_connection + def set_client_encoding(self, encoding=None): + "Set the client encoding for the current session" + if encoding: + self.conn.execute("SET client_encoding TO '%s';" % (encoding, ), simple_query=True) + return self.conn.encoding() + + + def xid(self,format_id, global_transaction_id, branch_qualifier): + """Create a Transaction IDs (only global_transaction_id is used in pg) + format_id and branch_qualifier are not used in postgres + global_transaction_id may be any string identifier supported by postgres + returns a tuple (format_id, global_transaction_id, branch_qualifier)""" + return (format_id, global_transaction_id, branch_qualifier) + + @require_open_connection + def tpc_begin(self,xid): + "Begin a two-phase transaction" + # set auto-commit mode to begin a TPC transaction + self.autocommit = False + # (actually in postgres at this point it is a normal one) + if self.conn.in_transaction: + warn("tpc_begin() should be called outside a transaction block", + stacklevel=3) + self.conn.begin() + # store actual TPC transaction id + self.__tpc_xid = xid + self.__tpc_prepared = False + + @require_open_connection + def tpc_prepare(self): + "Prepare a two-phase transaction" + if not self.__tpc_xid: + raise ProgrammingError("tpc_prepare() outside a TPC transaction " + "is not allowed!") + # Prepare the TPC + self.conn.execute("PREPARE TRANSACTION '%s';" % (self.__tpc_xid[1],), + simple_query=True) + self.conn.in_transaction = False + self.__tpc_prepared = True + + @require_open_connection + def tpc_commit(self, xid=None): + "Commit a prepared two-phase transaction" + try: + # save current autocommit status (to be recovered later) + previous_autocommit_mode = self.autocommit + if not xid: + # use current tpc transaction + tpc_xid = self.__tpc_xid + else: + # use a recovered tpc transaction + tpc_xid = xid + if not xid in self.tpc_recover(): + raise ProgrammingError("Requested TPC transaction is not " + "prepared!") + if not tpc_xid: + raise ProgrammingError("Cannot tpc_commit() without a TPC " + "transaction!") + if self.__tpc_prepared or (xid != self.__tpc_xid and xid): + # a two-phase commit: + # set the auto-commit mode for TPC commit + self.autocommit = True + try: + self.conn.execute("COMMIT PREPARED '%s';" % (tpc_xid[1], ), + simple_query=True) + finally: + # return to previous auto-commit mode + self.autocommit = previous_autocommit_mode + else: + try: + # a single-phase commit + self.conn.commit() + finally: + # return to previous auto-commit mode + self.autocommit = previous_autocommit_mode + finally: + # transaction is done, clear xid + self.__tpc_xid = None + + @require_open_connection + def tpc_rollback(self, xid=None): + "Commit a prepared two-phase transaction" + try: + # save current autocommit status (to be recovered later) + previous_autocommit_mode = self.autocommit + if not xid: + # use current tpc transaction + tpc_xid = self.__tpc_xid + else: + # use a recovered tpc transaction + tpc_xid = xid + if not xid in self.tpc_recover(): + raise ProgrammingError("Requested TPC transaction is not prepared!") + if not tpc_xid: + raise ProgrammingError("Cannot tpc_rollback() without a TPC prepared transaction!") + if self.__tpc_prepared or (xid != self.__tpc_xid and xid): + # a two-phase rollback + # set auto-commit for the TPC rollback + self.autocommit = True + try: + self.conn.execute("ROLLBACK PREPARED '%s';" % (tpc_xid[1],), + simple_query=True) + finally: + # return to previous auto-commit mode + self.autocommit = previous_autocommit_mode + else: + # a single-phase rollback + try: + self.conn.rollback() + finally: + # return to previous auto-commit mode + self.autocommit = previous_autocommit_mode + finally: + # transaction is done, clear xid + self.__tpc_xid = None + + @require_open_connection + def tpc_recover(self): + "Returns a list of pending transaction IDs" + previous_autocommit_mode = self.autocommit + if not self.conn.in_transaction and not self.autocommit: + self.autocommit = True + elif not self.autocommit: + warn("tpc_recover() will open a transaction block", stacklevel=3) + + curs = self.cursor() + xids = [] + try: + # query system view that stores open (prepared) TPC transactions + curs.execute("SELECT gid FROM pg_prepared_xacts;"); + xids.extend([self.xid(0,row[0],'') for row in curs]) + finally: + curs.close() + # return to previous auto-commit mode + self.autocommit = previous_autocommit_mode + # return a list of TPC transaction ids (xid) + return xids + + +## +# Creates a DBAPI 2.0 compatible interface to a PostgreSQL database. +#

+# Stability: Part of the DBAPI 2.0 specification. +# +# @param user The username to connect to the PostgreSQL server with. This +# parameter is required. +# +# @keyparam host The hostname of the PostgreSQL server to connect with. +# Providing this parameter is necessary for TCP/IP connections. One of either +# host, or unix_sock, must be provided. +# +# @keyparam unix_sock The path to the UNIX socket to access the database +# through, for example, '/tmp/.s.PGSQL.5432'. One of either unix_sock or host +# must be provided. The port parameter will have no affect if unix_sock is +# provided. +# +# @keyparam port The TCP/IP port of the PostgreSQL server instance. This +# parameter defaults to 5432, the registered and common port of PostgreSQL +# TCP/IP servers. +# +# @keyparam database The name of the database instance to connect with. This +# parameter is optional, if omitted the PostgreSQL server will assume the +# database name is the same as the username. +# +# @keyparam password The user password to connect to the server with. This +# parameter is optional. If omitted, and the database server requests password +# based authentication, the connection will fail. On the other hand, if this +# parameter is provided and the database does not request password +# authentication, then the password will not be used. +# +# @keyparam socket_timeout Socket connect timeout measured in seconds. +# Defaults to 60 seconds. +# +# @keyparam ssl Use SSL encryption for TCP/IP socket. Defaults to False. +# +# @return An instance of {@link #ConnectionWrapper ConnectionWrapper}. +def connect(dsn="", user=None, host=None, unix_sock=None, port=5432, database=None, password=None, socket_timeout=60, ssl=False): + return ConnectionWrapper(dsn=dsn, user=user, host=host, + unix_sock=unix_sock, port=port, database=database, + password=password, socket_timeout=socket_timeout, ssl=ssl) + +def Date(year, month, day): + return datetime.date(year, month, day) + +def Time(hour, minute, second): + return datetime.time(hour, minute, second) + +def Timestamp(year, month, day, hour, minute, second): + return datetime.datetime(year, month, day, hour, minute, second) + +def DateFromTicks(ticks): + return Date(*time.localtime(ticks)[:3]) + +def TimeFromTicks(ticks): + return Time(*time.localtime(ticks)[3:6]) + +def TimestampFromTicks(ticks): + return Timestamp(*time.localtime(ticks)[:6]) + +## +# Construct an object holding binary data. +def Binary(value): + return types.Bytea(value) + +# I have no idea what this would be used for by a client app. Should it be +# TEXT, VARCHAR, CHAR? It will only compare against row_description's +# type_code if it is this one type. It is the varchar type oid for now, this +# appears to match expectations in the DB API 2.0 compliance test suite. +STRING = 1043 + +# bytea type_oid +BINARY = 17 + +# numeric type_oid +NUMBER = 1700 + +# timestamp type_oid +DATETIME = 1114 + +# oid type_oid +ROWID = 26 + + diff --git a/gluon/contrib/pg8000/errors.py b/gluon/contrib/pg8000/errors.py new file mode 100644 index 00000000..b8b5acfb --- /dev/null +++ b/gluon/contrib/pg8000/errors.py @@ -0,0 +1,115 @@ +# vim: sw=4:expandtab:foldmethod=marker +# +# Copyright (c) 2007-2009, Mathieu Fenniak +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * The name of the author may not be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +__author__ = "Mathieu Fenniak" + +class Warning(StandardError): + pass + +class Error(StandardError): + pass + +class InterfaceError(Error): + pass + +class ConnectionClosedError(InterfaceError): + def __init__(self): + InterfaceError.__init__(self, "connection is closed") + +class CursorClosedError(InterfaceError): + def __init__(self): + InterfaceError.__init__(self, "cursor is closed") + +class DatabaseError(Error): + pass + +class DataError(DatabaseError): + pass + +class OperationalError(DatabaseError): + pass + +class IntegrityError(DatabaseError): + pass + +class InternalError(DatabaseError): + pass + +class ProgrammingError(DatabaseError): + pass + +class NotSupportedError(DatabaseError): + pass + +## +# An exception that is thrown when an internal error occurs trying to +# decode binary array data from the server. +class ArrayDataParseError(InternalError): + pass + +## +# Thrown when attempting to transmit an array of unsupported data types. +class ArrayContentNotSupportedError(NotSupportedError): + pass + +## +# Thrown when attempting to send an array that doesn't contain all the same +# type of objects (eg. some floats, some ints). +class ArrayContentNotHomogenousError(ProgrammingError): + pass + +## +# Attempted to pass an empty array in, but it's not possible to determine the +# data type for an empty array. +class ArrayContentEmptyError(ProgrammingError): + pass + +## +# Attempted to use a multidimensional array with inconsistent array sizes. +class ArrayDimensionsNotConsistentError(ProgrammingError): + pass + +# A cursor's copy_to or copy_from argument was not provided a table or query +# to operate on. +class CopyQueryOrTableRequiredError(ProgrammingError): + pass + +# Raised if a COPY query is executed without using copy_to or copy_from +# functions to provide a data stream. +class CopyQueryWithoutStreamError(ProgrammingError): + pass + +# When query parameters don't match up with query args. +class QueryParameterIndexError(ProgrammingError): + pass + +# Some sort of parse error occured during query parameterization. +class QueryParameterParseError(ProgrammingError): + pass + diff --git a/gluon/contrib/pg8000/interface.py b/gluon/contrib/pg8000/interface.py new file mode 100644 index 00000000..e3ccda00 --- /dev/null +++ b/gluon/contrib/pg8000/interface.py @@ -0,0 +1,660 @@ +# vim: sw=4:expandtab:foldmethod=marker +# +# Copyright (c) 2007-2009, Mathieu Fenniak +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * The name of the author may not be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +__author__ = "Mathieu Fenniak" + +import socket +import protocol +import threading +from errors import * + +def conninfo_parse(conninfo): + "Conninfo parser routine based on libpq conninfo_parse" + options = {} + buf = conninfo + " " + tmp = pname = "" + quoted_string = False + cp = 0 + while cp < len(buf): + # Skip blanks before the parameter name + c = buf[cp] + if c.isspace() and tmp and not quoted_string and pname: + options[pname] = tmp + tmp = pname = "" + elif c == "'": + quoted_string = not quoted_string + elif c == '\\': + cp += 1 + tmp += buf[cp] + elif c == "=": + if not tmp: + raise RuntimeError("missing parameter name (conninfo:%s)" % cp) + pname = tmp + tmp = "" + elif not c.isspace() or quoted_string: + tmp += c + cp += 1 + if quoted_string: + raise RuntimeError("unterminated quoted string (conninfo:%s)" % cp) + return options + +class DataIterator(object): + def __init__(self, obj, func): + self.obj = obj + self.func = func + + def __iter__(self): + return self + + def next(self): + retval = self.func(self.obj) + if retval == None: + raise StopIteration() + return retval + +statement_number_lock = threading.Lock() +statement_number = 0 + +## +# This class represents a prepared statement. A prepared statement is +# pre-parsed on the server, which reduces the need to parse the query every +# time it is run. The statement can have parameters in the form of $1, $2, $3, +# etc. When parameters are used, the types of the parameters need to be +# specified when creating the prepared statement. +#

+# As of v1.01, instances of this class are thread-safe. This means that a +# single PreparedStatement can be accessed by multiple threads without the +# internal consistency of the statement being altered. However, the +# responsibility is on the client application to ensure that one thread reading +# from a statement isn't affected by another thread starting a new query with +# the same statement. +#

+# Stability: Added in v1.00, stability guaranteed for v1.xx. +# +# @param connection An instance of {@link Connection Connection}. +# +# @param statement The SQL statement to be represented, often containing +# parameters in the form of $1, $2, $3, etc. +# +# @param types Python type objects for each parameter in the SQL +# statement. For example, int, float, str. +class PreparedStatement(object): + + ## + # Determines the number of rows to read from the database server at once. + # Reading more rows increases performance at the cost of memory. The + # default value is 100 rows. The affect of this parameter is transparent. + # That is, the library reads more rows when the cache is empty + # automatically. + #

+ # Stability: Added in v1.00, stability guaranteed for v1.xx. It is + # possible that implementation changes in the future could cause this + # parameter to be ignored. + row_cache_size = 100 + + def __init__(self, connection, statement, *types, **kwargs): + global statement_number + if connection == None or connection.c == None: + raise InterfaceError("connection not provided") + try: + statement_number_lock.acquire() + self._statement_number = statement_number + statement_number += 1 + finally: + statement_number_lock.release() + self.c = connection.c + self._portal_name = None + self._statement_name = kwargs.get("statement_name", "pg8000_statement_%s" % self._statement_number) + self._row_desc = None + self._cached_rows = [] + self._ongoing_row_count = 0 + self._command_complete = True + self._parse_row_desc = self.c.parse(self._statement_name, statement, types) + self._lock = threading.RLock() + + def close(self): + if self._statement_name != "": # don't close unnamed statement + self.c.close_statement(self._statement_name) + if self._portal_name != None: + self.c.close_portal(self._portal_name) + self._portal_name = None + + row_description = property(lambda self: self._getRowDescription()) + def _getRowDescription(self): + if self._row_desc == None: + return None + return self._row_desc.fields + + ## + # Run the SQL prepared statement with the given parameters. + #

+ # Stability: Added in v1.00, stability guaranteed for v1.xx. + def execute(self, *args, **kwargs): + self._lock.acquire() + try: + if not self._command_complete: + # cleanup last execute + self._cached_rows = [] + self._ongoing_row_count = 0 + if self._portal_name != None: + self.c.close_portal(self._portal_name) + self._command_complete = False + self._portal_name = "pg8000_portal_%s" % self._statement_number + self._row_desc, cmd = self.c.bind(self._portal_name, self._statement_name, args, self._parse_row_desc, kwargs.get("stream")) + if self._row_desc: + # We execute our cursor right away to fill up our cache. This + # prevents the cursor from being destroyed, apparently, by a rogue + # Sync between Bind and Execute. Since it is quite likely that + # data will be read from us right away anyways, this seems a safe + # move for now. + self._fill_cache() + else: + self._command_complete = True + self._ongoing_row_count = -1 + if cmd != None and cmd.rows != None: + self._ongoing_row_count = cmd.rows + finally: + self._lock.release() + + def _fill_cache(self): + self._lock.acquire() + try: + if self._cached_rows: + raise InternalError("attempt to fill cache that isn't empty") + end_of_data, rows = self.c.fetch_rows(self._portal_name, self.row_cache_size, self._row_desc) + self._cached_rows = rows + if end_of_data: + self._command_complete = True + finally: + self._lock.release() + + def _fetch(self): + if not self._row_desc: + raise ProgrammingError("no result set") + self._lock.acquire() + try: + if not self._cached_rows: + if self._command_complete: + return None + self._fill_cache() + if self._command_complete and not self._cached_rows: + # fill cache tells us the command is complete, but yet we have + # no rows after filling our cache. This is a special case when + # a query returns no rows. + return None + row = self._cached_rows.pop(0) + self._ongoing_row_count += 1 + return tuple(row) + finally: + self._lock.release() + + ## + # Return a count of the number of rows relevant to the executed statement. + # For a SELECT, this is the number of rows returned. For UPDATE or DELETE, + # this the number of rows affected. For INSERT, the number of rows + # inserted. This property may have a value of -1 to indicate that there + # was no row count. + #

+ # During a result-set query (eg. SELECT, or INSERT ... RETURNING ...), + # accessing this property requires reading the entire result-set into + # memory, as reading the data to completion is the only way to determine + # the total number of rows. Avoid using this property in with + # result-set queries, as it may cause unexpected memory usage. + #

+ # Stability: Added in v1.03, stability guaranteed for v1.xx. + row_count = property(lambda self: self._get_row_count()) + def _get_row_count(self): + self._lock.acquire() + try: + if not self._command_complete: + end_of_data, rows = self.c.fetch_rows(self._portal_name, 0, self._row_desc) + self._cached_rows += rows + if end_of_data: + self._command_complete = True + else: + raise InternalError("fetch_rows(0) did not hit end of data") + return self._ongoing_row_count + len(self._cached_rows) + finally: + self._lock.release() + + ## + # Read a row from the database server, and return it in a dictionary + # indexed by column name/alias. This method will raise an error if two + # columns have the same name. Returns None after the last row. + #

+ # Stability: Added in v1.00, stability guaranteed for v1.xx. + def read_dict(self): + row = self._fetch() + if row == None: + return row + retval = {} + for i in range(len(self._row_desc.fields)): + col_name = self._row_desc.fields[i]['name'] + if retval.has_key(col_name): + raise InterfaceError("cannot return dict of row when two columns have the same name (%r)" % (col_name,)) + retval[col_name] = row[i] + return retval + + ## + # Read a row from the database server, and return it as a tuple of values. + # Returns None after the last row. + #

+ # Stability: Added in v1.00, stability guaranteed for v1.xx. + def read_tuple(self): + return self._fetch() + + ## + # Return an iterator for the output of this statement. The iterator will + # return a tuple for each row, in the same manner as {@link + # #PreparedStatement.read_tuple read_tuple}. + #

+ # Stability: Added in v1.00, stability guaranteed for v1.xx. + def iterate_tuple(self): + return DataIterator(self, PreparedStatement.read_tuple) + + ## + # Return an iterator for the output of this statement. The iterator will + # return a dict for each row, in the same manner as {@link + # #PreparedStatement.read_dict read_dict}. + #

+ # Stability: Added in v1.00, stability guaranteed for v1.xx. + def iterate_dict(self): + return DataIterator(self, PreparedStatement.read_dict) + + +class SimpleStatement(PreparedStatement): + "Internal wrapper to Simple Query protocol emulating a PreparedStatement" + + # This should be used internally only for trivial queries + # (not a true Prepared Statement, in fact it can have multiple statements) + # See Simple Query Protocol limitations and trade-offs (send_simple_query) + + row_cache_size = None + + def __init__(self, connection, statement): + if connection == None or connection.c == None: + raise InterfaceError("connection not provided") + self.c = connection.c + self._row_desc = None + self._cached_rows = [] + self._ongoing_row_count = -1 + self._command_complete = True + self.statement = statement + self._lock = threading.RLock() + + def close(self): + # simple query doesn't have portals + pass + + def execute(self, *args, **kwargs): + "Run the SQL simple query stataments" + self._lock.acquire() + try: + self._row_desc, cmd_complete, self._cached_rows = \ + self.c.send_simple_query(self.statement, kwargs.get("stream")) + self._command_complete = True + self._ongoing_row_count = -1 + if cmd_complete is not None and cmd_complete.rows is not None: + self._ongoing_row_count = cmd_complete.rows + finally: + self._lock.release() + + def _fill_cache(self): + # data rows are already fetched in _cached_rows + pass + + def _fetch(self): + if not self._row_desc: + raise ProgrammingError("no result set") + self._lock.acquire() + try: + if not self._cached_rows: + return None + row = self._cached_rows.pop(0) + return tuple(row) + finally: + self._lock.release() + + def _get_row_count(self): + return self._ongoing_row_count + + +## +# The Cursor class allows multiple queries to be performed concurrently with a +# single PostgreSQL connection. The Cursor object is implemented internally by +# using a {@link PreparedStatement PreparedStatement} object, so if you plan to +# use a statement multiple times, you might as well create a PreparedStatement +# and save a small amount of reparsing time. +#

+# As of v1.01, instances of this class are thread-safe. See {@link +# PreparedStatement PreparedStatement} for more information. +#

+# Stability: Added in v1.00, stability guaranteed for v1.xx. +# +# @param connection An instance of {@link Connection Connection}. +class Cursor(object): + def __init__(self, connection): + self.connection = connection + self._stmt = None + + def require_stmt(func): + def retval(self, *args, **kwargs): + if self._stmt == None: + raise ProgrammingError("attempting to use unexecuted cursor") + return func(self, *args, **kwargs) + return retval + + row_description = property(lambda self: self._getRowDescription()) + def _getRowDescription(self): + if self._stmt == None: + return None + return self._stmt.row_description + + ## + # Run an SQL statement using this cursor. The SQL statement can have + # parameters in the form of $1, $2, $3, etc., which will be filled in by + # the additional arguments passed to this function. + #

+ # Stability: Added in v1.00, stability guaranteed for v1.xx. + # @param query The SQL statement to execute. + def execute(self, query, *args, **kwargs): + if self.connection.is_closed: + raise ConnectionClosedError() + self.connection._unnamed_prepared_statement_lock.acquire() + try: + if kwargs.get("simple_query"): + # no arguments and no statement name, + # use PostgreSQL Simple Query Protocol + ## print "SimpleQuery:", query + self._stmt = SimpleStatement(self.connection, query) + else: + # use PostgreSQL Extended Query Protocol + self._stmt = PreparedStatement(self.connection, query, statement_name="", *[{"type": type(x), "value": x} for x in args]) + self._stmt.execute(*args, **kwargs) + finally: + self.connection._unnamed_prepared_statement_lock.release() + + ## + # Return a count of the number of rows currently being read. If possible, + # please avoid using this function. It requires reading the entire result + # set from the database to determine the number of rows being returned. + #

+ # Stability: Added in v1.03, stability guaranteed for v1.xx. + # Implementation currently requires caching entire result set into memory, + # avoid using this property. + row_count = property(lambda self: self._get_row_count()) + + @require_stmt + def _get_row_count(self): + return self._stmt.row_count + + ## + # Read a row from the database server, and return it in a dictionary + # indexed by column name/alias. This method will raise an error if two + # columns have the same name. Returns None after the last row. + #

+ # Stability: Added in v1.00, stability guaranteed for v1.xx. + @require_stmt + def read_dict(self): + return self._stmt.read_dict() + + ## + # Read a row from the database server, and return it as a tuple of values. + # Returns None after the last row. + #

+ # Stability: Added in v1.00, stability guaranteed for v1.xx. + @require_stmt + def read_tuple(self): + return self._stmt.read_tuple() + + ## + # Return an iterator for the output of this statement. The iterator will + # return a tuple for each row, in the same manner as {@link + # #PreparedStatement.read_tuple read_tuple}. + #

+ # Stability: Added in v1.00, stability guaranteed for v1.xx. + @require_stmt + def iterate_tuple(self): + return self._stmt.iterate_tuple() + + ## + # Return an iterator for the output of this statement. The iterator will + # return a dict for each row, in the same manner as {@link + # #PreparedStatement.read_dict read_dict}. + #

+ # Stability: Added in v1.00, stability guaranteed for v1.xx. + @require_stmt + def iterate_dict(self): + return self._stmt.iterate_dict() + + def close(self): + if self._stmt != None: + self._stmt.close() + self._stmt = None + + + ## + # Return the fileno of the underlying socket for this cursor's connection. + #

+ # Stability: Added in v1.07, stability guaranteed for v1.xx. + def fileno(self): + return self.connection.fileno() + + ## + # Poll the underlying socket for this cursor and sync if there is data waiting + # to be read. This has the effect of flushing asynchronous messages from the + # backend. Returns True if messages were read, False otherwise. + #

+ # Stability: Added in v1.07, stability guaranteed for v1.xx. + def isready(self): + return self.connection.isready() + + +## +# This class represents a connection to a PostgreSQL database. +#

+# The database connection is derived from the {@link #Cursor Cursor} class, +# which provides a default cursor for running queries. It also provides +# transaction control via the 'begin', 'commit', and 'rollback' methods. +# Without beginning a transaction explicitly, all statements will autocommit to +# the database. +#

+# As of v1.01, instances of this class are thread-safe. See {@link +# PreparedStatement PreparedStatement} for more information. +#

+# Stability: Added in v1.00, stability guaranteed for v1.xx. +# +# @param user The username to connect to the PostgreSQL server with. This +# parameter is required. +# +# @keyparam host The hostname of the PostgreSQL server to connect with. +# Providing this parameter is necessary for TCP/IP connections. One of either +# host, or unix_sock, must be provided. +# +# @keyparam unix_sock The path to the UNIX socket to access the database +# through, for example, '/tmp/.s.PGSQL.5432'. One of either unix_sock or host +# must be provided. The port parameter will have no affect if unix_sock is +# provided. +# +# @keyparam port The TCP/IP port of the PostgreSQL server instance. This +# parameter defaults to 5432, the registered and common port of PostgreSQL +# TCP/IP servers. +# +# @keyparam database The name of the database instance to connect with. This +# parameter is optional, if omitted the PostgreSQL server will assume the +# database name is the same as the username. +# +# @keyparam password The user password to connect to the server with. This +# parameter is optional. If omitted, and the database server requests password +# based authentication, the connection will fail. On the other hand, if this +# parameter is provided and the database does not request password +# authentication, then the password will not be used. +# +# @keyparam socket_timeout Socket connect timeout measured in seconds. +# Defaults to 60 seconds. +# +# @keyparam ssl Use SSL encryption for TCP/IP socket. Defaults to False. +class Connection(Cursor): + def __init__(self, dsn="", user=None, host=None, unix_sock=None, port=5432, database=None, password=None, socket_timeout=60, ssl=False): + self._row_desc = None + if dsn: + # update connection parameters parsed of the conninfo dsn + opts = conninfo_parse(dsn) + database = opts.get("dbname", database) + user = opts.get("user", user) + password = opts.get("password", user) + host = opts.get("host", host) + port = int(opts.get("port", port)) + ssl = opts.get("sslmode", 'disable') != 'disable' + try: + self.c = protocol.Connection(unix_sock=unix_sock, host=host, port=port, socket_timeout=socket_timeout, ssl=ssl) + self.c.authenticate(user, password=password, database=database) + except socket.error, e: + raise InterfaceError("communication error", e) + Cursor.__init__(self, self) + self._begin = PreparedStatement(self, "BEGIN TRANSACTION") + self._commit = PreparedStatement(self, "COMMIT TRANSACTION") + self._rollback = PreparedStatement(self, "ROLLBACK TRANSACTION") + self._unnamed_prepared_statement_lock = threading.RLock() + self.in_transaction = False + self.autocommit = False + + ## + # An event handler that is fired when NOTIFY occurs for a notification that + # has been LISTEN'd for. The value of this property is a + # util.MulticastDelegate. A callback can be added by using + # connection.NotificationReceived += SomeMethod. The method will be called + # with a single argument, an object that has properties: backend_pid, + # condition, and additional_info. Callbacks can be removed with the -= + # operator. + #

+ # Stability: Added in v1.03, stability guaranteed for v1.xx. + NotificationReceived = property( + lambda self: getattr(self.c, "NotificationReceived"), + lambda self, value: setattr(self.c, "NotificationReceived", value) + ) + + ## + # An event handler that is fired when the database server issues a notice. + # The value of this property is a util.MulticastDelegate. A callback can + # be added by using connection.NotificationReceived += SomeMethod. The + # method will be called with a single argument, an object that has + # properties: severity, code, msg, and possibly others (detail, hint, + # position, where, file, line, and routine). Callbacks can be removed with + # the -= operator. + #

+ # Stability: Added in v1.03, stability guaranteed for v1.xx. + NoticeReceived = property( + lambda self: getattr(self.c, "NoticeReceived"), + lambda self, value: setattr(self.c, "NoticeReceived", value) + ) + + ## + # An event handler that is fired when a runtime configuration option is + # changed on the server. The value of this property is a + # util.MulticastDelegate. A callback can be added by using + # connection.NotificationReceived += SomeMethod. Callbacks can be removed + # with the -= operator. The method will be called with a single argument, + # an object that has properties "key" and "value". + #

+ # Stability: Added in v1.03, stability guaranteed for v1.xx. + ParameterStatusReceived = property( + lambda self: getattr(self.c, "ParameterStatusReceived"), + lambda self, value: setattr(self.c, "ParameterStatusReceived", value) + ) + + ## + # Begins a new transaction. + #

+ # Stability: Added in v1.00, stability guaranteed for v1.xx. + def begin(self): + if self.is_closed: + raise ConnectionClosedError() + if self.autocommit: + return + self._begin.execute() + self.in_transaction = True + + + ## + # Commits the running transaction. + #

+ # Stability: Added in v1.00, stability guaranteed for v1.xx. + def commit(self): + if self.is_closed: + raise ConnectionClosedError() + self._commit.execute() + self.in_transaction = False + + ## + # Rolls back the running transaction. + #

+ # Stability: Added in v1.00, stability guaranteed for v1.xx. + def rollback(self): + if self.is_closed: + raise ConnectionClosedError() + self._rollback.execute() + self.in_transaction = False + + ## + # Closes an open connection. + def close(self): + if self.is_closed: + raise ConnectionClosedError() + self.c.close() + self.c = None + + is_closed = property(lambda self: self.c == None) + + ## + # Return the fileno of the underlying socket for this connection. + #

+ # Stability: Added in v1.07, stability guaranteed for v1.xx. + def fileno(self): + return self.c.fileno() + + ## + # Poll the underlying socket for this connection and sync if there is data + # waiting to be read. This has the effect of flushing asynchronous + # messages from the backend. Returns True if messages were read, False + # otherwise. + #

+ # Stability: Added in v1.07, stability guaranteed for v1.xx. + def isready(self): + return self.c.isready() + + ## + # Return the server_version as reported from the connected server. + # Raises InterfaceError if no version has been reported from the server. + def server_version(self): + return self.c.server_version() + + def encoding(self, encoding=None): + "Returns the client_encoding as reported from the connected server" + return self.c.encoding() \ No newline at end of file diff --git a/gluon/contrib/pg8000/protocol.py b/gluon/contrib/pg8000/protocol.py new file mode 100644 index 00000000..16265d28 --- /dev/null +++ b/gluon/contrib/pg8000/protocol.py @@ -0,0 +1,1411 @@ +# vim: sw=4:expandtab:foldmethod=marker +# +# Copyright (c) 2007-2009, Mathieu Fenniak +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * The name of the author may not be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +__author__ = "Mathieu Fenniak" + +import socket +try: + import ssl as sslmodule +except ImportError: + sslmodule = None +import select +import threading +import struct +import hashlib +from cStringIO import StringIO + +from errors import * +from util import MulticastDelegate +import types + +## +# An SSLRequest message. To initiate an SSL-encrypted connection, an +# SSLRequest message is used rather than a {@link StartupMessage +# StartupMessage}. A StartupMessage is still sent, but only after SSL +# negotiation (if accepted). +#

+# Stability: This is an internal class. No stability guarantee is made. +class SSLRequest(object): + def __init__(self): + pass + + # Int32(8) - Message length, including self.
+ # Int32(80877103) - The SSL request code.
+ def serialize(self): + return struct.pack("!ii", 8, 80877103) + + +## +# A StartupMessage message. Begins a DB session, identifying the user to be +# authenticated as and the database to connect to. +#

+# Stability: This is an internal class. No stability guarantee is made. +class StartupMessage(object): + def __init__(self, user, database=None): + self.user = user + self.database = database + + # Int32 - Message length, including self. + # Int32(196608) - Protocol version number. Version 3.0. + # Any number of key/value pairs, terminated by a zero byte: + # String - A parameter name (user, database, or options) + # String - Parameter value + def serialize(self): + protocol = 196608 + val = struct.pack("!i", protocol) + val += "user\x00" + self.user + "\x00" + if self.database: + val += "database\x00" + self.database + "\x00" + val += "\x00" + val = struct.pack("!i", len(val) + 4) + val + return val + + +## +# Parse message. Creates a prepared statement in the DB session. +#

+# Stability: This is an internal class. No stability guarantee is made. +# +# @param ps Name of the prepared statement to create. +# @param qs Query string. +# @param type_oids An iterable that contains the PostgreSQL type OIDs for +# parameters in the query string. +class Parse(object): + def __init__(self, ps, qs, type_oids): + if isinstance(qs, unicode): + raise TypeError("qs must be encoded byte data") + self.ps = ps + self.qs = qs + self.type_oids = type_oids + + def __repr__(self): + return "" % (self.ps, self.qs) + + # Byte1('P') - Identifies the message as a Parse command. + # Int32 - Message length, including self. + # String - Prepared statement name. An empty string selects the unnamed + # prepared statement. + # String - The query string. + # Int16 - Number of parameter data types specified (can be zero). + # For each parameter: + # Int32 - The OID of the parameter data type. + def serialize(self): + val = self.ps + "\x00" + self.qs + "\x00" + val = val + struct.pack("!h", len(self.type_oids)) + for oid in self.type_oids: + # Parse message doesn't seem to handle the -1 type_oid for NULL + # values that other messages handle. So we'll provide type_oid 705, + # the PG "unknown" type. + if oid == -1: oid = 705 + val = val + struct.pack("!i", oid) + val = struct.pack("!i", len(val) + 4) + val + val = "P" + val + return val + + +## +# Bind message. Readies a prepared statement for execution. +#

+# Stability: This is an internal class. No stability guarantee is made. +# +# @param portal Name of the destination portal. +# @param ps Name of the source prepared statement. +# @param in_fc An iterable containing the format codes for input +# parameters. 0 = Text, 1 = Binary. +# @param params The parameters. +# @param out_fc An iterable containing the format codes for output +# parameters. 0 = Text, 1 = Binary. +# @param kwargs Additional arguments to pass to the type conversion +# methods. +class Bind(object): + def __init__(self, portal, ps, in_fc, params, out_fc, **kwargs): + self.portal = portal + self.ps = ps + self.in_fc = in_fc + self.params = [] + for i in range(len(params)): + if len(self.in_fc) == 0: + fc = 0 + elif len(self.in_fc) == 1: + fc = self.in_fc[0] + else: + fc = self.in_fc[i] + self.params.append(types.pg_value(params[i], fc, **kwargs)) + self.out_fc = out_fc + + def __repr__(self): + return "" % (self.portal, self.ps) + + # Byte1('B') - Identifies the Bind command. + # Int32 - Message length, including self. + # String - Name of the destination portal. + # String - Name of the source prepared statement. + # Int16 - Number of parameter format codes. + # For each parameter format code: + # Int16 - The parameter format code. + # Int16 - Number of parameter values. + # For each parameter value: + # Int32 - The length of the parameter value, in bytes, not including this + # this length. -1 indicates a NULL parameter value, in which no + # value bytes follow. + # Byte[n] - Value of the parameter. + # Int16 - The number of result-column format codes. + # For each result-column format code: + # Int16 - The format code. + def serialize(self): + retval = StringIO() + retval.write(self.portal + "\x00") + retval.write(self.ps + "\x00") + retval.write(struct.pack("!h", len(self.in_fc))) + for fc in self.in_fc: + retval.write(struct.pack("!h", fc)) + retval.write(struct.pack("!h", len(self.params))) + for param in self.params: + if param == None: + # special case, NULL value + retval.write(struct.pack("!i", -1)) + else: + retval.write(struct.pack("!i", len(param))) + retval.write(param) + retval.write(struct.pack("!h", len(self.out_fc))) + for fc in self.out_fc: + retval.write(struct.pack("!h", fc)) + val = retval.getvalue() + val = struct.pack("!i", len(val) + 4) + val + val = "B" + val + return val + + +## +# A Close message, used for closing prepared statements and portals. +#

+# Stability: This is an internal class. No stability guarantee is made. +# +# @param typ 'S' for prepared statement, 'P' for portal. +# @param name The name of the item to close. +class Close(object): + def __init__(self, typ, name): + if len(typ) != 1: + raise InternalError("Close typ must be 1 char") + self.typ = typ + self.name = name + + # Byte1('C') - Identifies the message as a close command. + # Int32 - Message length, including self. + # Byte1 - 'S' for prepared statement, 'P' for portal. + # String - The name of the item to close. + def serialize(self): + val = self.typ + self.name + "\x00" + val = struct.pack("!i", len(val) + 4) + val + val = "C" + val + return val + + +## +# A specialized Close message for a portal. +#

+# Stability: This is an internal class. No stability guarantee is made. +class ClosePortal(Close): + def __init__(self, name): + Close.__init__(self, "P", name) + + +## +# A specialized Close message for a prepared statement. +#

+# Stability: This is an internal class. No stability guarantee is made. +class ClosePreparedStatement(Close): + def __init__(self, name): + Close.__init__(self, "S", name) + + +## +# A Describe message, used for obtaining information on prepared statements +# and portals. +#

+# Stability: This is an internal class. No stability guarantee is made. +# +# @param typ 'S' for prepared statement, 'P' for portal. +# @param name The name of the item to close. +class Describe(object): + def __init__(self, typ, name): + if len(typ) != 1: + raise InternalError("Describe typ must be 1 char") + self.typ = typ + self.name = name + + # Byte1('D') - Identifies the message as a describe command. + # Int32 - Message length, including self. + # Byte1 - 'S' for prepared statement, 'P' for portal. + # String - The name of the item to close. + def serialize(self): + val = self.typ + self.name + "\x00" + val = struct.pack("!i", len(val) + 4) + val + val = "D" + val + return val + + +## +# A specialized Describe message for a portal. +#

+# Stability: This is an internal class. No stability guarantee is made. +class DescribePortal(Describe): + def __init__(self, name): + Describe.__init__(self, "P", name) + + def __repr__(self): + return "" % (self.name) + + +## +# A specialized Describe message for a prepared statement. +#

+# Stability: This is an internal class. No stability guarantee is made. +class DescribePreparedStatement(Describe): + def __init__(self, name): + Describe.__init__(self, "S", name) + + def __repr__(self): + return "" % (self.name) + + +## +# A Flush message forces the backend to deliver any data pending in its +# output buffers. +#

+# Stability: This is an internal class. No stability guarantee is made. +class Flush(object): + # Byte1('H') - Identifies the message as a flush command. + # Int32(4) - Length of message, including self. + def serialize(self): + return 'H\x00\x00\x00\x04' + + def __repr__(self): + return "" + +## +# Causes the backend to close the current transaction (if not in a BEGIN/COMMIT +# block), and issue ReadyForQuery. +#

+# Stability: This is an internal class. No stability guarantee is made. +class Sync(object): + # Byte1('S') - Identifies the message as a sync command. + # Int32(4) - Length of message, including self. + def serialize(self): + return 'S\x00\x00\x00\x04' + + def __repr__(self): + return "" + + +## +# Transmits a password. +#

+# Stability: This is an internal class. No stability guarantee is made. +class PasswordMessage(object): + def __init__(self, pwd): + self.pwd = pwd + + # Byte1('p') - Identifies the message as a password message. + # Int32 - Message length including self. + # String - The password. Password may be encrypted. + def serialize(self): + val = self.pwd + "\x00" + val = struct.pack("!i", len(val) + 4) + val + val = "p" + val + return val + + +## +# Requests that the backend execute a portal and retrieve any number of rows. +#

+# Stability: This is an internal class. No stability guarantee is made. +# @param row_count The number of rows to return. Can be zero to indicate the +# backend should return all rows. If the portal represents a +# query that does not return rows, no rows will be returned +# no matter what the row_count. +class Execute(object): + def __init__(self, portal, row_count): + self.portal = portal + self.row_count = row_count + + # Byte1('E') - Identifies the message as an execute message. + # Int32 - Message length, including self. + # String - The name of the portal to execute. + # Int32 - Maximum number of rows to return, if portal contains a query that + # returns rows. 0 = no limit. + def serialize(self): + val = self.portal + "\x00" + struct.pack("!i", self.row_count) + val = struct.pack("!i", len(val) + 4) + val + val = "E" + val + return val + + +class SimpleQuery(object): + "Requests that the backend execute a Simple Query (SQL string)" + + def __init__(self, query_string): + self.query_string = query_string + + # Byte1('Q') - Identifies the message as an query message. + # Int32 - Message length, including self. + # String - The query string itself. + def serialize(self): + val = self.query_string + "\x00" + val = struct.pack("!i", len(val) + 4) + val + val = "Q" + val + return val + + def __repr__(self): + return "" % (self.query_string) + +## +# Informs the backend that the connection is being closed. +#

+# Stability: This is an internal class. No stability guarantee is made. +class Terminate(object): + def __init__(self): + pass + + # Byte1('X') - Identifies the message as a terminate message. + # Int32(4) - Message length, including self. + def serialize(self): + return 'X\x00\x00\x00\x04' + +## +# Base class of all Authentication[*] messages. +#

+# Stability: This is an internal class. No stability guarantee is made. +class AuthenticationRequest(object): + def __init__(self, data): + pass + + # Byte1('R') - Identifies the message as an authentication request. + # Int32(8) - Message length, including self. + # Int32 - An authentication code that represents different + # authentication messages: + # 0 = AuthenticationOk + # 5 = MD5 pwd + # 2 = Kerberos v5 (not supported by pg8000) + # 3 = Cleartext pwd (not supported by pg8000) + # 4 = crypt() pwd (not supported by pg8000) + # 6 = SCM credential (not supported by pg8000) + # 7 = GSSAPI (not supported by pg8000) + # 8 = GSSAPI data (not supported by pg8000) + # 9 = SSPI (not supported by pg8000) + # Some authentication messages have additional data following the + # authentication code. That data is documented in the appropriate class. + def createFromData(data): + ident = struct.unpack("!i", data[:4])[0] + klass = authentication_codes.get(ident, None) + if klass != None: + return klass(data[4:]) + else: + raise NotSupportedError("authentication method %r not supported" % (ident,)) + createFromData = staticmethod(createFromData) + + def ok(self, conn, user, **kwargs): + raise InternalError("ok method should be overridden on AuthenticationRequest instance") + +## +# A message representing that the backend accepting the provided username +# without any challenge. +#

+# Stability: This is an internal class. No stability guarantee is made. +class AuthenticationOk(AuthenticationRequest): + def ok(self, conn, user, **kwargs): + return True + + +## +# A message representing the backend requesting an MD5 hashed password +# response. The response will be sent as md5(md5(pwd + login) + salt). +#

+# Stability: This is an internal class. No stability guarantee is made. +class AuthenticationMD5Password(AuthenticationRequest): + # Additional message data: + # Byte4 - Hash salt. + def __init__(self, data): + self.salt = "".join(struct.unpack("4c", data)) + + def ok(self, conn, user, password=None, **kwargs): + if password == None: + raise InterfaceError("server requesting MD5 password authentication, but no password was provided") + pwd = "md5" + hashlib.md5(hashlib.md5(password + user).hexdigest() + self.salt).hexdigest() + conn._send(PasswordMessage(pwd)) + conn._flush() + + reader = MessageReader(conn) + reader.add_message(AuthenticationRequest, lambda msg, reader: reader.return_value(msg.ok(conn, user)), reader) + reader.add_message(ErrorResponse, self._ok_error) + return reader.handle_messages() + + def _ok_error(self, msg): + if msg.code == "28000": + raise InterfaceError("md5 password authentication failed") + else: + raise msg.createException() + +authentication_codes = { + 0: AuthenticationOk, + 5: AuthenticationMD5Password, +} + + +## +# ParameterStatus message sent from backend, used to inform the frotnend of +# runtime configuration parameter changes. +#

+# Stability: This is an internal class. No stability guarantee is made. +class ParameterStatus(object): + def __init__(self, key, value): + self.key = key + self.value = value + + # Byte1('S') - Identifies ParameterStatus + # Int32 - Message length, including self. + # String - Runtime parameter name. + # String - Runtime parameter value. + def createFromData(data): + key = data[:data.find("\x00")] + value = data[data.find("\x00")+1:-1] + return ParameterStatus(key, value) + createFromData = staticmethod(createFromData) + + +## +# BackendKeyData message sent from backend. Contains a connection's process +# ID and a secret key. Can be used to terminate the connection's current +# actions, such as a long running query. Not supported by pg8000 yet. +#

+# Stability: This is an internal class. No stability guarantee is made. +class BackendKeyData(object): + def __init__(self, process_id, secret_key): + self.process_id = process_id + self.secret_key = secret_key + + # Byte1('K') - Identifier. + # Int32(12) - Message length, including self. + # Int32 - Process ID. + # Int32 - Secret key. + def createFromData(data): + process_id, secret_key = struct.unpack("!2i", data) + return BackendKeyData(process_id, secret_key) + createFromData = staticmethod(createFromData) + + +## +# Message representing a query with no data. +#

+# Stability: This is an internal class. No stability guarantee is made. +class NoData(object): + # Byte1('n') - Identifier. + # Int32(4) - Message length, including self. + def createFromData(data): + return NoData() + createFromData = staticmethod(createFromData) + + +## +# Message representing a successful Parse. +#

+# Stability: This is an internal class. No stability guarantee is made. +class ParseComplete(object): + # Byte1('1') - Identifier. + # Int32(4) - Message length, including self. + def createFromData(data): + return ParseComplete() + createFromData = staticmethod(createFromData) + + +## +# Message representing a successful Bind. +#

+# Stability: This is an internal class. No stability guarantee is made. +class BindComplete(object): + # Byte1('2') - Identifier. + # Int32(4) - Message length, including self. + def createFromData(data): + return BindComplete() + createFromData = staticmethod(createFromData) + + +## +# Message representing a successful Close. +#

+# Stability: This is an internal class. No stability guarantee is made. +class CloseComplete(object): + # Byte1('3') - Identifier. + # Int32(4) - Message length, including self. + def createFromData(data): + return CloseComplete() + createFromData = staticmethod(createFromData) + + +## +# Message representing data from an Execute has been received, but more data +# exists in the portal. +#

+# Stability: This is an internal class. No stability guarantee is made. +class PortalSuspended(object): + # Byte1('s') - Identifier. + # Int32(4) - Message length, including self. + def createFromData(data): + return PortalSuspended() + createFromData = staticmethod(createFromData) + + +## +# Message representing the backend is ready to process a new query. +#

+# Stability: This is an internal class. No stability guarantee is made. +class ReadyForQuery(object): + def __init__(self, status): + self._status = status + + ## + # I = Idle, T = Idle in Transaction, E = idle in failed transaction. + status = property(lambda self: self._status) + + def __repr__(self): + return "" % \ + {"I": "Idle", "T": "Idle in Transaction", "E": "Idle in Failed Transaction"}[self.status] + + # Byte1('Z') - Identifier. + # Int32(5) - Message length, including self. + # Byte1 - Status indicator. + def createFromData(data): + return ReadyForQuery(data) + createFromData = staticmethod(createFromData) + + +## +# Represents a notice sent from the server. This is not the same as a +# notification. A notice is just additional information about a query, such +# as a notice that a primary key has automatically been created for a table. +#

+# A NoticeResponse instance will have properties containing the data sent +# from the server: +#

+#

+# Stability: Added in pg8000 v1.03. Required properties severity, code, and +# msg are guaranteed for v1.xx. Other properties should be checked with +# hasattr before accessing. +class NoticeResponse(object): + responseKeys = { + "S": "severity", # always present + "C": "code", # always present + "M": "msg", # always present + "D": "detail", + "H": "hint", + "P": "position", + "p": "_position", + "q": "_query", + "W": "where", + "F": "file", + "L": "line", + "R": "routine", + } + + def __init__(self, **kwargs): + for arg, value in kwargs.items(): + setattr(self, arg, value) + + def __repr__(self): + return "" % (self.severity, self.code, self.msg) + + def dataIntoDict(data): + retval = {} + for s in data.split("\x00"): + if not s: continue + key, value = s[0], s[1:] + key = NoticeResponse.responseKeys.get(key, key) + retval[key] = value + return retval + dataIntoDict = staticmethod(dataIntoDict) + + # Byte1('N') - Identifier + # Int32 - Message length + # Any number of these, followed by a zero byte: + # Byte1 - code identifying the field type (see responseKeys) + # String - field value + def createFromData(data): + return NoticeResponse(**NoticeResponse.dataIntoDict(data)) + createFromData = staticmethod(createFromData) + + +## +# A message sent in case of a server-side error. Contains the same properties +# that {@link NoticeResponse NoticeResponse} contains. +#

+# Stability: Added in pg8000 v1.03. Required properties severity, code, and +# msg are guaranteed for v1.xx. Other properties should be checked with +# hasattr before accessing. +class ErrorResponse(object): + def __init__(self, **kwargs): + for arg, value in kwargs.items(): + setattr(self, arg, value) + + def __repr__(self): + return "" % (self.severity, self.code, self.msg) + + def createException(self): + return ProgrammingError(self.severity, self.code, self.msg) + + def createFromData(data): + return ErrorResponse(**NoticeResponse.dataIntoDict(data)) + createFromData = staticmethod(createFromData) + + +## +# A message sent if this connection receives a NOTIFY that it was LISTENing for. +#

+# Stability: Added in pg8000 v1.03. When limited to accessing properties from +# a notification event dispatch, stability is guaranteed for v1.xx. +class NotificationResponse(object): + def __init__(self, backend_pid, condition, additional_info): + self._backend_pid = backend_pid + self._condition = condition + self._additional_info = additional_info + + ## + # An integer representing the process ID of the backend that triggered + # the NOTIFY. + #

+ # Stability: Added in pg8000 v1.03, stability guaranteed for v1.xx. + backend_pid = property(lambda self: self._backend_pid) + + ## + # The name of the notification fired. + #

+ # Stability: Added in pg8000 v1.03, stability guaranteed for v1.xx. + condition = property(lambda self: self._condition) + + ## + # Currently unspecified by the PostgreSQL documentation as of v8.3.1. + #

+ # Stability: Added in pg8000 v1.03, stability guaranteed for v1.xx. + additional_info = property(lambda self: self._additional_info) + + def __repr__(self): + return "" % (self.backend_pid, self.condition, self.additional_info) + + def createFromData(data): + backend_pid = struct.unpack("!i", data[:4])[0] + data = data[4:] + null = data.find("\x00") + condition = data[:null] + data = data[null+1:] + null = data.find("\x00") + additional_info = data[:null] + return NotificationResponse(backend_pid, condition, additional_info) + createFromData = staticmethod(createFromData) + + +class ParameterDescription(object): + def __init__(self, type_oids): + self.type_oids = type_oids + def createFromData(data): + count = struct.unpack("!h", data[:2])[0] + type_oids = struct.unpack("!" + "i"*count, data[2:]) + return ParameterDescription(type_oids) + createFromData = staticmethod(createFromData) + + +class RowDescription(object): + def __init__(self, fields): + self.fields = fields + + def createFromData(data): + count = struct.unpack("!h", data[:2])[0] + data = data[2:] + fields = [] + for i in range(count): + null = data.find("\x00") + field = {"name": data[:null]} + data = data[null+1:] + field["table_oid"], field["column_attrnum"], field["type_oid"], field["type_size"], field["type_modifier"], field["format"] = struct.unpack("!ihihih", data[:18]) + data = data[18:] + fields.append(field) + return RowDescription(fields) + createFromData = staticmethod(createFromData) + +class CommandComplete(object): + def __init__(self, command, rows=None, oid=None): + self.command = command + self.rows = rows + self.oid = oid + + def createFromData(data): + values = data[:-1].split(" ") + args = {} + args['command'] = values[0] + if args['command'] in ("INSERT", "DELETE", "UPDATE", "MOVE", "FETCH", "COPY", "SELECT"): + args['rows'] = int(values[-1]) + if args['command'] == "INSERT": + args['oid'] = int(values[1]) + else: + args['command'] = data[:-1] + return CommandComplete(**args) + createFromData = staticmethod(createFromData) + + +class DataRow(object): + def __init__(self, fields): + self.fields = fields + + def createFromData(data): + count = struct.unpack("!h", data[:2])[0] + data = data[2:] + fields = [] + for i in range(count): + val_len = struct.unpack("!i", data[:4])[0] + data = data[4:] + if val_len == -1: + fields.append(None) + else: + fields.append(data[:val_len]) + data = data[val_len:] + return DataRow(fields) + createFromData = staticmethod(createFromData) + + +class CopyData(object): + # "d": CopyData, + def __init__(self, data): + self.data = data + + def createFromData(data): + return CopyData(data) + createFromData = staticmethod(createFromData) + + def serialize(self): + return 'd' + struct.pack('!i', len(self.data) + 4) + self.data + + +class CopyDone(object): + # Byte1('c') - Identifier. + # Int32(4) - Message length, including self. + + def createFromData(data): + return CopyDone() + + createFromData = staticmethod(createFromData) + + def serialize(self): + return 'c\x00\x00\x00\x04' + +class CopyOutResponse(object): + # Byte1('H') + # Int32(4) - Length of message contents in bytes, including self. + # Int8(1) - 0 textual, 1 binary + # Int16(2) - Number of columns + # Int16(N) - Format codes for each column (0 text, 1 binary) + + def __init__(self, is_binary, column_formats): + self.is_binary = is_binary + self.column_formats = column_formats + + def createFromData(data): + is_binary, num_cols = struct.unpack('!bh', data[:3]) + column_formats = struct.unpack('!' + ('h' * num_cols), data[3:]) + return CopyOutResponse(is_binary, column_formats) + + createFromData = staticmethod(createFromData) + + +class CopyInResponse(object): + # Byte1('G') + # Otherwise the same as CopyOutResponse + + def __init__(self, is_binary, column_formats): + self.is_binary = is_binary + self.column_formats = column_formats + + def createFromData(data): + is_binary, num_cols = struct.unpack('!bh', data[:3]) + column_formats = struct.unpack('!' + ('h' * num_cols), data[3:]) + return CopyInResponse(is_binary, column_formats) + + createFromData = staticmethod(createFromData) + + +class EmptyQueryResponse(object): + # Byte1('I') + # Response to an empty query string. (This substitutes for CommandComplete.) + + def createFromData(data): + return EmptyQueryResponse() + createFromData = staticmethod(createFromData) + + +class MessageReader(object): + def __init__(self, connection): + self._conn = connection + self._msgs = [] + + # If true, raise exception from an ErrorResponse after messages are + # processed. This can be used to leave the connection in a usable + # state after an error response, rather than having unconsumed + # messages that won't be understood in another context. + self.delay_raising_exception = False + + self.ignore_unhandled_messages = False + + def add_message(self, msg_class, handler, *args, **kwargs): + self._msgs.append((msg_class, handler, args, kwargs)) + + def clear_messages(self): + self._msgs = [] + + def return_value(self, value): + self._retval = value + + def handle_messages(self): + exc = None + while 1: + msg = self._conn._read_message() + msg_handled = False + for (msg_class, handler, args, kwargs) in self._msgs: + if isinstance(msg, msg_class): + msg_handled = True + retval = handler(msg, *args, **kwargs) + if retval: + # The handler returned a true value, meaning that the + # message loop should be aborted. + if exc != None: + raise exc + return retval + elif hasattr(self, "_retval"): + # The handler told us to return -- used for non-true + # return values + if exc != None: + raise exc + return self._retval + if msg_handled: + continue + elif isinstance(msg, ErrorResponse): + exc = msg.createException() + if not self.delay_raising_exception: + raise exc + elif isinstance(msg, NoticeResponse): + self._conn.handleNoticeResponse(msg) + elif isinstance(msg, ParameterStatus): + self._conn.handleParameterStatus(msg) + elif isinstance(msg, NotificationResponse): + self._conn.handleNotificationResponse(msg) + elif not self.ignore_unhandled_messages: + raise InternalError("Unexpected response msg %r" % (msg)) + +def sync_on_error(fn): + def _fn(self, *args, **kwargs): + try: + self._sock_lock.acquire() + return fn(self, *args, **kwargs) + except: + self._sync() + raise + finally: + self._sock_lock.release() + return _fn + +class Connection(object): + def __init__(self, unix_sock=None, host=None, port=5432, socket_timeout=60, ssl=False): + self._client_encoding = "ascii" + self._integer_datetimes = False + self._server_version = None + self._sock_buf = "" + self._sock_buf_pos = 0 + self._send_sock_buf = [] + self._block_size = 8192 + self._sock_lock = threading.Lock() + if unix_sock == None and host != None: + self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + elif unix_sock != None: + if not hasattr(socket, "AF_UNIX"): + raise InterfaceError("attempt to connect to unix socket on unsupported platform") + self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + else: + raise ProgrammingError("one of host or unix_sock must be provided") + if unix_sock == None and host != None: + self._sock.connect((host, port)) + elif unix_sock != None: + self._sock.connect(unix_sock) + if ssl: + self._sock_lock.acquire() + try: + self._send(SSLRequest()) + self._flush() + resp = self._sock.recv(1) + if resp == 'S' and sslmodule is not None: + self._sock = sslmodule.wrap_socket(self._sock) + elif sslmodule is None: + raise InterfaceError("SSL required but ssl module not available in this python installation") + else: + raise InterfaceError("server refuses SSL") + finally: + self._sock_lock.release() + else: + # settimeout causes ssl failure, on windows. Python bug 1462352. + self._sock.settimeout(socket_timeout) + self._state = "noauth" + self._backend_key_data = None + + self.NoticeReceived = MulticastDelegate() + self.ParameterStatusReceived = MulticastDelegate() + self.NotificationReceived = MulticastDelegate() + + self.ParameterStatusReceived += self._onParameterStatusReceived + + def verifyState(self, state): + if self._state != state: + raise InternalError("connection state must be %s, is %s" % (state, self._state)) + + def _send(self, msg): + assert self._sock_lock.locked() + ##print "_send(%r)" % msg + data = msg.serialize() + if not isinstance(data, str): + raise TypeError("bytes data expected") + self._send_sock_buf.append(data) + + def _flush(self): + assert self._sock_lock.locked() + self._sock.sendall("".join(self._send_sock_buf)) + del self._send_sock_buf[:] + + def _read_bytes(self, byte_count): + retval = [] + bytes_read = 0 + while bytes_read < byte_count: + if self._sock_buf_pos == len(self._sock_buf): + self._sock_buf = self._sock.recv(1024) + self._sock_buf_pos = 0 + rpos = min(len(self._sock_buf), self._sock_buf_pos + (byte_count - bytes_read)) + addt_data = self._sock_buf[self._sock_buf_pos:rpos] + bytes_read += (rpos - self._sock_buf_pos) + assert bytes_read <= byte_count + self._sock_buf_pos = rpos + retval.append(addt_data) + return "".join(retval) + + def _read_message(self): + assert self._sock_lock.locked() + bytes = self._read_bytes(5) + message_code = bytes[0] + data_len = struct.unpack("!i", bytes[1:])[0] - 4 + bytes = self._read_bytes(data_len) + assert len(bytes) == data_len + msg = message_types[message_code].createFromData(bytes) + ##print "_read_message() -> %r" % msg + return msg + + def authenticate(self, user, **kwargs): + self.verifyState("noauth") + self._sock_lock.acquire() + try: + self._send(StartupMessage(user, database=kwargs.get("database",None))) + self._flush() + + reader = MessageReader(self) + reader.add_message(AuthenticationRequest, self._authentication_request(user, **kwargs)) + reader.handle_messages() + finally: + self._sock_lock.release() + + def _authentication_request(self, user, **kwargs): + def _func(msg): + assert self._sock_lock.locked() + if not msg.ok(self, user, **kwargs): + raise InterfaceError("authentication method %s failed" % msg.__class__.__name__) + self._state = "auth" + reader = MessageReader(self) + reader.add_message(ReadyForQuery, self._ready_for_query) + reader.add_message(BackendKeyData, self._receive_backend_key_data) + reader.handle_messages() + return 1 + return _func + + def _ready_for_query(self, msg): + self._state = "ready" + return True + + def _receive_backend_key_data(self, msg): + self._backend_key_data = msg + + @sync_on_error + def parse(self, statement, qs, param_types): + self.verifyState("ready") + + type_info = [types.pg_type_info(x) for x in param_types] + param_types, param_fc = [x[0] for x in type_info], [x[1] for x in type_info] # zip(*type_info) -- fails on empty arr + if isinstance(qs, unicode): + qs = qs.encode(self._client_encoding) + self._send(Parse(statement, qs, param_types)) + self._send(DescribePreparedStatement(statement)) + self._send(Flush()) + self._flush() + + reader = MessageReader(self) + + # ParseComplete is good. + reader.add_message(ParseComplete, lambda msg: 0) + + # Well, we don't really care -- we're going to send whatever we + # want and let the database deal with it. But thanks anyways! + reader.add_message(ParameterDescription, lambda msg: 0) + + # We're not waiting for a row description. Return something + # destinctive to let bind know that there is no output. + reader.add_message(NoData, lambda msg: (None, param_fc)) + + # Common row description response + reader.add_message(RowDescription, lambda msg: (msg, param_fc)) + + return reader.handle_messages() + + @sync_on_error + def bind(self, portal, statement, params, parse_data, copy_stream): + self.verifyState("ready") + + row_desc, param_fc = parse_data + if row_desc == None: + # no data coming out + output_fc = () + else: + # We've got row_desc that allows us to identify what we're going to + # get back from this statement. + output_fc = [types.py_type_info(f) for f in row_desc.fields] + self._send(Bind(portal, statement, param_fc, params, output_fc, client_encoding = self._client_encoding, integer_datetimes = self._integer_datetimes)) + # We need to describe the portal after bind, since the return + # format codes will be different (hopefully, always what we + # requested). + self._send(DescribePortal(portal)) + self._send(Flush()) + self._flush() + + # Read responses from server... + reader = MessageReader(self) + + # BindComplete is good -- just ignore + reader.add_message(BindComplete, lambda msg: 0) + + # NoData in this case means we're not executing a query. As a + # result, we won't be fetching rows, so we'll never execute the + # portal we just created... unless we execute it right away, which + # we'll do. + reader.add_message(NoData, self._bind_nodata, portal, reader, copy_stream) + + # Return the new row desc, since it will have the format types we + # asked the server for + reader.add_message(RowDescription, lambda msg: (msg, None)) + + return reader.handle_messages() + + def _copy_in_response(self, copyin, fileobj, old_reader): + if fileobj == None: + raise CopyQueryWithoutStreamError() + while True: + data = fileobj.read(self._block_size) + if not data: + break + self._send(CopyData(data)) + self._flush() + self._send(CopyDone()) + self._send(Sync()) + self._flush() + + def _copy_out_response(self, copyout, fileobj, old_reader): + if fileobj == None: + raise CopyQueryWithoutStreamError() + reader = MessageReader(self) + reader.add_message(CopyData, self._copy_data, fileobj) + reader.add_message(CopyDone, lambda msg: 1) + reader.handle_messages() + + def _copy_data(self, copydata, fileobj): + fileobj.write(copydata.data) + + def _bind_nodata(self, msg, portal, old_reader, copy_stream): + # Bind message returned NoData, causing us to execute the command. + self._send(Execute(portal, 0)) + self._send(Sync()) + self._flush() + + output = {} + reader = MessageReader(self) + reader.add_message(CopyOutResponse, self._copy_out_response, copy_stream, reader) + reader.add_message(CopyInResponse, self._copy_in_response, copy_stream, reader) + reader.add_message(CommandComplete, lambda msg, out: out.setdefault('msg', msg) and False, output) + reader.add_message(ReadyForQuery, lambda msg: 1) + reader.delay_raising_exception = True + reader.handle_messages() + + old_reader.return_value((None, output['msg'])) + + @sync_on_error + def send_simple_query(self, query_string, copy_stream=None): + "Submit a simple query (PQsendQuery)" + + # Only use this for trivial queries, as its use is discouraged because: + # CONS: + # - Parameter are "injected" (they should be escaped by the app) + # - Exesive memory usage (allways returns all rows on completion) + # - Inneficient transmission of data in plain text (except for FETCH) + # - No Prepared Statement support, each query is parsed every time + # - Basic implementation: minimal error recovery and type support + # PROS: + # - compact: equivalent to Parse, Bind, Describe, Execute, Close, Sync + # - doesn't returns ParseComplete, BindComplete, CloseComplete, NoData + # - it supports multiple statements in a single query string + # - it is available when the Streaming Replication Protocol is actived + # NOTE: this is the protocol used by psycopg2 + # (they also uses named cursors to overcome some drawbacks) + + self.verifyState("ready") + + if isinstance(query_string, unicode): + query_string = query_string.encode(self._client_encoding) + + self._send(SimpleQuery(query_string)) + self._flush() + + # define local storage for message handlers: + output = {} + rows = [] + + # create and add handlers for all the possible messages: + reader = MessageReader(self) + + # read row description but continue processing messages... (return false) + reader.add_message(RowDescription, lambda msg, out: out.setdefault('row_desc', msg) and False, output) + reader.add_message(DataRow, lambda msg: self._fetch_datarow(msg, rows, output['row_desc'])) + reader.add_message(EmptyQueryResponse, lambda msg: False) + reader.add_message(CommandComplete, lambda msg, out: out.setdefault('complete', msg) and False, output) + reader.add_message(CopyInResponse, self._copy_in_response, copy_stream, reader) + reader.add_message(CopyOutResponse, self._copy_out_response, copy_stream, reader) + # messages indicating that we've hit the end of the available data for this command + reader.add_message(ReadyForQuery, lambda msg: 1) + # process all messages and then raise exceptions (if any) + reader.delay_raising_exception = True + # start processing the messages from the backend: + retval = reader.handle_messages() + + # return a dict with command complete / row description message values + return output.get('row_desc'), output.get('complete'), rows + + @sync_on_error + def fetch_rows(self, portal, row_count, row_desc): + self.verifyState("ready") + + self._send(Execute(portal, row_count)) + self._send(Flush()) + self._flush() + rows = [] + + reader = MessageReader(self) + reader.add_message(DataRow, self._fetch_datarow, rows, row_desc) + reader.add_message(PortalSuspended, lambda msg: 1) + reader.add_message(CommandComplete, self._fetch_commandcomplete, portal) + retval = reader.handle_messages() + + # retval = 2 when command complete, indicating that we've hit the + # end of the available data for this command + return (retval == 2), rows + + def _fetch_datarow(self, msg, rows, row_desc): + rows.append( + [ + types.py_value( + msg.fields[i], + row_desc.fields[i], + client_encoding=self._client_encoding, + integer_datetimes=self._integer_datetimes, + ) + for i in range(len(msg.fields)) + ] + ) + + def _fetch_commandcomplete(self, msg, portal): + self._send(ClosePortal(portal)) + self._send(Sync()) + self._flush() + + reader = MessageReader(self) + reader.add_message(ReadyForQuery, self._fetch_commandcomplete_rfq) + reader.add_message(CloseComplete, lambda msg: False) + reader.handle_messages() + + return 2 # signal end-of-data + + def _fetch_commandcomplete_rfq(self, msg): + self._state = "ready" + return True + + # Send a Sync message, then read and discard all messages until we + # receive a ReadyForQuery message. + def _sync(self): + # it is assumed _sync is called from sync_on_error, which holds + # a _sock_lock throughout the call + self._send(Sync()) + self._flush() + reader = MessageReader(self) + reader.ignore_unhandled_messages = True + reader.add_message(ReadyForQuery, lambda msg: True) + reader.handle_messages() + + def close_statement(self, statement): + if self._state == "closed": + return + self.verifyState("ready") + self._sock_lock.acquire() + try: + self._send(ClosePreparedStatement(statement)) + self._send(Sync()) + self._flush() + + reader = MessageReader(self) + reader.add_message(CloseComplete, lambda msg: 0) + reader.add_message(ReadyForQuery, lambda msg: 1) + reader.handle_messages() + finally: + self._sock_lock.release() + + def close_portal(self, portal): + if self._state == "closed": + return + self.verifyState("ready") + self._sock_lock.acquire() + try: + self._send(ClosePortal(portal)) + self._send(Sync()) + self._flush() + + reader = MessageReader(self) + reader.add_message(CloseComplete, lambda msg: 0) + reader.add_message(ReadyForQuery, lambda msg: 1) + reader.handle_messages() + finally: + self._sock_lock.release() + + def close(self): + self._sock_lock.acquire() + try: + self._send(Terminate()) + self._flush() + self._sock.close() + self._state = "closed" + finally: + self._sock_lock.release() + + def _onParameterStatusReceived(self, msg): + if msg.key == "client_encoding": + self._client_encoding = types.encoding_convert(msg.value) + ##print "_onParameterStatusReceived client_encoding", self._client_encoding + elif msg.key == "integer_datetimes": + self._integer_datetimes = (msg.value == "on") + elif msg.key == "server_version": + self._server_version = msg.value + else: + ##print "_onParameterStatusReceived ", msg.key, msg.value + pass + + def handleNoticeResponse(self, msg): + self.NoticeReceived(msg) + + def handleParameterStatus(self, msg): + self.ParameterStatusReceived(msg) + + def handleNotificationResponse(self, msg): + self.NotificationReceived(msg) + + def fileno(self): + # This should be safe to do without a lock + return self._sock.fileno() + + def isready(self): + self._sock_lock.acquire() + try: + rlst, _wlst, _xlst = select.select([self], [], [], 0) + if not rlst: + return False + + self._sync() + return True + finally: + self._sock_lock.release() + + def server_version(self): + self.verifyState("ready") + if not self._server_version: + raise InterfaceError("Server did not provide server_version parameter.") + return self._server_version + + def encoding(self): + return self._client_encoding + + +message_types = { + "N": NoticeResponse, + "R": AuthenticationRequest, + "S": ParameterStatus, + "K": BackendKeyData, + "Z": ReadyForQuery, + "T": RowDescription, + "E": ErrorResponse, + "D": DataRow, + "C": CommandComplete, + "1": ParseComplete, + "2": BindComplete, + "3": CloseComplete, + "s": PortalSuspended, + "n": NoData, + "I": EmptyQueryResponse, + "t": ParameterDescription, + "A": NotificationResponse, + "c": CopyDone, + "d": CopyData, + "G": CopyInResponse, + "H": CopyOutResponse, + } + + diff --git a/gluon/contrib/pg8000/types.py b/gluon/contrib/pg8000/types.py new file mode 100644 index 00000000..37cf0595 --- /dev/null +++ b/gluon/contrib/pg8000/types.py @@ -0,0 +1,708 @@ +# vim: sw=4:expandtab:foldmethod=marker +# +# Copyright (c) 2007-2009, Mathieu Fenniak +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * The name of the author may not be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +__author__ = "Mathieu Fenniak" + +import datetime +import decimal +import struct +import math +from errors import (NotSupportedError, ArrayDataParseError, InternalError, + ArrayContentEmptyError, ArrayContentNotHomogenousError, + ArrayContentNotSupportedError, ArrayDimensionsNotConsistentError) + +try: + from pytz import utc +except ImportError: + ZERO = datetime.timedelta(0) + class UTC(datetime.tzinfo): + def utcoffset(self, dt): + return ZERO + def tzname(self, dt): + return "UTC" + def dst(self, dt): + return ZERO + utc = UTC() + +class Bytea(str): + pass + +class Interval(object): + def __init__(self, microseconds=0, days=0, months=0): + self.microseconds = microseconds + self.days = days + self.months = months + + def _setMicroseconds(self, value): + if not isinstance(value, int) and not isinstance(value, long): + raise TypeError("microseconds must be an int or long") + elif not (min_int8 < value < max_int8): + raise OverflowError("microseconds must be representable as a 64-bit integer") + else: + self._microseconds = value + + def _setDays(self, value): + if not isinstance(value, int) and not isinstance(value, long): + raise TypeError("days must be an int or long") + elif not (min_int4 < value < max_int4): + raise OverflowError("days must be representable as a 32-bit integer") + else: + self._days = value + + def _setMonths(self, value): + if not isinstance(value, int) and not isinstance(value, long): + raise TypeError("months must be an int or long") + elif not (min_int4 < value < max_int4): + raise OverflowError("months must be representable as a 32-bit integer") + else: + self._months = value + + microseconds = property(lambda self: self._microseconds, _setMicroseconds) + days = property(lambda self: self._days, _setDays) + months = property(lambda self: self._months, _setMonths) + + def __repr__(self): + return "" % (self.months, self.days, self.microseconds) + + def __cmp__(self, other): + if other == None: return -1 + c = cmp(self.months, other.months) + if c != 0: return c + c = cmp(self.days, other.days) + if c != 0: return c + return cmp(self.microseconds, other.microseconds) + +def pg_type_info(typ): + value = None + if isinstance(typ, dict): + value = typ["value"] + typ = typ["type"] + + data = py_types.get(typ) + if data == None: + raise NotSupportedError("type %r not mapped to pg type" % typ) + + # permit the type data to be determined by the value, if provided + inspect_func = data.get("inspect") + if value != None and inspect_func != None: + data = inspect_func(value) + + type_oid = data.get("typeoid") + if type_oid == None: + raise InternalError("type %r has no type_oid" % typ) + elif type_oid == -1: + # special case: NULL values + return type_oid, 0 + + # prefer bin, but go with whatever exists + if data.get("bin_out"): + format = 1 + elif data.get("txt_out"): + format = 0 + else: + raise InternalError("no conversion fuction for type %r" % typ) + + return type_oid, format + +def pg_value(value, fc, **kwargs): + typ = type(value) + data = py_types.get(typ) + if data == None: + raise NotSupportedError("type %r not mapped to pg type" % typ) + + # permit the type conversion to be determined by the value, if provided + inspect_func = data.get("inspect") + if value != None and inspect_func != None: + data = inspect_func(value) + + # special case: NULL values + if data.get("typeoid") == -1: + return None + + if fc == 0: + func = data.get("txt_out") + elif fc == 1: + func = data.get("bin_out") + else: + raise InternalError("unrecognized format code %r" % fc) + if func == None: + raise NotSupportedError("type %r, format code %r not supported" % (typ, fc)) + return func(value, **kwargs) + +def py_type_info(description): + type_oid = description['type_oid'] + data = pg_types.get(type_oid) + if data == None: + raise NotSupportedError("type oid %r not mapped to py type" % type_oid) + # prefer bin, but go with whatever exists + if data.get("bin_in"): + format = 1 + elif data.get("txt_in"): + format = 0 + else: + raise InternalError("no conversion fuction for type oid %r" % type_oid) + return format + +def py_value(v, description, **kwargs): + if v == None: + # special case - NULL value + return None + type_oid = description['type_oid'] + format = description['format'] + data = pg_types.get(type_oid) + if data == None: + raise NotSupportedError("type oid %r not supported" % type_oid) + if format == 0: + func = data.get("txt_in") + elif format == 1: + func = data.get("bin_in") + else: + raise NotSupportedError("format code %r not supported" % format) + if func == None: + raise NotSupportedError("data response format %r, type %r not supported" % (format, type_oid)) + return func(v, **kwargs) + +def boolrecv(data, **kwargs): + return data == "\x01" + +def boolsend(v, **kwargs): + if v: + return "\x01" + else: + return "\x00" + +min_int2, max_int2 = -2 ** 15, 2 ** 15 +min_int4, max_int4 = -2 ** 31, 2 ** 31 +min_int8, max_int8 = -2 ** 63, 2 ** 63 + +def int_inspect(value): + if min_int2 < value < max_int2: + return {"typeoid": 21, "bin_out": int2send} + elif min_int4 < value < max_int4: + return {"typeoid": 23, "bin_out": int4send} + elif min_int8 < value < max_int8: + return {"typeoid": 20, "bin_out": int8send} + else: + return {"typeoid": 1700, "bin_out": numeric_send} + +def int2recv(data, **kwargs): + return struct.unpack("!h", data)[0] + +def int2send(v, **kwargs): + return struct.pack("!h", v) + +def int4recv(data, **kwargs): + return struct.unpack("!i", data)[0] + +def int4send(v, **kwargs): + return struct.pack("!i", v) + +def int8recv(data, **kwargs): + return struct.unpack("!q", data)[0] + +def int8send(v, **kwargs): + return struct.pack("!q", v) + +def float4recv(data, **kwargs): + return struct.unpack("!f", data)[0] + +def float8recv(data, **kwargs): + return struct.unpack("!d", data)[0] + +def float8send(v, **kwargs): + return struct.pack("!d", v) + +def datetime_inspect(value): + if value.tzinfo != None: + # send as timestamptz if timezone is provided + return {"typeoid": 1184, "bin_out": timestamptz_send} + else: + # otherwise send as timestamp + return {"typeoid": 1114, "bin_out": timestamp_send} + +def timestamp_recv(data, integer_datetimes, **kwargs): + if integer_datetimes: + # data is 64-bit integer representing milliseconds since 2000-01-01 + val = struct.unpack("!q", data)[0] + return datetime.datetime(2000, 1, 1) + datetime.timedelta(microseconds = val) + else: + # data is double-precision float representing seconds since 2000-01-01 + val = struct.unpack("!d", data)[0] + return datetime.datetime(2000, 1, 1) + datetime.timedelta(seconds = val) + +# return a timezone-aware datetime instance if we're reading from a +# "timestamp with timezone" type. The timezone returned will always be UTC, +# but providing that additional information can permit conversion to local. +def timestamptz_recv(data, **kwargs): + return timestamp_recv(data, **kwargs).replace(tzinfo=utc) + +def timestamp_send(v, integer_datetimes, **kwargs): + delta = v - datetime.datetime(2000, 1, 1) + val = delta.microseconds + (delta.seconds * 1000000) + (delta.days * 86400000000) + if integer_datetimes: + # data is 64-bit integer representing milliseconds since 2000-01-01 + return struct.pack("!q", val) + else: + # data is double-precision float representing seconds since 2000-01-01 + return struct.pack("!d", val / 1000.0 / 1000.0) + +def timestamptz_send(v, **kwargs): + # timestamps should be sent as UTC. If they have zone info, + # convert them. + return timestamp_send(v.astimezone(utc).replace(tzinfo=None), **kwargs) + +def date_in(data, **kwargs): + year = int(data[0:4]) + month = int(data[5:7]) + day = int(data[8:10]) + return datetime.date(year, month, day) + +def date_out(v, **kwargs): + return v.isoformat() + +def time_in(data, **kwargs): + hour = int(data[0:2]) + minute = int(data[3:5]) + sec = decimal.Decimal(data[6:]) + return datetime.time(hour, minute, int(sec), int((sec - int(sec)) * 1000000)) + +def time_out(v, **kwargs): + return v.isoformat() + +def numeric_in(data, **kwargs): + if data.find(".") == -1: + return int(data) + else: + return decimal.Decimal(data) + +def numeric_recv(data, **kwargs): + num_digits, weight, sign, scale = struct.unpack("!hhhh", data[:8]) + data = data[8:] + digits = struct.unpack("!" + ("h" * num_digits), data) + weight = decimal.Decimal(weight) + retval = 0 + for d in digits: + d = decimal.Decimal(d) + retval += d * (10000 ** weight) + weight -= 1 + if sign: + retval *= -1 + return retval + +DEC_DIGITS = 4 +def numeric_send(d, **kwargs): + # This is a very straight port of src/backend/utils/adt/numeric.c set_var_from_str() + s = str(d) + pos = 0 + sign = 0 + if s[0] == '-': + sign = 0x4000 # NEG + pos=1 + elif s[0] == '+': + sign = 0 # POS + pos=1 + have_dp = False + decdigits = [0, 0, 0, 0] + dweight = -1 + dscale = 0 + for char in s[pos:]: + if char.isdigit(): + decdigits.append(int(char)) + if not have_dp: + dweight += 1 + else: + dscale += 1 + pos+=1 + elif char == '.': + have_dp = True + pos+=1 + else: + break + + if len(s) > pos: + char = s[pos] + if char == 'e' or char == 'E': + pos+=1 + exponent = int(s[pos:]) + dweight += exponent + dscale -= exponent + if dscale < 0: dscale = 0 + + if dweight >= 0: + weight = (dweight + 1 + DEC_DIGITS - 1) / DEC_DIGITS - 1 + else: + weight = -((-dweight - 1) / DEC_DIGITS + 1) + offset = (weight + 1) * DEC_DIGITS - (dweight + 1) + ndigits = (len(decdigits)-DEC_DIGITS + offset + DEC_DIGITS - 1) / DEC_DIGITS + + i = DEC_DIGITS - offset + decdigits.extend([0, 0, 0]) + ndigits_ = ndigits + digits = '' + while ndigits_ > 0: + # ifdef DEC_DIGITS == 4 + digits += struct.pack("!h", ((decdigits[i] * 10 + decdigits[i + 1]) * 10 + decdigits[i + 2]) * 10 + decdigits[i + 3]) + ndigits_ -= 1 + i += DEC_DIGITS + + # strip_var() + if ndigits == 0: + sign = 0x4000 # pos + weight = 0 + # ---------- + + retval = struct.pack("!hhhh", ndigits, weight, sign, dscale) + digits + return retval + +def numeric_out(v, **kwargs): + return str(v) + +# PostgreSQL encodings: +# http://www.postgresql.org/docs/8.3/interactive/multibyte.html +# Python encodings: +# http://www.python.org/doc/2.4/lib/standard-encodings.html +# +# Commented out encodings don't require a name change between PostgreSQL and +# Python. If the py side is None, then the encoding isn't supported. +pg_to_py_encodings = { + # Not supported: + "mule_internal": None, + "euc_tw": None, + + # Name fine as-is: + #"euc_jp", + #"euc_jis_2004", + #"euc_kr", + #"gb18030", + #"gbk", + #"johab", + #"sjis", + #"shift_jis_2004", + #"uhc", + #"utf8", + + # Different name: + "euc_cn": "gb2312", + "iso_8859_5": "is8859_5", + "iso_8859_6": "is8859_6", + "iso_8859_7": "is8859_7", + "iso_8859_8": "is8859_8", + "koi8": "koi8_r", + "latin1": "iso8859-1", + "latin2": "iso8859_2", + "latin3": "iso8859_3", + "latin4": "iso8859_4", + "latin5": "iso8859_9", + "latin6": "iso8859_10", + "latin7": "iso8859_13", + "latin8": "iso8859_14", + "latin9": "iso8859_15", + "sql_ascii": "ascii", + "win866": "cp886", + "win874": "cp874", + "win1250": "cp1250", + "win1251": "cp1251", + "win1252": "cp1252", + "win1253": "cp1253", + "win1254": "cp1254", + "win1255": "cp1255", + "win1256": "cp1256", + "win1257": "cp1257", + "win1258": "cp1258", +} + +def encoding_convert(encoding): + return pg_to_py_encodings.get(encoding.lower(), encoding) + +def varcharin(data, client_encoding, **kwargs): + return unicode(data, encoding_convert(client_encoding)) + +def textout(v, client_encoding, **kwargs): + if isinstance(v, unicode): + return v.encode(encoding_convert(client_encoding)) + else: + return v + +def byteasend(v, **kwargs): + return str(v) + +def bytearecv(data, **kwargs): + return Bytea(data) + +# interval support does not provide a Python-usable interval object yet +def interval_recv(data, integer_datetimes, **kwargs): + if integer_datetimes: + microseconds, days, months = struct.unpack("!qii", data) + else: + seconds, days, months = struct.unpack("!dii", data) + microseconds = int(seconds * 1000 * 1000) + return Interval(microseconds, days, months) + +def interval_send(data, integer_datetimes, **kwargs): + if integer_datetimes: + return struct.pack("!qii", data.microseconds, data.days, data.months) + else: + return struct.pack("!dii", data.microseconds / 1000.0 / 1000.0, data.days, data.months) + +def array_recv(data, **kwargs): + dim, hasnull, typeoid = struct.unpack("!iii", data[:12]) + data = data[12:] + + # get type conversion method for typeoid + conversion = pg_types[typeoid]["bin_in"] + + # Read dimension info + dim_lengths = [] + element_count = 1 + for idim in range(dim): + dim_len, dim_lbound = struct.unpack("!ii", data[:8]) + data = data[8:] + dim_lengths.append(dim_len) + element_count *= dim_len + + # Read all array values + array_values = [] + for i in range(element_count): + if len(data): + element_len, = struct.unpack("!i", data[:4]) + data = data[4:] + if element_len == -1: + array_values.append(None) + else: + array_values.append(conversion(data[:element_len], **kwargs)) + data = data[element_len:] + if data != "": + raise ArrayDataParseError("unexpected data left over after array read") + + # at this point, {{1,2,3},{4,5,6}}::int[][] looks like [1,2,3,4,5,6]. + # go through the dimensions and fix up the array contents to match + # expected dimensions + for dim_length in reversed(dim_lengths[1:]): + val = [] + while array_values: + val.append(array_values[:dim_length]) + array_values = array_values[dim_length:] + array_values = val + + return array_values + +def array_inspect(value): + # Check if array has any values. If not, we can't determine the proper + # array typeoid. + first_element = array_find_first_element(value) + if first_element == None: + raise ArrayContentEmptyError("array has no values") + + # supported array output + typ = type(first_element) + if issubclass(typ, int) or issubclass(typ, long): + # special int array support -- send as smallest possible array type + special_int_support = True + int2_ok, int4_ok, int8_ok = True, True, True + for v in array_flatten(value): + if v == None: + continue + if min_int2 < v < max_int2: + continue + int2_ok = False + if min_int4 < v < max_int4: + continue + int4_ok = False + if min_int8 < v < max_int8: + continue + int8_ok = False + if int2_ok: + array_typeoid = 1005 # INT2[] + elif int4_ok: + array_typeoid = 1007 # INT4[] + elif int8_ok: + array_typeoid = 1016 # INT8[] + else: + raise ArrayContentNotSupportedError("numeric not supported as array contents") + else: + special_int_support = False + array_typeoid = py_array_types.get(typ) + if array_typeoid == None: + raise ArrayContentNotSupportedError("type %r not supported as array contents" % typ) + + # check for homogenous array + for v in array_flatten(value): + if v != None and not (isinstance(v, typ) or (typ == long and isinstance(v, int)) or (typ == int and isinstance(v, long))): + raise ArrayContentNotHomogenousError("not all array elements are of type %r" % typ) + + # check that all array dimensions are consistent + array_check_dimensions(value) + + type_data = py_types[typ] + if special_int_support: + if array_typeoid == 1005: + type_data = {"typeoid": 21, "bin_out": int2send} + elif array_typeoid == 1007: + type_data = {"typeoid": 23, "bin_out": int4send} + elif array_typeoid == 1016: + type_data = {"typeoid": 20, "bin_out": int8send} + else: + type_data = py_types[typ] + return { + "typeoid": array_typeoid, + "bin_out": array_send(type_data["typeoid"], type_data["bin_out"]) + } + +def array_find_first_element(arr): + for v in array_flatten(arr): + if v != None: + return v + return None + +def array_flatten(arr): + for v in arr: + if isinstance(v, list): + for v2 in array_flatten(v): + yield v2 + else: + yield v + +def array_check_dimensions(arr): + v0 = arr[0] + if isinstance(v0, list): + req_len = len(v0) + req_inner_lengths = array_check_dimensions(v0) + for v in arr: + inner_lengths = array_check_dimensions(v) + if len(v) != req_len or inner_lengths != req_inner_lengths: + raise ArrayDimensionsNotConsistentError("array dimensions not consistent") + retval = [req_len] + retval.extend(req_inner_lengths) + return retval + else: + # make sure nothing else at this level is a list + for v in arr: + if isinstance(v, list): + raise ArrayDimensionsNotConsistentError("array dimensions not consistent") + return [] + +def array_has_null(arr): + for v in array_flatten(arr): + if v == None: + return True + return False + +def array_dim_lengths(arr): + v0 = arr[0] + if isinstance(v0, list): + retval = [len(v0)] + retval.extend(array_dim_lengths(v0)) + else: + return [len(arr)] + return retval + +class array_send(object): + def __init__(self, typeoid, bin_out_func): + self.typeoid = typeoid + self.bin_out_func = bin_out_func + + def __call__(self, arr, **kwargs): + has_null = array_has_null(arr) + dim_lengths = array_dim_lengths(arr) + data = struct.pack("!iii", len(dim_lengths), has_null, self.typeoid) + for i in dim_lengths: + data += struct.pack("!ii", i, 1) + for v in array_flatten(arr): + if v == None: + data += struct.pack("!i", -1) + else: + inner_data = self.bin_out_func(v, **kwargs) + data += struct.pack("!i", len(inner_data)) + data += inner_data + return data + +py_types = { + bool: {"typeoid": 16, "bin_out": boolsend}, + int: {"inspect": int_inspect}, + long: {"inspect": int_inspect}, + str: {"typeoid": 25, "bin_out": textout}, + unicode: {"typeoid": 25, "bin_out": textout}, + float: {"typeoid": 701, "bin_out": float8send}, + decimal.Decimal: {"typeoid": 1700, "bin_out": numeric_send}, + Bytea: {"typeoid": 17, "bin_out": byteasend}, + datetime.datetime: {"typeoid": 1114, "bin_out": timestamp_send, "inspect": datetime_inspect}, + datetime.date: {"typeoid": 1082, "txt_out": date_out}, + datetime.time: {"typeoid": 1083, "txt_out": time_out}, + Interval: {"typeoid": 1186, "bin_out": interval_send}, + type(None): {"typeoid": -1}, + list: {"inspect": array_inspect}, +} + +# py type -> pg array typeoid +py_array_types = { + float: 1022, + bool: 1000, + str: 1009, # TEXT[] + unicode: 1009, # TEXT[] + decimal.Decimal: 1231, # NUMERIC[] +} + +pg_types = { + 16: {"bin_in": boolrecv}, + 17: {"bin_in": bytearecv}, + 19: {"bin_in": varcharin}, # name type + 20: {"bin_in": int8recv}, + 21: {"bin_in": int2recv}, + 23: {"bin_in": int4recv, "txt_in": numeric_in}, + 25: {"bin_in": varcharin, "txt_in": varcharin}, # TEXT type + 26: {"txt_in": numeric_in}, # oid type + 142: {"bin_in": varcharin, "txt_in": varcharin}, # XML + 194: {"bin_in": varcharin}, # "string representing an internal node tree" + 700: {"bin_in": float4recv}, + 701: {"bin_in": float8recv}, + 705: {"txt_in": varcharin}, # UNKNOWN + 829: {"txt_in": varcharin}, # MACADDR type + 1000: {"bin_in": array_recv}, # BOOL[] + 1003: {"bin_in": array_recv}, # NAME[] + 1005: {"bin_in": array_recv}, # INT2[] + 1007: {"bin_in": array_recv, "txt_in": varcharin}, # INT4[] + 1009: {"bin_in": array_recv}, # TEXT[] + 1014: {"bin_in": array_recv}, # CHAR[] + 1015: {"bin_in": array_recv}, # VARCHAR[] + 1016: {"bin_in": array_recv}, # INT8[] + 1021: {"bin_in": array_recv}, # FLOAT4[] + 1022: {"bin_in": array_recv}, # FLOAT8[] + 1042: {"bin_in": varcharin}, # CHAR type + 1043: {"bin_in": varcharin}, # VARCHAR type + 1082: {"txt_in": date_in}, + 1083: {"txt_in": time_in}, + 1114: {"bin_in": timestamp_recv}, + 1184: {"bin_in": timestamptz_recv}, # timestamp w/ tz + 1186: {"bin_in": interval_recv}, + 1231: {"bin_in": array_recv}, # NUMERIC[] + 1263: {"bin_in": array_recv}, # cstring[] + 1700: {"bin_in": numeric_recv}, + 2275: {"bin_in": varcharin}, # cstring +} + diff --git a/gluon/contrib/pg8000/util.py b/gluon/contrib/pg8000/util.py new file mode 100644 index 00000000..d99421e1 --- /dev/null +++ b/gluon/contrib/pg8000/util.py @@ -0,0 +1,20 @@ + +class MulticastDelegate(object): + def __init__(self): + self.delegates = [] + + def __iadd__(self, delegate): + self.add(delegate) + return self + + def add(self, delegate): + self.delegates.append(delegate) + + def __isub__(self, delegate): + self.delegates.remove(delegate) + return self + + def __call__(self, *args, **kwargs): + for d in self.delegates: + d(*args, **kwargs) +