diff --git a/gluon/contrib/pypyodbc.py b/gluon/contrib/pypyodbc.py index fc2b2d12..7960a5bb 100644 --- a/gluon/contrib/pypyodbc.py +++ b/gluon/contrib/pypyodbc.py @@ -4,28 +4,55 @@ # The MIT License (MIT) # -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# Copyright (c) 2013 Henry Zhou and PyPyODBC contributors +# Copyright (c) 2004 Michele Petrazzo + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the "Software"), to deal in the Software without restriction, including without limitation # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: # -# The above copyright notice and this permission notice shall be included in all copies or substantial portions +# The above copyright notice and this permission notice shall be included in all copies or substantial portions # of the Software. # -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO #EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF -# CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO #EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF +# CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. +pooling = True +apilevel = '2.0' +paramstyle = 'qmark' +threadsafety = 1 +version = '1.2.0' +lowercase=True + +DEBUG = 0 +# Comment out all "if DEBUG:" statements like below for production +#if DEBUG:print 'DEBUGGING' + import sys, os, datetime, ctypes, threading from decimal import Decimal -try: - bytearray -except NameError: - # pre version 2.6 python does not have the bytearray type + +py_ver = sys.version[:3] +py_v3 = py_ver >= '3.0' + +if py_v3: + long = int + unicode = str + str_8b = bytes + buffer = memoryview + BYTE_1 = bytes('1','ascii') + use_unicode = True +else: + str_8b = str + BYTE_1 = '1' + use_unicode = False +if py_ver < '2.6': bytearray = str + if not hasattr(ctypes, 'c_ssize_t'): if ctypes.sizeof(ctypes.c_uint) == ctypes.sizeof(ctypes.c_void_p): @@ -35,23 +62,14 @@ if not hasattr(ctypes, 'c_ssize_t'): elif ctypes.sizeof(ctypes.c_ulonglong) == ctypes.sizeof(ctypes.c_void_p): ctypes.c_ssize_t = ctypes.c_longlong -DEBUG = 0 -# Comment out all "if DEBUG:" statements like below for production -if DEBUG: print 'DEBUGGING' -pooling = True lock = threading.Lock() shared_env_h = None -apilevel = '2.0' -paramstyle = 'qmark' -threadsafety = 1 -version = '0.9.3' -lowercase=True SQLWCHAR_SIZE = ctypes.sizeof(ctypes.c_wchar) #determin the size of Py_UNICODE #sys.maxunicode > 65536 and 'UCS4' or 'UCS2' -UNICODE_SIZE = sys.maxunicode > 65536 and 4 or 2 +UNICODE_SIZE = sys.maxunicode > 65536 and 4 or 2 # Define ODBC constants. They are widly used in ODBC documents and programs @@ -63,7 +81,7 @@ SQL_ATTR_CONNECTION_POOLING = 201; SQL_CP_ONE_PER_HENV = 2 SQL_FETCH_NEXT, SQL_FETCH_FIRST, SQL_FETCH_LAST = 0x01, 0x02, 0x04 SQL_NULL_HANDLE, SQL_HANDLE_ENV, SQL_HANDLE_DBC, SQL_HANDLE_STMT = 0, 1, 2, 3 -SQL_SUCCESS, SQL_SUCCESS_WITH_INFO = 0, 1 +SQL_SUCCESS, SQL_SUCCESS_WITH_INFO, SQL_ERROR = 0, 1, -1 SQL_NO_DATA = 100; SQL_NO_TOTAL = -4 SQL_ATTR_ACCESS_MODE = SQL_ACCESS_MODE = 101 SQL_ATTR_AUTOCOMMIT = SQL_AUTOCOMMIT = 102 @@ -93,56 +111,7 @@ SQL_RESET_PARAMS = 3 SQL_UNBIND = 2 SQL_CLOSE = 0 -SQL_TYPE_NULL = 0 -SQL_DECIMAL = 3 -SQL_FLOAT = 6 -SQL_DATE = 9 -SQL_TIME = 10 -SQL_TIMESTAMP = 11 -SQL_VARCHAR = 12 -SQL_LONGVARCHAR = -1 -SQL_VARBINARY = -3 -SQL_LONGVARBINARY = -4 -SQL_BIGINT = -5 -SQL_WVARCHAR = -9 -SQL_WLONGVARCHAR = -10 -SQL_ALL_TYPES = 0 -SQL_SIGNED_OFFSET = -20 -SQL_C_CHAR = SQL_CHAR = 1 -SQL_C_NUMERIC = SQL_NUMERIC = 2 -SQL_C_LONG = SQL_INTEGER = 4 -SQL_C_SLONG = SQL_C_LONG + SQL_SIGNED_OFFSET -SQL_C_SHORT = SQL_SMALLINT = 5 -SQL_C_FLOAT = SQL_REAL = 7 -SQL_C_DOUBLE = SQL_DOUBLE = 8 -SQL_C_TYPE_DATE = SQL_TYPE_DATE = 91 -SQL_C_TYPE_TIME = SQL_TYPE_TIME = 92 -SQL_C_BINARY = SQL_BINARY = -2 -SQL_C_SBIGINT = SQL_BIGINT + SQL_SIGNED_OFFSET -SQL_C_TINYINT = SQL_TINYINT = -6 -SQL_C_BIT = SQL_BIT = -7 -SQL_C_WCHAR = SQL_WCHAR = -8 -SQL_C_GUID = SQL_GUID = -11 -SQL_C_TYPE_TIMESTAMP = SQL_TYPE_TIMESTAMP = 93 -SQL_C_DEFAULT = 99 - -SQL_SS_TIME2 = -154 - -SQL_DESC_DISPLAY_SIZE = SQL_COLUMN_DISPLAY_SIZE - - -def dttm_cvt(x): - if x == '': return None - else: return datetime.datetime(int(x[0:4]),int(x[5:7]),int(x[8:10]),int(x[10:13]),int(x[14:16]),int(x[17:19]),int(x[20:].ljust(6,'0'))) - -def tm_cvt(x): - if x == '': return None - else: return datetime.time(int(x[0:2]),int(x[3:5]),int(x[6:8]),int(x[9:].ljust(6,'0'))) - -def dt_cvt(x): - if x == '': return None - else: return datetime.date(int(x[0:4]),int(x[5:7]),int(x[8:10])) # Below defines The constants for sqlgetinfo method, and their coresponding return types @@ -303,18 +272,16 @@ SQL_ALTER_TABLE : 'GI_UINTEGER',SQL_ASYNC_MODE : 'GI_UINTEGER',SQL_BATCH_ROW_COU SQL_BATCH_SUPPORT : 'GI_UINTEGER',SQL_BOOKMARK_PERSISTENCE : 'GI_UINTEGER',SQL_CATALOG_LOCATION : 'GI_USMALLINT', SQL_CATALOG_NAME : 'GI_YESNO',SQL_CATALOG_NAME_SEPARATOR : 'GI_STRING',SQL_CATALOG_TERM : 'GI_STRING', SQL_CATALOG_USAGE : 'GI_UINTEGER',SQL_COLLATION_SEQ : 'GI_STRING',SQL_COLUMN_ALIAS : 'GI_YESNO', -SQL_CONCAT_NULL_BEHAVIOR : 'GI_USMALLINT',SQL_CONVERT_FUNCTIONS : 'GI_UINTEGER', -SQL_CONVERT_VARCHAR : 'GI_UINTEGER',SQL_CORRELATION_NAME : 'GI_USMALLINT', -SQL_CREATE_ASSERTION : 'GI_UINTEGER',SQL_CREATE_CHARACTER_SET : 'GI_UINTEGER', +SQL_CONCAT_NULL_BEHAVIOR : 'GI_USMALLINT',SQL_CONVERT_FUNCTIONS : 'GI_UINTEGER',SQL_CONVERT_VARCHAR : 'GI_UINTEGER', +SQL_CORRELATION_NAME : 'GI_USMALLINT',SQL_CREATE_ASSERTION : 'GI_UINTEGER',SQL_CREATE_CHARACTER_SET : 'GI_UINTEGER', SQL_CREATE_COLLATION : 'GI_UINTEGER',SQL_CREATE_DOMAIN : 'GI_UINTEGER',SQL_CREATE_SCHEMA : 'GI_UINTEGER', SQL_CREATE_TABLE : 'GI_UINTEGER',SQL_CREATE_TRANSLATION : 'GI_UINTEGER',SQL_CREATE_VIEW : 'GI_UINTEGER', SQL_CURSOR_COMMIT_BEHAVIOR : 'GI_USMALLINT',SQL_CURSOR_ROLLBACK_BEHAVIOR : 'GI_USMALLINT',SQL_DATABASE_NAME : 'GI_STRING', SQL_DATA_SOURCE_NAME : 'GI_STRING',SQL_DATA_SOURCE_READ_ONLY : 'GI_YESNO',SQL_DATETIME_LITERALS : 'GI_UINTEGER', SQL_DBMS_NAME : 'GI_STRING',SQL_DBMS_VER : 'GI_STRING',SQL_DDL_INDEX : 'GI_UINTEGER', SQL_DEFAULT_TXN_ISOLATION : 'GI_UINTEGER',SQL_DESCRIBE_PARAMETER : 'GI_YESNO',SQL_DM_VER : 'GI_STRING', -SQL_DRIVER_NAME : 'GI_STRING',SQL_DRIVER_ODBC_VER : 'GI_STRING',SQL_DRIVER_VER : 'GI_STRING', -SQL_DROP_ASSERTION : 'GI_UINTEGER',SQL_DROP_CHARACTER_SET : 'GI_UINTEGER', -SQL_DROP_COLLATION : 'GI_UINTEGER',SQL_DROP_DOMAIN : 'GI_UINTEGER', +SQL_DRIVER_NAME : 'GI_STRING',SQL_DRIVER_ODBC_VER : 'GI_STRING',SQL_DRIVER_VER : 'GI_STRING',SQL_DROP_ASSERTION : 'GI_UINTEGER', +SQL_DROP_CHARACTER_SET : 'GI_UINTEGER', SQL_DROP_COLLATION : 'GI_UINTEGER',SQL_DROP_DOMAIN : 'GI_UINTEGER', SQL_DROP_SCHEMA : 'GI_UINTEGER',SQL_DROP_TABLE : 'GI_UINTEGER',SQL_DROP_TRANSLATION : 'GI_UINTEGER', SQL_DROP_VIEW : 'GI_UINTEGER',SQL_DYNAMIC_CURSOR_ATTRIBUTES1 : 'GI_UINTEGER',SQL_DYNAMIC_CURSOR_ATTRIBUTES2 : 'GI_UINTEGER', SQL_EXPRESSIONS_IN_ORDERBY : 'GI_YESNO',SQL_FILE_USAGE : 'GI_USMALLINT', @@ -394,71 +361,67 @@ class OdbcGenericError(Exception): self.value = value def __str__(self): return repr(self.value) - - -class Warning(StandardError): +class Warning(Exception): def __init__(self, error_code, error_desc): self.value = (error_code, error_desc) self.args = (error_code, error_desc) - - -class Error(StandardError): +class Error(Exception): def __init__(self, error_code, error_desc): self.value = (error_code, error_desc) self.args = (error_code, error_desc) - class InterfaceError(Error): def __init__(self, error_code, error_desc): self.value = (error_code, error_desc) self.args = (error_code, error_desc) - - class DatabaseError(Error): def __init__(self, error_code, error_desc): self.value = (error_code, error_desc) self.args = (error_code, error_desc) - - class InternalError(DatabaseError): def __init__(self, error_code, error_desc): self.value = (error_code, error_desc) self.args = (error_code, error_desc) - - class ProgrammingError(DatabaseError): def __init__(self, error_code, error_desc): self.value = (error_code, error_desc) self.args = (error_code, error_desc) - class DataError(DatabaseError): def __init__(self, error_code, error_desc): self.value = (error_code, error_desc) self.args = (error_code, error_desc) - class IntegrityError(DatabaseError): def __init__(self, error_code, error_desc): self.value = (error_code, error_desc) self.args = (error_code, error_desc) - class NotSupportedError(Error): def __init__(self, error_code, error_desc): self.value = (error_code, error_desc) self.args = (error_code, error_desc) - class OperationalError(DatabaseError): def __init__(self, error_code, error_desc): self.value = (error_code, error_desc) self.args = (error_code, error_desc) + + + +############################################################################ +# +# Find the ODBC library on the platform and connect to it using ctypes +# +############################################################################ +# Get the References of the platform's ODBC functions via ctypes +odbc_decoding = 'utf_16' +odbc_encoding = 'utf_16_le' +ucs_length = 2 -# Get the References of the platform's ODBC functions via ctypes if sys.platform in ('win32','cli'): ODBC_API = ctypes.windll.odbc32 # On Windows, the size of SQLWCHAR is hardcoded to 2-bytes. SQLWCHAR_SIZE = ctypes.sizeof(ctypes.c_ushort) else: - # Set load the library on linux + # Set load the library on linux try: # First try direct loading libodbc.so ODBC_API = ctypes.cdll.LoadLibrary('libodbc.so') @@ -470,10 +433,10 @@ else: if library is None: # If find_library still can not find the library # we try finding it manually from where libodbc.so usually appears - lib_paths = ("/usr/lib/libodbc.so","/usr/lib/i386-linux-gnu/libodbc.so","/usr/lib/x86_64-linux-gnu/libodbc.so") + lib_paths = ("/usr/lib/libodbc.so","/usr/lib/i386-linux-gnu/libodbc.so","/usr/lib/x86_64-linux-gnu/libodbc.so","/usr/lib/libiodbc.dylib") lib_paths = [path for path in lib_paths if os.path.exists(path)] if len(lib_paths) == 0 : - raise OdbcNoLibrary, 'ODBC Library is not found' + raise OdbcNoLibrary('ODBC Library is not found. Is LD_LIBRARY_PATH set?') else: library = lib_paths[0] @@ -482,14 +445,26 @@ else: ODBC_API = ctypes.cdll.LoadLibrary(library) except: # If still fail loading, abort. - raise OdbcLibraryError, 'Error while loading %s' % library + raise OdbcLibraryError('Error while loading ' + library) + + # only iODBC uses utf-32 / UCS4 encoding data, others normally use utf-16 / UCS2 + # So we set those for handling. + if 'libiodbc.dylib' in library: + odbc_decoding = 'utf_32' + odbc_encoding = 'utf_32_le' + ucs_length = 4 + # unixODBC defaults to 2-bytes SQLWCHAR, unless "-DSQL_WCHART_CONVERT" was # added to CFLAGS, in which case it will be the size of wchar_t. # Note that using 4-bytes SQLWCHAR will break most ODBC drivers, as driver # development mostly targets the Windows platform. - import commands - status, output = commands.getstatusoutput('odbc_config --cflags') + if py_v3: + from subprocess import getstatusoutput + else: + from commands import getstatusoutput + + status, output = getstatusoutput('odbc_config --cflags') if status == 0 and 'SQL_WCHART_CONVERT' in output: SQLWCHAR_SIZE = ctypes.sizeof(ctypes.c_wchar) else: @@ -498,48 +473,130 @@ else: create_buffer_u = ctypes.create_unicode_buffer create_buffer = ctypes.create_string_buffer -wchar_type = ctypes.c_wchar_p -to_unicode = lambda s: s +wchar_pointer = ctypes.c_wchar_p +UCS_buf = lambda s: s +def UCS_dec(buffer): + i = 0 + uchars = [] + while True: + uchar = buffer.raw[i:i + ucs_length].decode(odbc_decoding) + if uchar == unicode('\x00'): + break + uchars.append(uchar) + i += ucs_length + return ''.join(uchars) from_buffer_u = lambda buffer: buffer.value # This is the common case on Linux, which uses wide Python build together with # the default unixODBC without the "-DSQL_WCHART_CONVERT" CFLAGS. -if UNICODE_SIZE > SQLWCHAR_SIZE: - # We can only use unicode buffer if the size of wchar_t (UNICODE_SIZE) is - # the same as the size expected by the driver manager (SQLWCHAR_SIZE). - create_buffer_u = create_buffer - wchar_type = ctypes.c_char_p +if sys.platform not in ('win32','cli'): + if UNICODE_SIZE >= SQLWCHAR_SIZE: + # We can only use unicode buffer if the size of wchar_t (UNICODE_SIZE) is + # the same as the size expected by the driver manager (SQLWCHAR_SIZE). + create_buffer_u = create_buffer + wchar_pointer = ctypes.c_char_p - def to_unicode(s): - return s.encode('UTF-16LE') + def UCS_buf(s): + return s.encode(odbc_encoding) - def from_buffer_u(buffer): - i = 0 - uchars = [] - while True: - uchar = buffer.raw[i:i + 2].decode('UTF-16') - if uchar == u'\x00': - break - uchars.append(uchar) - i += 2 - return ''.join(uchars) + from_buffer_u = UCS_dec -# Exoteric case, don't really care. -elif UNICODE_SIZE < SQLWCHAR_SIZE: - raise OdbcLibraryError('Using narrow Python build with ODBC library ' - 'expecting wide unicode is not supported.') + # Exoteric case, don't really care. + elif UNICODE_SIZE < SQLWCHAR_SIZE: + raise OdbcLibraryError('Using narrow Python build with ODBC library ' + 'expecting wide unicode is not supported.') + + + + + + + + + + + + +############################################################ +# Database value to Python data type mappings +SQL_TYPE_NULL = 0 +SQL_DECIMAL = 3 +SQL_FLOAT = 6 +SQL_DATE = 9 +SQL_TIME = 10 +SQL_TIMESTAMP = 11 +SQL_VARCHAR = 12 +SQL_LONGVARCHAR = -1 +SQL_VARBINARY = -3 +SQL_LONGVARBINARY = -4 +SQL_BIGINT = -5 +SQL_WVARCHAR = -9 +SQL_WLONGVARCHAR = -10 +SQL_ALL_TYPES = 0 +SQL_SIGNED_OFFSET = -20 +SQL_SS_VARIANT = -150 +SQL_SS_UDT = -151 +SQL_SS_XML = -152 +SQL_SS_TIME2 = -154 + +SQL_C_CHAR = SQL_CHAR = 1 +SQL_C_NUMERIC = SQL_NUMERIC = 2 +SQL_C_LONG = SQL_INTEGER = 4 +SQL_C_SLONG = SQL_C_LONG + SQL_SIGNED_OFFSET +SQL_C_SHORT = SQL_SMALLINT = 5 +SQL_C_FLOAT = SQL_REAL = 7 +SQL_C_DOUBLE = SQL_DOUBLE = 8 +SQL_C_TYPE_DATE = SQL_TYPE_DATE = 91 +SQL_C_TYPE_TIME = SQL_TYPE_TIME = 92 +SQL_C_BINARY = SQL_BINARY = -2 +SQL_C_SBIGINT = SQL_BIGINT + SQL_SIGNED_OFFSET +SQL_C_TINYINT = SQL_TINYINT = -6 +SQL_C_BIT = SQL_BIT = -7 +SQL_C_WCHAR = SQL_WCHAR = -8 +SQL_C_GUID = SQL_GUID = -11 +SQL_C_TYPE_TIMESTAMP = SQL_TYPE_TIMESTAMP = 93 +SQL_C_DEFAULT = 99 + +SQL_DESC_DISPLAY_SIZE = SQL_COLUMN_DISPLAY_SIZE + +def dttm_cvt(x): + if py_v3: + x = x.decode('ascii') + if x == '': return None + else: return datetime.datetime(int(x[0:4]),int(x[5:7]),int(x[8:10]),int(x[10:13]),int(x[14:16]),int(x[17:19]),int(x[20:].ljust(6,'0'))) + +def tm_cvt(x): + if py_v3: + x = x.decode('ascii') + if x == '': return None + else: return datetime.time(int(x[0:2]),int(x[3:5]),int(x[6:8]),int(x[9:].ljust(6,'0'))) + +def dt_cvt(x): + if py_v3: + x = x.decode('ascii') + if x == '': return None + else: return datetime.date(int(x[0:4]),int(x[5:7]),int(x[8:10])) + +def Decimal_cvt(x): + if py_v3: + x = x.decode('ascii') + return Decimal(x) + +bytearray_cvt = bytearray +if sys.platform == 'cli': + bytearray_cvt = lambda x: bytearray(buffer(x)) + # Below Datatype mappings referenced the document at # http://infocenter.sybase.com/help/index.jsp?topic=/com.sybase.help.sdk_12.5.1.aseodbc/html/aseodbc/CACFDIGH.htm - SQL_data_type_dict = { \ #SQL Data TYPE 0.Python Data Type 1.Default Output Converter 2.Buffer Type 3.Buffer Allocator 4.Default Buffer Size -SQL_TYPE_NULL : (None, lambda x: None, SQL_C_CHAR, create_buffer, 2 ), +SQL_TYPE_NULL : (None, lambda x: None, SQL_C_CHAR, create_buffer, 2 ), SQL_CHAR : (str, lambda x: x, SQL_C_CHAR, create_buffer, 2048 ), -SQL_NUMERIC : (Decimal, Decimal, SQL_C_CHAR, create_buffer, 150 ), -SQL_DECIMAL : (Decimal, Decimal, SQL_C_CHAR, create_buffer, 150 ), +SQL_NUMERIC : (Decimal, Decimal_cvt, SQL_C_CHAR, create_buffer, 150 ), +SQL_DECIMAL : (Decimal, Decimal_cvt, SQL_C_CHAR, create_buffer, 150 ), SQL_INTEGER : (int, int, SQL_C_CHAR, create_buffer, 150 ), SQL_SMALLINT : (int, int, SQL_C_CHAR, create_buffer, 150 ), SQL_FLOAT : (float, float, SQL_C_CHAR, create_buffer, 150 ), @@ -551,19 +608,22 @@ SQL_SS_TIME2 : (datetime.time, tm_cvt, SQL_C_CH SQL_TIMESTAMP : (datetime.datetime, dttm_cvt, SQL_C_CHAR, create_buffer, 30 ), SQL_VARCHAR : (str, lambda x: x, SQL_C_CHAR, create_buffer, 2048 ), SQL_LONGVARCHAR : (str, lambda x: x, SQL_C_CHAR, create_buffer, 20500 ), -SQL_BINARY : (bytearray, bytearray, SQL_C_BINARY, create_buffer, 5120 ), -SQL_VARBINARY : (bytearray, bytearray, SQL_C_BINARY, create_buffer, 5120 ), -SQL_LONGVARBINARY : (bytearray, bytearray, SQL_C_BINARY, create_buffer, 20500 ), +SQL_BINARY : (bytearray, bytearray_cvt, SQL_C_BINARY, create_buffer, 5120 ), +SQL_VARBINARY : (bytearray, bytearray_cvt, SQL_C_BINARY, create_buffer, 5120 ), +SQL_LONGVARBINARY : (bytearray, bytearray_cvt, SQL_C_BINARY, create_buffer, 20500 ), SQL_BIGINT : (long, long, SQL_C_CHAR, create_buffer, 150 ), SQL_TINYINT : (int, int, SQL_C_CHAR, create_buffer, 150 ), -SQL_BIT : (bool, lambda x:x=='1', SQL_C_CHAR, create_buffer, 2 ), +SQL_BIT : (bool, lambda x:x == BYTE_1, SQL_C_CHAR, create_buffer, 2 ), SQL_WCHAR : (unicode, lambda x: x, SQL_C_WCHAR, create_buffer_u, 2048 ), SQL_WVARCHAR : (unicode, lambda x: x, SQL_C_WCHAR, create_buffer_u, 2048 ), SQL_GUID : (str, str, SQL_C_CHAR, create_buffer, 50 ), SQL_WLONGVARCHAR : (unicode, lambda x: x, SQL_C_WCHAR, create_buffer_u, 20500 ), SQL_TYPE_DATE : (datetime.date, dt_cvt, SQL_C_CHAR, create_buffer, 30 ), SQL_TYPE_TIME : (datetime.time, tm_cvt, SQL_C_CHAR, create_buffer, 20 ), -SQL_TYPE_TIMESTAMP : (datetime.datetime, dttm_cvt, SQL_C_CHAR, create_buffer, 30 ), +SQL_TYPE_TIMESTAMP : (datetime.datetime, dttm_cvt, SQL_C_CHAR, create_buffer, 30 ), +SQL_SS_VARIANT : (str, lambda x: x, SQL_C_CHAR, create_buffer, 2048 ), +SQL_SS_XML : (unicode, lambda x: x, SQL_C_WCHAR, create_buffer_u, 20500 ), +SQL_SS_UDT : (bytearray, bytearray_cvt, SQL_C_BINARY, create_buffer, 5120 ), } @@ -599,6 +659,8 @@ funcs_with_ret = [ "SQLDisconnect", "SQLDriverConnect", "SQLDriverConnectW", + "SQLDrivers", + "SQLDriversW", "SQLEndTran", "SQLExecDirect", "SQLExecDirectW", @@ -611,7 +673,9 @@ funcs_with_ret = [ "SQLFreeStmt", "SQLGetData", "SQLGetDiagRec", + "SQLGetDiagRecW", "SQLGetInfo", + "SQLGetInfoW", "SQLGetTypeInfo", "SQLMoreResults", "SQLNumParams", @@ -639,282 +703,172 @@ for func_name in funcs_with_ret: if sys.platform not in ('cli'): #Seems like the IronPython can not declare ctypes.POINTER type arguments ODBC_API.SQLAllocHandle.argtypes = [ - ctypes.c_short, - ctypes.c_void_p, - ctypes.POINTER(ctypes.c_void_p), + ctypes.c_short, ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p), ] ODBC_API.SQLBindParameter.argtypes = [ - ctypes.c_void_p, - ctypes.c_ushort, - ctypes.c_short, - ctypes.c_short, - ctypes.c_short, - ctypes.c_size_t, - ctypes.c_short, - ctypes.c_void_p, - ctypes.c_ssize_t, - ctypes.POINTER(ctypes.c_ssize_t), + ctypes.c_void_p, ctypes.c_ushort, ctypes.c_short, + ctypes.c_short, ctypes.c_short, ctypes.c_size_t, + ctypes.c_short, ctypes.c_void_p, ctypes.c_ssize_t, ctypes.POINTER(ctypes.c_ssize_t), ] ODBC_API.SQLColAttribute.argtypes = [ - ctypes.c_void_p, - ctypes.c_ushort, - ctypes.c_ushort, - ctypes.c_void_p, - ctypes.c_short, - ctypes.POINTER(ctypes.c_short), - ctypes.POINTER(ctypes.c_ssize_t), + ctypes.c_void_p, ctypes.c_ushort, ctypes.c_ushort, + ctypes.c_void_p, ctypes.c_short, ctypes.POINTER(ctypes.c_short), ctypes.POINTER(ctypes.c_ssize_t), ] ODBC_API.SQLDataSources.argtypes = [ - ctypes.c_void_p, - ctypes.c_ushort, - ctypes.c_char_p, - ctypes.c_short, - ctypes.POINTER(ctypes.c_short), - ctypes.c_char_p, - ctypes.c_short, - ctypes.POINTER(ctypes.c_short), + ctypes.c_void_p, ctypes.c_ushort, ctypes.c_char_p, + ctypes.c_short, ctypes.POINTER(ctypes.c_short), + ctypes.c_char_p, ctypes.c_short, ctypes.POINTER(ctypes.c_short), ] ODBC_API.SQLDescribeCol.argtypes = [ - ctypes.c_void_p, - ctypes.c_ushort, - ctypes.c_char_p, - ctypes.c_short, - ctypes.POINTER(ctypes.c_short), - ctypes.POINTER(ctypes.c_short), - ctypes.POINTER(ctypes.c_size_t), - ctypes.POINTER(ctypes.c_short), - ctypes.POINTER(ctypes.c_short), + ctypes.c_void_p, ctypes.c_ushort, ctypes.c_char_p, ctypes.c_short, + ctypes.POINTER(ctypes.c_short), ctypes.POINTER(ctypes.c_short), + ctypes.POINTER(ctypes.c_size_t), ctypes.POINTER(ctypes.c_short), ctypes.POINTER(ctypes.c_short), ] ODBC_API.SQLDescribeParam.argtypes = [ - ctypes.c_void_p, - ctypes.c_ushort, - ctypes.POINTER(ctypes.c_short), - ctypes.POINTER(ctypes.c_size_t), - ctypes.POINTER(ctypes.c_short), - ctypes.POINTER(ctypes.c_short), + ctypes.c_void_p, ctypes.c_ushort, + ctypes.POINTER(ctypes.c_short), ctypes.POINTER(ctypes.c_size_t), + ctypes.POINTER(ctypes.c_short), ctypes.POINTER(ctypes.c_short), ] ODBC_API.SQLDriverConnect.argtypes = [ - ctypes.c_void_p, - ctypes.c_void_p, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, - ctypes.POINTER(ctypes.c_short), - ctypes.c_ushort, + ctypes.c_void_p, ctypes.c_void_p, ctypes.c_char_p, + ctypes.c_short, ctypes.c_char_p, ctypes.c_short, + ctypes.POINTER(ctypes.c_short), ctypes.c_ushort, + ] + + ODBC_API.SQLDrivers.argtypes = [ + ctypes.c_void_p, ctypes.c_ushort, + ctypes.c_char_p, ctypes.c_short, ctypes.POINTER(ctypes.c_short), + ctypes.c_char_p, ctypes.c_short, ctypes.POINTER(ctypes.c_short), ] ODBC_API.SQLGetData.argtypes = [ - ctypes.c_void_p, - ctypes.c_ushort, - ctypes.c_short, - ctypes.c_void_p, - ctypes.c_ssize_t, - ctypes.POINTER(ctypes.c_ssize_t), + ctypes.c_void_p, ctypes.c_ushort, ctypes.c_short, + ctypes.c_void_p, ctypes.c_ssize_t, ctypes.POINTER(ctypes.c_ssize_t), ] ODBC_API.SQLGetDiagRec.argtypes = [ - ctypes.c_short, - ctypes.c_void_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.POINTER(ctypes.c_int), - ctypes.c_char_p, - ctypes.c_short, - ctypes.POINTER(ctypes.c_short), + ctypes.c_short, ctypes.c_void_p, ctypes.c_short, + ctypes.c_char_p, ctypes.POINTER(ctypes.c_int), + ctypes.c_char_p, ctypes.c_short, ctypes.POINTER(ctypes.c_short), ] ODBC_API.SQLGetInfo.argtypes = [ - ctypes.c_void_p, - ctypes.c_ushort, - ctypes.c_void_p, - ctypes.c_short, - ctypes.POINTER(ctypes.c_short), + ctypes.c_void_p, ctypes.c_ushort, ctypes.c_void_p, + ctypes.c_short, ctypes.POINTER(ctypes.c_short), ] ODBC_API.SQLRowCount.argtypes = [ - ctypes.c_void_p, - ctypes.POINTER(ctypes.c_ssize_t), + ctypes.c_void_p, ctypes.POINTER(ctypes.c_ssize_t), ] ODBC_API.SQLNumParams.argtypes = [ - ctypes.c_void_p, - ctypes.POINTER(ctypes.c_short), + ctypes.c_void_p, ctypes.POINTER(ctypes.c_short), ] ODBC_API.SQLNumResultCols.argtypes = [ - ctypes.c_void_p, - ctypes.POINTER(ctypes.c_short), + ctypes.c_void_p, ctypes.POINTER(ctypes.c_short), ] ODBC_API.SQLCloseCursor.argtypes = [ctypes.c_void_p] ODBC_API.SQLColumns.argtypes = [ - ctypes.c_void_p, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, + ctypes.c_void_p, ctypes.c_char_p, ctypes.c_short, + ctypes.c_char_p, ctypes.c_short, ctypes.c_char_p, + ctypes.c_short, ctypes.c_char_p, ctypes.c_short, ] ODBC_API.SQLConnect.argtypes = [ - ctypes.c_void_p, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, + ctypes.c_void_p, ctypes.c_char_p, ctypes.c_short, + ctypes.c_char_p, ctypes.c_short, ctypes.c_char_p, ctypes.c_short, ] - - ODBC_API.SQLDisconnect.argtypes = [ctypes.c_void_p] - ODBC_API.SQLEndTran.argtypes = [ - ctypes.c_short, - ctypes.c_void_p, - ctypes.c_short, + ctypes.c_short, ctypes.c_void_p, ctypes.c_short, ] ODBC_API.SQLExecute.argtypes = [ctypes.c_void_p] ODBC_API.SQLExecDirect.argtypes = [ - ctypes.c_void_p, - ctypes.c_char_p, - ctypes.c_int, + ctypes.c_void_p, ctypes.c_char_p, ctypes.c_int, ] ODBC_API.SQLFetch.argtypes = [ctypes.c_void_p] ODBC_API.SQLFetchScroll.argtypes = [ - ctypes.c_void_p, - ctypes.c_short, - ctypes.c_ssize_t, + ctypes.c_void_p, ctypes.c_short, ctypes.c_ssize_t, ] ODBC_API.SQLForeignKeys.argtypes = [ - ctypes.c_void_p, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, + ctypes.c_void_p, ctypes.c_char_p, ctypes.c_short, + ctypes.c_char_p, ctypes.c_short, ctypes.c_char_p, + ctypes.c_short, ctypes.c_char_p, ctypes.c_short, + ctypes.c_char_p, ctypes.c_short, ctypes.c_char_p, ctypes.c_short, ] ODBC_API.SQLFreeHandle.argtypes = [ - ctypes.c_short, - ctypes.c_void_p, + ctypes.c_short, ctypes.c_void_p, ] ODBC_API.SQLFreeStmt.argtypes = [ - ctypes.c_void_p, - ctypes.c_ushort, + ctypes.c_void_p, ctypes.c_ushort, ] ODBC_API.SQLGetTypeInfo.argtypes = [ - ctypes.c_void_p, - ctypes.c_short, + ctypes.c_void_p, ctypes.c_short, ] ODBC_API.SQLMoreResults.argtypes = [ctypes.c_void_p] ODBC_API.SQLPrepare.argtypes = [ - ctypes.c_void_p, - ctypes.c_char_p, - ctypes.c_int, + ctypes.c_void_p, ctypes.c_char_p, ctypes.c_int, ] ODBC_API.SQLPrimaryKeys.argtypes = [ - ctypes.c_void_p, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, + ctypes.c_void_p, ctypes.c_char_p, ctypes.c_short, + ctypes.c_char_p, ctypes.c_short, ctypes.c_char_p, ctypes.c_short, ] ODBC_API.SQLProcedureColumns.argtypes = [ - ctypes.c_void_p, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, + ctypes.c_void_p, ctypes.c_char_p, ctypes.c_short, + ctypes.c_char_p, ctypes.c_short, ctypes.c_char_p, + ctypes.c_short, ctypes.c_char_p, ctypes.c_short, ] ODBC_API.SQLProcedures.argtypes = [ - ctypes.c_void_p, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, + ctypes.c_void_p, ctypes.c_char_p, ctypes.c_short, + ctypes.c_char_p, ctypes.c_short, ctypes.c_char_p, ctypes.c_short, ] ODBC_API.SQLSetConnectAttr.argtypes = [ - ctypes.c_void_p, - ctypes.c_int, - ctypes.c_void_p, - ctypes.c_int, + ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_int, ] ODBC_API.SQLSetEnvAttr.argtypes = [ - ctypes.c_void_p, - ctypes.c_int, - ctypes.c_void_p, - ctypes.c_int, + ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_int, ] ODBC_API.SQLStatistics.argtypes = [ - ctypes.c_void_p, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_ushort, - ctypes.c_ushort, + ctypes.c_void_p, ctypes.c_char_p, ctypes.c_short, + ctypes.c_char_p, ctypes.c_short, ctypes.c_char_p, + ctypes.c_short, ctypes.c_ushort, ctypes.c_ushort, ] ODBC_API.SQLTables.argtypes = [ - ctypes.c_void_p, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, - ctypes.c_char_p, - ctypes.c_short, + ctypes.c_void_p, ctypes.c_char_p, ctypes.c_short, + ctypes.c_char_p, ctypes.c_short, ctypes.c_char_p, + ctypes.c_short, ctypes.c_char_p, ctypes.c_short, ] def to_wchar(argtypes): @@ -922,7 +876,7 @@ def to_wchar(argtypes): result = [] for x in argtypes: if x == ctypes.c_char_p: - result.append(wchar_type) + result.append(wchar_pointer) else: result.append(x) return result @@ -934,6 +888,7 @@ ODBC_API.SQLConnectW.argtypes = to_wchar(ODBC_API.SQLConnect.argtypes) ODBC_API.SQLDataSourcesW.argtypes = to_wchar(ODBC_API.SQLDataSources.argtypes) ODBC_API.SQLDescribeColW.argtypes = to_wchar(ODBC_API.SQLDescribeCol.argtypes) ODBC_API.SQLDriverConnectW.argtypes = to_wchar(ODBC_API.SQLDriverConnect.argtypes) +ODBC_API.SQLDriversW.argtypes = to_wchar(ODBC_API.SQLDrivers.argtypes) ODBC_API.SQLExecDirectW.argtypes = to_wchar(ODBC_API.SQLExecDirect.argtypes) ODBC_API.SQLForeignKeysW.argtypes = to_wchar(ODBC_API.SQLForeignKeys.argtypes) ODBC_API.SQLPrepareW.argtypes = to_wchar(ODBC_API.SQLPrepare.argtypes) @@ -942,47 +897,66 @@ ODBC_API.SQLProcedureColumnsW.argtypes = to_wchar(ODBC_API.SQLProcedureColumns.a ODBC_API.SQLProceduresW.argtypes = to_wchar(ODBC_API.SQLProcedures.argtypes) ODBC_API.SQLStatisticsW.argtypes = to_wchar(ODBC_API.SQLStatistics.argtypes) ODBC_API.SQLTablesW.argtypes = to_wchar(ODBC_API.SQLTables.argtypes) - +ODBC_API.SQLGetDiagRecW.argtypes = to_wchar(ODBC_API.SQLGetDiagRec.argtypes) +ODBC_API.SQLGetInfoW.argtypes = to_wchar(ODBC_API.SQLGetInfo.argtypes) # Set the alias for the ctypes functions for beter code readbility or performance. ADDR = ctypes.byref +c_short = ctypes.c_short +c_ssize_t = ctypes.c_ssize_t SQLFetch = ODBC_API.SQLFetch SQLExecute = ODBC_API.SQLExecute SQLBindParameter = ODBC_API.SQLBindParameter +SQLGetData = ODBC_API.SQLGetData +SQLRowCount = ODBC_API.SQLRowCount +SQLNumResultCols = ODBC_API.SQLNumResultCols +SQLEndTran = ODBC_API.SQLEndTran +# Set alias for beter code readbility or performance. +NO_FREE_STATEMENT = 0 +FREE_STATEMENT = 1 +BLANK_BYTE = str_8b() - - - - -def ctrl_err(ht, h, val_ret): +def ctrl_err(ht, h, val_ret, ansi): """Classify type of ODBC error from (type of handle, handle, return value) , and raise with a list""" - state = create_buffer(5) + + if ansi: + state = create_buffer(22) + Message = create_buffer(1024*4) + ODBC_func = ODBC_API.SQLGetDiagRec + if py_v3: + raw_s = lambda s: bytes(s,'ascii') + else: + raw_s = str_8b + else: + state = create_buffer_u(22) + Message = create_buffer_u(1024*4) + ODBC_func = ODBC_API.SQLGetDiagRecW + raw_s = unicode NativeError = ctypes.c_int() - Message = create_buffer(1024*10) - Buffer_len = ctypes.c_short() + Buffer_len = c_short() err_list = [] number_errors = 1 - + while 1: - ret = ODBC_API.SQLGetDiagRec(ht, h, number_errors, state, \ - NativeError, Message, len(Message), ADDR(Buffer_len)) + ret = ODBC_func(ht, h, number_errors, state, \ + ADDR(NativeError), Message, 1024, ADDR(Buffer_len)) if ret == SQL_NO_DATA_FOUND: #No more data, I can raise - if DEBUG: print err_list[0][1] + #print(err_list[0][1]) state = err_list[0][0] - err_text = '['+state+'] '+err_list[0][1] - if state[:2] in ('24','25','42'): + err_text = raw_s('[')+state+raw_s('] ')+err_list[0][1] + if state[:2] in (raw_s('24'),raw_s('25'),raw_s('42')): raise ProgrammingError(state,err_text) - elif state[:2] in ('22'): + elif state[:2] in (raw_s('22')): raise DataError(state,err_text) - elif state[:2] in ('23') or state == '40002': + elif state[:2] in (raw_s('23')) or state == raw_s('40002'): raise IntegrityError(state,err_text) - elif state == '0A000': + elif state == raw_s('0A000'): raise NotSupportedError(state,err_text) - elif state in ('HYT00','HYT01'): + elif state in (raw_s('HYT00'),raw_s('HYT01')): raise OperationalError(state,err_text) - elif state[:2] in ('IM','HY'): + elif state[:2] in (raw_s('IM'),raw_s('HY')): raise Error(state,err_text) else: raise DatabaseError(state,err_text) @@ -991,35 +965,46 @@ def ctrl_err(ht, h, val_ret): #The handle passed is an invalid handle raise ProgrammingError('', 'SQL_INVALID_HANDLE') elif ret == SQL_SUCCESS: - err_list.append((state.value, Message.value, NativeError.value)) + if ansi: + err_list.append((state.value, Message.value, NativeError.value)) + else: + err_list.append((from_buffer_u(state), from_buffer_u(Message), NativeError.value)) number_errors += 1 + elif ret == SQL_ERROR: + raise ProgrammingError('', 'SQL_ERROR') -def validate(ret, handle_type, handle): + +def check_success(ODBC_obj, ret): """ Validate return value, if not success, raise exceptions based on the handle """ if ret not in (SQL_SUCCESS, SQL_SUCCESS_WITH_INFO, SQL_NO_DATA): - ctrl_err(handle_type, handle, ret) - - + if isinstance(ODBC_obj, Cursor): + ctrl_err(SQL_HANDLE_STMT, ODBC_obj.stmt_h, ret, ODBC_obj.ansi) + elif isinstance(ODBC_obj, Connection): + ctrl_err(SQL_HANDLE_DBC, ODBC_obj.dbc_h, ret, ODBC_obj.ansi) + else: + ctrl_err(SQL_HANDLE_ENV, ODBC_obj, ret, False) + + def AllocateEnv(): if pooling: ret = ODBC_API.SQLSetEnvAttr(SQL_NULL_HANDLE, SQL_ATTR_CONNECTION_POOLING, SQL_CP_ONE_PER_HENV, SQL_IS_UINTEGER) - validate(ret, SQL_HANDLE_ENV, SQL_NULL_HANDLE) + check_success(SQL_NULL_HANDLE, ret) - ''' + ''' Allocate an ODBC environment by initializing the handle shared_env_h ODBC enviroment needed to be created, so connections can be created under it connections pooling can be shared under one environment ''' - global shared_env_h + global shared_env_h shared_env_h = ctypes.c_void_p() ret = ODBC_API.SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, ADDR(shared_env_h)) - validate(ret, SQL_HANDLE_ENV, shared_env_h) + check_success(shared_env_h, ret) # Set the ODBC environment's compatibil leve to ODBC 3.0 ret = ODBC_API.SQLSetEnvAttr(shared_env_h, SQL_ATTR_ODBC_VERSION, SQL_OV_ODBC3, 0) - validate(ret, SQL_HANDLE_ENV, shared_env_h) - + check_success(shared_env_h, ret) + """ Here, we have a few callables that determine how a result row is returned. @@ -1036,7 +1021,20 @@ def TupleRow(cursor): """ class Row(tuple): cursor_description = cursor.description - + + def get(self, field): + if not hasattr(self, 'field_dict'): + self.field_dict = {} + for i,item in enumerate(self): + self.field_dict[self.cursor_description[i][0]] = item + return self.field_dict.get(field) + + def __getitem__(self, field): + if isinstance(field, (unicode,str)): + return self.get(field) + else: + return tuple.__getitem__(self,field) + return Row @@ -1088,35 +1086,71 @@ def MutableNamedTupleRow(cursor): return Row +# When Null is used in a binary parameter, database usually would not +# accept the None for a binary field, so the work around is to use a +# Specical None that the pypyodbc moudle would know this NULL is for +# a binary field. +class BinaryNullType(): pass +BinaryNull = BinaryNullType() -# The get_type function is used to determine if parameters need to be re-binded +# The get_type function is used to determine if parameters need to be re-binded # against the changed parameter types -def get_type(v): - t = type(v) - if isinstance(v, str): - if len(v) >= 255: - t = 's' +# 'b' for bool, 'U' for long unicode string, 'u' for short unicode string +# 'S' for long 8 bit string, 's' for short 8 bit string, 'l' for big integer, 'i' for normal integer +# 'f' for float, 'D' for Decimal, 't' for datetime.time, 'd' for datetime.datetime, 'dt' for datetime.datetime +# 'bi' for binary +def get_type(v): + + if isinstance(v, bool): + return ('b',) elif isinstance(v, unicode): if len(v) >= 255: - t = 'u' - elif isinstance(v, Decimal): - sv = str(v).replace('-','').strip('0').split('.') - if len(sv)>1: - t = (len(sv[0])+len(sv[1]),len(sv[1])) + return ('U',(len(v)//1000+1)*1000) else: - t = (len(sv[0]),0) - return t + return ('u',) + elif isinstance(v, (str_8b,str)): + if len(v) >= 255: + return ('S',(len(v)//1000+1)*1000) + else: + return ('s',) + elif isinstance(v, (int, long)): + #SQL_BIGINT defination: http://msdn.microsoft.com/en-us/library/ms187745.aspx + if v > 2147483647 or v < -2147483648: + return ('l',) + else: + return ('i',) + elif isinstance(v, float): + return ('f',) + elif isinstance(v, BinaryNullType): + return ('BN',) + elif v is None: + return ('N',) + elif isinstance(v, Decimal): + t = v.as_tuple() #1.23 -> (1,2,3),-2 , 1.23*E7 -> (1,2,3),5 + return ('D',(len(t[1]),0 - t[2])) # number of digits, and number of decimal digits + + elif isinstance (v, datetime.datetime): + return ('dt',) + elif isinstance (v, datetime.date): + return ('d',) + elif isinstance(v, datetime.time): + return ('t',) + elif isinstance (v, (bytearray, buffer)): + return ('bi',(len(v)//1000+1)*1000) + + return type(v) # The Cursor Class. class Cursor: def __init__(self, conx, row_type_callable=None): - """ Initialize self._stmt_h, which is the handle of a statement + """ Initialize self.stmt_h, which is the handle of a statement A statement is actually the basis of a python"cursor" object """ - self._stmt_h = ctypes.c_void_p() + self.stmt_h = ctypes.c_void_p() self.connection = conx + self.ansi = conx.ansi self.row_type_callable = row_type_callable or TupleRow self.statement = None self._last_param_types = None @@ -1131,69 +1165,347 @@ class Cursor: self._outputsize = {} self._inputsizers = [] self.arraysize = 1 - ret = ODBC_API.SQLAllocHandle(SQL_HANDLE_STMT, self.connection.dbc_h, ADDR(self._stmt_h)) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - self.closed = False + ret = ODBC_API.SQLAllocHandle(SQL_HANDLE_STMT, self.connection.dbc_h, ADDR(self.stmt_h)) + check_success(self, ret) + self._PARAM_SQL_TYPE_LIST = [] + self.closed = False + + + def prepare(self, query_string): + """prepare a query""" + + #self._free_results(FREE_STATEMENT) + + if type(query_string) == unicode: + c_query_string = wchar_pointer(UCS_buf(query_string)) + ret = ODBC_API.SQLPrepareW(self.stmt_h, c_query_string, len(query_string)) + else: + c_query_string = ctypes.c_char_p(query_string) + ret = ODBC_API.SQLPrepare(self.stmt_h, c_query_string, len(query_string)) + if ret != SQL_SUCCESS: + check_success(self, ret) + + + self._PARAM_SQL_TYPE_LIST = [] + + if self.connection.support_SQLDescribeParam: + # SQLServer's SQLDescribeParam only supports DML SQL, so avoid the SELECT statement + if True:# 'SELECT' not in query_string.upper(): + #self._free_results(NO_FREE_STATEMENT) + NumParams = c_short() + ret = ODBC_API.SQLNumParams(self.stmt_h, ADDR(NumParams)) + if ret != SQL_SUCCESS: + check_success(self, ret) + + for col_num in range(NumParams.value): + ParameterNumber = ctypes.c_ushort(col_num + 1) + DataType = c_short() + ParameterSize = ctypes.c_size_t() + DecimalDigits = c_short() + Nullable = c_short() + ret = ODBC_API.SQLDescribeParam( + self.stmt_h, + ParameterNumber, + ADDR(DataType), + ADDR(ParameterSize), + ADDR(DecimalDigits), + ADDR(Nullable), + ) + if ret != SQL_SUCCESS: + try: + check_success(self, ret) + except DatabaseError: + if sys.exc_info()[1].value[0] == '07009': + self._PARAM_SQL_TYPE_LIST = [] + break + else: + raise sys.exc_info()[1] + except: + raise sys.exc_info()[1] + + self._PARAM_SQL_TYPE_LIST.append((DataType.value,DecimalDigits.value)) + + self.statement = query_string + def _BindParams(self, param_types, pram_io_list = []): + """Create parameter buffers based on param types, and bind them to the statement""" + # Clear the old Parameters + #self._free_results(NO_FREE_STATEMENT) + + # Get the number of query parameters judged by database. + NumParams = c_short() + ret = ODBC_API.SQLNumParams(self.stmt_h, ADDR(NumParams)) + if ret != SQL_SUCCESS: + check_success(self, ret) + + if len(param_types) != NumParams.value: + # In case number of parameters provided do not same as number required + error_desc = "The SQL contains %d parameter markers, but %d parameters were supplied" \ + %(NumParams.value,len(param_types)) + raise ProgrammingError('HY000',error_desc) + + + # Every parameter needs to be binded to a buffer + ParamBufferList = [] + # Temporary holder since we can only call SQLDescribeParam before + # calling SQLBindParam. + temp_holder = [] + for col_num in range(NumParams.value): + dec_num = 0 + buf_size = 512 + + if param_types[col_num][0] == 'u': + sql_c_type = SQL_C_WCHAR + sql_type = SQL_WVARCHAR + buf_size = 255 + ParameterBuffer = create_buffer_u(buf_size) + + elif param_types[col_num][0] == 's': + sql_c_type = SQL_C_CHAR + sql_type = SQL_VARCHAR + buf_size = 255 + ParameterBuffer = create_buffer(buf_size) + + + elif param_types[col_num][0] == 'U': + sql_c_type = SQL_C_WCHAR + sql_type = SQL_WLONGVARCHAR + buf_size = param_types[col_num][1]#len(self._inputsizers)>col_num and self._inputsizers[col_num] or 20500 + ParameterBuffer = create_buffer_u(buf_size) + + elif param_types[col_num][0] == 'S': + sql_c_type = SQL_C_CHAR + sql_type = SQL_LONGVARCHAR + buf_size = param_types[col_num][1]#len(self._inputsizers)>col_num and self._inputsizers[col_num] or 20500 + ParameterBuffer = create_buffer(buf_size) + + # bool subclasses int, thus has to go first + elif param_types[col_num][0] == 'b': + sql_c_type = SQL_C_CHAR + sql_type = SQL_BIT + buf_size = SQL_data_type_dict[sql_type][4] + ParameterBuffer = create_buffer(buf_size) + + elif param_types[col_num][0] == 'i': + sql_c_type = SQL_C_CHAR + sql_type = SQL_INTEGER + buf_size = SQL_data_type_dict[sql_type][4] + ParameterBuffer = create_buffer(buf_size) + + elif param_types[col_num][0] == 'l': + sql_c_type = SQL_C_CHAR + sql_type = SQL_BIGINT + buf_size = SQL_data_type_dict[sql_type][4] + ParameterBuffer = create_buffer(buf_size) + + + elif param_types[col_num][0] == 'D': #Decimal + sql_c_type = SQL_C_CHAR + sql_type = SQL_NUMERIC + digit_num, dec_num = param_types[col_num][1] + if dec_num > 0: + # has decimal + buf_size = digit_num + dec_num = dec_num + else: + # no decimal + buf_size = digit_num - dec_num + dec_num = 0 + + ParameterBuffer = create_buffer(buf_size + 4)# add extra length for sign and dot + + + elif param_types[col_num][0] == 'f': + sql_c_type = SQL_C_CHAR + sql_type = SQL_DOUBLE + buf_size = SQL_data_type_dict[sql_type][4] + ParameterBuffer = create_buffer(buf_size) + + + # datetime subclasses date, thus has to go first + elif param_types[col_num][0] == 'dt': + sql_c_type = SQL_C_CHAR + sql_type = SQL_TYPE_TIMESTAMP + buf_size = self.connection.type_size_dic[SQL_TYPE_TIMESTAMP][0] + ParameterBuffer = create_buffer(buf_size) + dec_num = self.connection.type_size_dic[SQL_TYPE_TIMESTAMP][1] + + + elif param_types[col_num][0] == 'd': + sql_c_type = SQL_C_CHAR + if SQL_TYPE_DATE in self.connection.type_size_dic: + #if DEBUG:print('conx.type_size_dic.has_key(SQL_TYPE_DATE)') + sql_type = SQL_TYPE_DATE + buf_size = self.connection.type_size_dic[SQL_TYPE_DATE][0] + + ParameterBuffer = create_buffer(buf_size) + dec_num = self.connection.type_size_dic[SQL_TYPE_DATE][1] + + else: + # SQL Sever <2008 doesn't have a DATE type. + sql_type = SQL_TYPE_TIMESTAMP + buf_size = 10 + ParameterBuffer = create_buffer(buf_size) + + + elif param_types[col_num][0] == 't': + sql_c_type = SQL_C_CHAR + if SQL_TYPE_TIME in self.connection.type_size_dic: + sql_type = SQL_TYPE_TIME + buf_size = self.connection.type_size_dic[SQL_TYPE_TIME][0] + ParameterBuffer = create_buffer(buf_size) + dec_num = self.connection.type_size_dic[SQL_TYPE_TIME][1] + elif SQL_SS_TIME2 in self.connection.type_size_dic: + # TIME type added in SQL Server 2008 + sql_type = SQL_SS_TIME2 + buf_size = self.connection.type_size_dic[SQL_SS_TIME2][0] + ParameterBuffer = create_buffer(buf_size) + dec_num = self.connection.type_size_dic[SQL_SS_TIME2][1] + else: + # SQL Sever <2008 doesn't have a TIME type. + sql_type = SQL_TYPE_TIMESTAMP + buf_size = self.connection.type_size_dic[SQL_TYPE_TIMESTAMP][0] + ParameterBuffer = create_buffer(buf_size) + dec_num = 3 + + elif param_types[col_num][0] == 'BN': + sql_c_type = SQL_C_BINARY + sql_type = SQL_VARBINARY + buf_size = 1 + ParameterBuffer = create_buffer(buf_size) + + elif param_types[col_num][0] == 'N': + if len(self._PARAM_SQL_TYPE_LIST) > 0: + sql_c_type = SQL_C_DEFAULT + sql_type = self._PARAM_SQL_TYPE_LIST[col_num][0] + buf_size = 1 + ParameterBuffer = create_buffer(buf_size) + else: + sql_c_type = SQL_C_CHAR + sql_type = SQL_CHAR + buf_size = 1 + ParameterBuffer = create_buffer(buf_size) + elif param_types[col_num][0] == 'bi': + sql_c_type = SQL_C_BINARY + sql_type = SQL_LONGVARBINARY + buf_size = param_types[col_num][1]#len(self._inputsizers)>col_num and self._inputsizers[col_num] or 20500 + ParameterBuffer = create_buffer(buf_size) + + + else: + sql_c_type = SQL_C_CHAR + sql_type = SQL_LONGVARCHAR + buf_size = len(self._inputsizers)>col_num and self._inputsizers[col_num] or 20500 + ParameterBuffer = create_buffer(buf_size) + + temp_holder.append((sql_c_type, sql_type, buf_size, dec_num, ParameterBuffer)) + + for col_num, (sql_c_type, sql_type, buf_size, dec_num, ParameterBuffer) in enumerate(temp_holder): + BufferLen = c_ssize_t(buf_size) + LenOrIndBuf = c_ssize_t() + + + InputOutputType = SQL_PARAM_INPUT + if len(pram_io_list) > col_num: + InputOutputType = pram_io_list[col_num] + + ret = SQLBindParameter(self.stmt_h, col_num + 1, InputOutputType, sql_c_type, sql_type, buf_size,\ + dec_num, ADDR(ParameterBuffer), BufferLen,ADDR(LenOrIndBuf)) + if ret != SQL_SUCCESS: + check_success(self, ret) + # Append the value buffer and the lenth buffer to the array + ParamBufferList.append((ParameterBuffer,LenOrIndBuf,sql_type)) + + self._last_param_types = param_types + self._ParamBufferList = ParamBufferList + + def execute(self, query_string, params=None, many_mode=False, call_mode=False): """ Execute the query string, with optional parameters. If parameters are provided, the query would first be prepared, then executed with parameters; - If parameters are not provided, only th query sting, it would be executed directly + If parameters are not provided, only th query sting, it would be executed directly """ - - self._free_results('FREE_STATEMENT') - + self._free_stmt(SQL_CLOSE) if params: # If parameters exist, first prepare the query then executed with parameters - if not type(params) in (tuple, list, set): - raise TypeError("Params must be in a list, tuple, or set") + + if not isinstance(params, (tuple, list)): + raise TypeError("Params must be in a list, tuple, or Row") - if not many_mode: - if query_string != self.statement: - # if the query is not same as last query, then it is not prepared - self.prepare(query_string) - - - param_types = map(get_type, params) + + if query_string != self.statement: + # if the query is not same as last query, then it is not prepared + self.prepare(query_string) + + + param_types = list(map(get_type, params)) if call_mode: + self._free_stmt(SQL_RESET_PARAMS) self._BindParams(param_types, self._pram_io_list) else: - if param_types != self._last_param_types: + if self._last_param_types is None: + self._free_stmt(SQL_RESET_PARAMS) self._BindParams(param_types) - - + elif len(param_types) != len(self._last_param_types): + self._free_stmt(SQL_RESET_PARAMS) + self._BindParams(param_types) + elif sum([p_type[0] != 'N' and p_type != self._last_param_types[i] for i,p_type in enumerate(param_types)]) > 0: + self._free_stmt(SQL_RESET_PARAMS) + self._BindParams(param_types) + + # With query prepared, now put parameters into buffers col_num = 0 for param_buffer, param_buffer_len, sql_type in self._ParamBufferList: c_char_buf, c_buf_len = '', 0 param_val = params[col_num] - if param_val is None: - c_buf_len = SQL_NULL_DATA - - elif isinstance(param_val, datetime.datetime): + if param_types[col_num][0] in ('N','BN'): + param_buffer_len.value = SQL_NULL_DATA + col_num += 1 + continue + elif param_types[col_num][0] in ('i','l','f'): + if py_v3: + c_char_buf = bytes(str(param_val),'ascii') + else: + c_char_buf = str(param_val) + c_buf_len = len(c_char_buf) + + elif param_types[col_num][0] in ('s','S'): + c_char_buf = param_val + c_buf_len = len(c_char_buf) + elif param_types[col_num][0] in ('u','U'): + c_char_buf = UCS_buf(param_val) + c_buf_len = len(c_char_buf) + + elif param_types[col_num][0] == 'dt': max_len = self.connection.type_size_dic[SQL_TYPE_TIMESTAMP][0] datetime_str = param_val.strftime('%Y-%m-%d %H:%M:%S.%f') c_char_buf = datetime_str[:max_len] + if py_v3: + c_char_buf = bytes(c_char_buf,'ascii') + c_buf_len = len(c_char_buf) # print c_buf_len, c_char_buf - - elif isinstance(param_val, datetime.date): - if self.connection.type_size_dic.has_key(SQL_TYPE_DATE): + + elif param_types[col_num][0] == 'd': + if SQL_TYPE_DATE in self.connection.type_size_dic: max_len = self.connection.type_size_dic[SQL_TYPE_DATE][0] else: max_len = 10 c_char_buf = param_val.isoformat()[:max_len] + if py_v3: + c_char_buf = bytes(c_char_buf,'ascii') c_buf_len = len(c_char_buf) #print c_char_buf - - elif isinstance(param_val, datetime.time): - if self.connection.type_size_dic.has_key(SQL_TYPE_TIME): + + elif param_types[col_num][0] == 't': + if SQL_TYPE_TIME in self.connection.type_size_dic: max_len = self.connection.type_size_dic[SQL_TYPE_TIME][0] c_char_buf = param_val.isoformat()[:max_len] c_buf_len = len(c_char_buf) - elif self.connection.type_size_dic.has_key(SQL_SS_TIME2): + elif SQL_SS_TIME2 in self.connection.type_size_dic: max_len = self.connection.type_size_dic[SQL_SS_TIME2][0] c_char_buf = param_val.isoformat()[:max_len] c_buf_len = len(c_char_buf) @@ -1203,114 +1515,116 @@ class Cursor: if len(time_str) == 8: time_str += '.000' c_char_buf = '1900-01-01 '+time_str[0:c_buf_len - 11] + if py_v3: + c_char_buf = bytes(c_char_buf,'ascii') #print c_buf_len, c_char_buf - - elif isinstance(param_val, bool): + + elif param_types[col_num][0] == 'b': if param_val == True: c_char_buf = '1' else: c_char_buf = '0' + if py_v3: + c_char_buf = bytes(c_char_buf,'ascii') c_buf_len = 1 + + elif param_types[col_num][0] == 'D': #Decimal + sign = param_val.as_tuple()[0] == 0 and '+' or '-' + digit_string = ''.join([str(x) for x in param_val.as_tuple()[1]]) + digit_num, dec_num = param_types[col_num][1] + if dec_num > 0: + # has decimal + left_part = digit_string[:digit_num - dec_num] + right_part = digit_string[0-dec_num:] + else: + # no decimal + left_part = digit_string + '0'*(0-dec_num) + right_part = '' + v = ''.join((sign, left_part,'.', right_part)) - elif isinstance(param_val, (int, long, float, Decimal)): + if py_v3: + c_char_buf = bytes(v,'ascii') + else: + c_char_buf = v + c_buf_len = len(c_char_buf) + + elif param_types[col_num][0] == 'bi': c_char_buf = str(param_val) c_buf_len = len(c_char_buf) - - elif isinstance(param_val, str): - c_char_buf = param_val - c_buf_len = len(c_char_buf) - elif isinstance(param_val, unicode): - c_char_buf = to_unicode(param_val) - c_buf_len = len(c_char_buf) - elif isinstance(param_val, (bytearray, buffer)): - c_char_buf = str(param_val) - c_buf_len = len(c_char_buf) - + else: c_char_buf = param_val - - - if isinstance(param_val, (bytearray, buffer)): - param_buffer.raw = c_char_buf - + + + if param_types[col_num][0] == 'bi': + param_buffer.raw = str_8b(param_val) + else: + #print (type(param_val),param_buffer, param_buffer.value) param_buffer.value = c_char_buf - #print param_buffer, param_buffer.value - if isinstance(param_val, (unicode, str)): + if param_types[col_num][0] in ('U','u','S','s'): #ODBC driver will find NUL in unicode and string to determine their length param_buffer_len.value = SQL_NTS else: param_buffer_len.value = c_buf_len - + col_num += 1 - ret = SQLExecute(self._stmt_h) + ret = SQLExecute(self.stmt_h) if ret != SQL_SUCCESS: - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - + #print param_valparam_buffer, param_buffer.value + check_success(self, ret) + if not many_mode: self._NumOfRows() self._UpdateDesc() #self._BindCols() - + else: self.execdirect(query_string) - return (self) - - + return self + + def _SQLExecute(self): - ret = SQLExecute(self._stmt_h) + ret = SQLExecute(self.stmt_h) if ret != SQL_SUCCESS: - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - - - - - def prepare(self, query_string): - """prepare a query""" - if type(query_string) == unicode: - c_query_string = wchar_type(to_unicode(query_string)) - ret = ODBC_API.SQLPrepareW(self._stmt_h, c_query_string, len(query_string)) - else: - c_query_string = ctypes.c_char_p(query_string) - ret = ODBC_API.SQLPrepare(self._stmt_h, c_query_string, len(query_string)) - if ret != SQL_SUCCESS: - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - self.statement = query_string - - + check_success(self, ret) + + def execdirect(self, query_string): """Execute a query directly""" + self._free_stmt() + self._last_param_types = None + self.statement = None if type(query_string) == unicode: - c_query_string = wchar_type(to_unicode(query_string)) - ret = ODBC_API.SQLExecDirectW(self._stmt_h, c_query_string, len(query_string)) + c_query_string = wchar_pointer(UCS_buf(query_string)) + ret = ODBC_API.SQLExecDirectW(self.stmt_h, c_query_string, len(query_string)) else: c_query_string = ctypes.c_char_p(query_string) - ret = ODBC_API.SQLExecDirect(self._stmt_h, c_query_string, len(query_string)) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) + ret = ODBC_API.SQLExecDirect(self.stmt_h, c_query_string, len(query_string)) + check_success(self, ret) self._NumOfRows() self._UpdateDesc() #self._BindCols() - self.statement = None - return (self) - - + return self + + def callproc(self, procname, args): raise Warning('', 'Still not fully implemented') self._pram_io_list = [row[4] for row in self.procedurecolumns(procedure = procname).fetchall() if row[4] not in (SQL_RESULT_COL, SQL_RETURN_VALUE)] + + print('pram_io_list: '+str(self._pram_io_list)) - print 'pram_io_list: '+str(self._pram_io_list) - - - + + call_escape = '{CALL '+procname if args: call_escape += '(' + ','.join(['?' for params in args]) + ')' call_escape += '}' self.execute(call_escape, args, call_mode = True) - + result = [] for buf, buf_len, sql_type in self._ParamBufferList: @@ -1318,348 +1632,127 @@ class Cursor: result.append(None) else: result.append(self.connection.output_converter[sql_type](buf.value)) - return (result) - - + return result + + def executemany(self, query_string, params_list = [None]): - self.prepare(query_string) for params in params_list: self.execute(query_string, params, many_mode = True) self._NumOfRows() self.rowcount = -1 self._UpdateDesc() #self._BindCols() - - - def _BindParams(self, param_types, pram_io_list = []): - """Create parameter buffers based on param types, and bind them to the statement""" - # Get the number of query parameters judged by database. - NumParams = ctypes.c_short() - ret = ODBC_API.SQLNumParams(self._stmt_h, ADDR(NumParams)) - if ret != SQL_SUCCESS: - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - - if len(param_types) != NumParams.value: - # In case number of parameters provided do not same as number required - error_desc = "The SQL contains %d parameter markers, but %d parameters were supplied" \ - %(NumParams.value,len(param_types)) - raise ProgrammingError('HY000',error_desc) - - - # Every parameter needs to be binded to a buffer - ParamBufferList = [] - # Temporary holder since we can only call SQLDescribeParam before - # calling SQLBindParam. - temp_holder = [] - for col_num in range(NumParams.value): - col_size = 0 - buf_size = 512 - - if param_types[col_num] == type(None): - ParameterNumber = ctypes.c_ushort(col_num + 1) - DataType = ctypes.c_short() - ParameterSize = ctypes.c_size_t() - DecimalDigits = ctypes.c_short() - Nullable = ctypes.c_short() - ret = ODBC_API.SQLDescribeParam( - self._stmt_h, - ParameterNumber, - ADDR(DataType), - ADDR(ParameterSize), - ADDR(DecimalDigits), - ADDR(Nullable), - ) - if ret != SQL_SUCCESS: - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - - sql_c_type = SQL_C_DEFAULT - sql_type = DataType.value - buf_size = 1 - ParameterBuffer = create_buffer(buf_size) - - elif param_types[col_num] == 'u': - sql_c_type = SQL_C_WCHAR - sql_type = SQL_WLONGVARCHAR - buf_size = len(self._inputsizers)>col_num and self._inputsizers[col_num] or 20500 - ParameterBuffer = create_buffer_u(buf_size) - - elif param_types[col_num] == 's': - sql_c_type = SQL_C_CHAR - sql_type = SQL_LONGVARCHAR - buf_size = len(self._inputsizers)>col_num and self._inputsizers[col_num] or 20500 - ParameterBuffer = create_buffer(buf_size) - - elif type(param_types[col_num]) == tuple: #Decimal - sql_c_type = SQL_C_CHAR - sql_type = SQL_NUMERIC - buf_size = param_types[col_num][0] - - ParameterBuffer = create_buffer(buf_size+4) - col_size = param_types[col_num][1] - if DEBUG: print param_types[col_num][0],param_types[col_num][1] - - # bool subclasses int, thus has to go first - elif issubclass(param_types[col_num], bool): - sql_c_type = SQL_C_CHAR - sql_type = SQL_BIT - buf_size = SQL_data_type_dict[sql_type][4] - ParameterBuffer = create_buffer(buf_size) - - elif issubclass(param_types[col_num], int): - sql_c_type = SQL_C_CHAR - sql_type = SQL_INTEGER - buf_size = SQL_data_type_dict[sql_type][4] - ParameterBuffer = create_buffer(buf_size) - - elif issubclass(param_types[col_num], long): - sql_c_type = SQL_C_CHAR - sql_type = SQL_BIGINT - buf_size = SQL_data_type_dict[sql_type][4] - ParameterBuffer = create_buffer(buf_size) - - - elif issubclass(param_types[col_num], float): - sql_c_type = SQL_C_CHAR - sql_type = SQL_DOUBLE - buf_size = SQL_data_type_dict[sql_type][4] - ParameterBuffer = create_buffer(buf_size) - - - # datetime subclasses date, thus has to go first - elif issubclass(param_types[col_num], datetime.datetime): - sql_c_type = SQL_C_CHAR - sql_type = SQL_TYPE_TIMESTAMP - buf_size = self.connection.type_size_dic[SQL_TYPE_TIMESTAMP][0] - ParameterBuffer = create_buffer(buf_size) - col_size = self.connection.type_size_dic[SQL_TYPE_TIMESTAMP][1] - - - elif issubclass(param_types[col_num], datetime.date): - sql_c_type = SQL_C_CHAR - if self.connection.type_size_dic.has_key(SQL_TYPE_DATE): - if DEBUG: print 'conx.type_size_dic.has_key(SQL_TYPE_DATE)' - sql_type = SQL_TYPE_DATE - buf_size = self.connection.type_size_dic[SQL_TYPE_DATE][0] - - ParameterBuffer = create_buffer(buf_size) - col_size = self.connection.type_size_dic[SQL_TYPE_DATE][1] - - else: - # SQL Sever <2008 doesn't have a DATE type. - sql_type = SQL_TYPE_TIMESTAMP - buf_size = 10 - ParameterBuffer = create_buffer(buf_size) - - - elif issubclass(param_types[col_num], datetime.time): - sql_c_type = SQL_C_CHAR - if self.connection.type_size_dic.has_key(SQL_TYPE_TIME): - sql_type = SQL_TYPE_TIME - buf_size = self.connection.type_size_dic[SQL_TYPE_TIME][0] - ParameterBuffer = create_buffer(buf_size) - col_size = self.connection.type_size_dic[SQL_TYPE_TIME][1] - elif self.connection.type_size_dic.has_key(SQL_SS_TIME2): - # TIME type added in SQL Server 2008 - sql_type = SQL_SS_TIME2 - buf_size = self.connection.type_size_dic[SQL_SS_TIME2][0] - ParameterBuffer = create_buffer(buf_size) - col_size = self.connection.type_size_dic[SQL_SS_TIME2][1] - else: - # SQL Sever <2008 doesn't have a TIME type. - sql_type = SQL_TYPE_TIMESTAMP - buf_size = self.connection.type_size_dic[SQL_TYPE_TIMESTAMP][0] - ParameterBuffer = create_buffer(buf_size) - col_size = 3 - - elif issubclass(param_types[col_num], unicode): - sql_c_type = SQL_C_WCHAR - sql_type = SQL_WVARCHAR - buf_size = 255 - ParameterBuffer = create_buffer_u(buf_size) - - elif issubclass(param_types[col_num], str): - sql_c_type = SQL_C_CHAR - sql_type = SQL_VARCHAR - buf_size = 255 - ParameterBuffer = create_buffer(buf_size) - - elif issubclass(param_types[col_num], (bytearray, buffer)): - sql_c_type = SQL_C_BINARY - sql_type = SQL_LONGVARBINARY - buf_size = len(self._inputsizers)>col_num and self._inputsizers[col_num] or 20500 - ParameterBuffer = create_buffer(buf_size) - - - else: - sql_c_type = SQL_C_CHAR - sql_type = SQL_LONGVARCHAR - buf_size = len(self._inputsizers)>col_num and self._inputsizers[col_num] or 20500 - ParameterBuffer = create_buffer(buf_size) - - temp_holder.append((sql_c_type, sql_type, buf_size, col_size, ParameterBuffer)) - - for col_num, (sql_c_type, sql_type, buf_size, col_size, ParameterBuffer) in enumerate(temp_holder): - BufferLen = ctypes.c_ssize_t(buf_size) - LenOrIndBuf = ctypes.c_ssize_t() - - - InputOutputType = SQL_PARAM_INPUT - if len(pram_io_list) > col_num: - InputOutputType = pram_io_list[col_num] - - ret = SQLBindParameter(self._stmt_h, col_num + 1, InputOutputType, sql_c_type, sql_type, buf_size,\ - col_size, ADDR(ParameterBuffer), BufferLen,ADDR(LenOrIndBuf)) - if ret != SQL_SUCCESS: - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - # Append the value buffer and the lenth buffer to the array - ParamBufferList.append((ParameterBuffer,LenOrIndBuf,sql_type)) - - self._last_param_types = param_types - self._ParamBufferList = ParamBufferList - - + + def _CreateColBuf(self): + self._free_stmt(SQL_UNBIND) NOC = self._NumOfCols() self._ColBufferList = [] - self._row_type = None for col_num in range(NOC): - col_name = self.description[col_num][0] - - col_sql_data_type = self._ColTypeCodeList[col_num] + col_name = self.description[col_num][0] + + col_sql_data_type = self._ColTypeCodeList[col_num] # set default size base on the column's sql data type - total_buf_len = SQL_data_type_dict[col_sql_data_type][4] + total_buf_len = SQL_data_type_dict[col_sql_data_type][4] # over-write if there's preset size value for "large columns" - if total_buf_len >= 20500: + if total_buf_len >= 20500: total_buf_len = self._outputsize.get(None,total_buf_len) - # over-write if there's preset size value for the "col_num" column + # over-write if there's preset size value for the "col_num" column total_buf_len = self._outputsize.get(col_num, total_buf_len) alloc_buffer = SQL_data_type_dict[col_sql_data_type][3](total_buf_len) - used_buf_len = ctypes.c_ssize_t() - + used_buf_len = c_ssize_t() + target_type = SQL_data_type_dict[col_sql_data_type][2] force_unicode = self.connection.unicode_results - + if force_unicode and col_sql_data_type in (SQL_CHAR,SQL_VARCHAR,SQL_LONGVARCHAR): target_type = SQL_C_WCHAR alloc_buffer = create_buffer_u(total_buf_len) - + buf_cvt_func = self.connection.output_converter[self._ColTypeCodeList[col_num]] - - self._ColBufferList.append([col_name, target_type, used_buf_len, alloc_buffer, total_buf_len, buf_cvt_func]) - - - def _GetData(self): - '''Bind buffers for the record set columns''' - - # Lazily create the row type on first fetch. - if self._row_type is None: - self._row_type = self.row_type_callable(self) - - value_list = [] - col_num = 0 - for col_name, target_type, used_buf_len, alloc_buffer, total_buf_len, buf_cvt_func in self._ColBufferList: - - blocks = [] - while True: - ret = ODBC_API.SQLGetData(self._stmt_h, col_num + 1, target_type, ADDR(alloc_buffer), total_buf_len,\ - ADDR(used_buf_len)) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - - if ret == SQL_SUCCESS: - if used_buf_len.value == SQL_NULL_DATA: - blocks.append(None) - else: - if target_type == SQL_C_BINARY: - blocks.append(alloc_buffer.raw[:used_buf_len.value]) - elif target_type == SQL_C_WCHAR: - blocks.append(from_buffer_u(alloc_buffer)) - else: - #print col_name, target_type, alloc_buffer.value - blocks.append(alloc_buffer.value) - - break - - if ret == SQL_SUCCESS_WITH_INFO: - if target_type == SQL_C_BINARY: - blocks.append(alloc_buffer.raw) - else: - blocks.append(alloc_buffer.value) - - if ret == SQL_NO_DATA: - break - - - if len(blocks) == 1: - raw_value = blocks[0] - else: - raw_value = ''.join(blocks) - - if raw_value == None: - value_list.append(None) - else: - value_list.append(buf_cvt_func(raw_value)) - col_num += 1 - - return self._row_type(value_list) - - + ADDR(alloc_buffer) + ADDR(used_buf_len) + self._ColBufferList.append([col_name, target_type, used_buf_len, ADDR(used_buf_len), alloc_buffer, ADDR(alloc_buffer), total_buf_len, buf_cvt_func]) + + + def _UpdateDesc(self): - "Get the information of (name, type_code, display_size, internal_size, col_precision, scale, null_ok)" - Cname = create_buffer(1024) - Cname_ptr = ctypes.c_short() - Ctype_code = ctypes.c_short() + "Get the information of (name, type_code, display_size, internal_size, col_precision, scale, null_ok)" + force_unicode = self.connection.unicode_results + if force_unicode: + Cname = create_buffer_u(1024) + else: + Cname = create_buffer(1024) + + Cname_ptr = c_short() + Ctype_code = c_short() Csize = ctypes.c_size_t() - Cdisp_size = ctypes.c_ssize_t(0) - CDecimalDigits = ctypes.c_short() - Cnull_ok = ctypes.c_short() + Cdisp_size = c_ssize_t(0) + CDecimalDigits = c_short() + Cnull_ok = c_short() ColDescr = [] self._ColTypeCodeList = [] NOC = self._NumOfCols() for col in range(1, NOC+1): - ret = ODBC_API.SQLColAttribute(self._stmt_h, col, SQL_DESC_DISPLAY_SIZE, ADDR(create_buffer(10)), - 10, ADDR(ctypes.c_short()),ADDR(Cdisp_size)) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - - ret = ODBC_API.SQLDescribeCol(self._stmt_h, col, Cname, len(Cname), ADDR(Cname_ptr),\ - ADDR(Ctype_code),ADDR(Csize),ADDR(CDecimalDigits), ADDR(Cnull_ok)) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - + + ret = ODBC_API.SQLColAttribute(self.stmt_h, col, SQL_DESC_DISPLAY_SIZE, ADDR(create_buffer(10)), + 10, ADDR(c_short()),ADDR(Cdisp_size)) + if ret != SQL_SUCCESS: + check_success(self, ret) + + if force_unicode: + + ret = ODBC_API.SQLDescribeColW(self.stmt_h, col, Cname, len(Cname), ADDR(Cname_ptr),\ + ADDR(Ctype_code),ADDR(Csize),ADDR(CDecimalDigits), ADDR(Cnull_ok)) + if ret != SQL_SUCCESS: + check_success(self, ret) + else: + + ret = ODBC_API.SQLDescribeCol(self.stmt_h, col, Cname, len(Cname), ADDR(Cname_ptr),\ + ADDR(Ctype_code),ADDR(Csize),ADDR(CDecimalDigits), ADDR(Cnull_ok)) + if ret != SQL_SUCCESS: + check_success(self, ret) + col_name = Cname.value if lowercase: - col_name = str.lower(col_name) - #(name, type_code, display_size, - # internal_size, col_precision, scale, null_ok) - ColDescr.append((col_name, SQL_data_type_dict.get(Ctype_code.value,(Ctype_code.value))[0],Cdisp_size.value,\ + col_name = col_name.lower() + #(name, type_code, display_size, + + ColDescr.append((col_name, SQL_data_type_dict.get(Ctype_code.value,(Ctype_code.value,))[0],Cdisp_size.value,\ Csize.value, Csize.value,CDecimalDigits.value,Cnull_ok.value == 1 and True or False)) self._ColTypeCodeList.append(Ctype_code.value) - + if len(ColDescr) > 0: self.description = ColDescr + # Create the row type before fetching. + self._row_type = self.row_type_callable(self) else: self.description = None self._CreateColBuf() - - + + def _NumOfRows(self): """Get the number of rows""" - NOR = ctypes.c_ssize_t() - ret = ODBC_API.SQLRowCount(self._stmt_h, ADDR(NOR)) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) + NOR = c_ssize_t() + ret = SQLRowCount(self.stmt_h, ADDR(NOR)) + if ret != SQL_SUCCESS: + check_success(self, ret) self.rowcount = NOR.value - return self.rowcount - + return self.rowcount + def _NumOfCols(self): """Get the number of cols""" - NOC = ctypes.c_short() - ret = ODBC_API.SQLNumResultCols(self._stmt_h, ADDR(NOC)) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) + NOC = c_short() + ret = SQLNumResultCols(self.stmt_h, ADDR(NOC)) + if ret != SQL_SUCCESS: + check_success(self, ret) return NOC.value @@ -1667,308 +1760,436 @@ class Cursor: rows = [] while True: row = self.fetchone() - if row == None: + if row is None: break rows.append(row) return rows def fetchmany(self, num = None): - if num == None: + if num is None: num = self.arraysize - rows, row_num = [], 0 - - while row_num < num: + rows = [] + + while len(rows) < num: row = self.fetchone() - if row == None: + if row is None: break rows.append(row) - row_num += 1 return rows def fetchone(self): - ret = SQLFetch(self._stmt_h) - if ret == SQL_SUCCESS: - return self._GetData() + ret = SQLFetch(self.stmt_h) + if ret == SQL_SUCCESS: + '''Bind buffers for the record set columns''' + + value_list = [] + col_num = 1 + for col_name, target_type, used_buf_len, ADDR_used_buf_len, alloc_buffer, ADDR_alloc_buffer, total_buf_len, buf_cvt_func in self._ColBufferList: + + blocks = [] + while 1: + ret = SQLGetData(self.stmt_h, col_num, target_type, ADDR_alloc_buffer, total_buf_len, ADDR_used_buf_len) + if ret == SQL_SUCCESS: + if used_buf_len.value == SQL_NULL_DATA: + value_list.append(None) + else: + if blocks == []: + if target_type == SQL_C_BINARY: + value_list.append(buf_cvt_func(alloc_buffer.raw[:used_buf_len.value])) + elif target_type == SQL_C_WCHAR: + value_list.append(buf_cvt_func(from_buffer_u(alloc_buffer))) + else: + #print col_name, target_type, alloc_buffer.value + value_list.append(buf_cvt_func(alloc_buffer.value)) + else: + if target_type == SQL_C_BINARY: + blocks.append(alloc_buffer.raw[:used_buf_len.value]) + elif target_type == SQL_C_WCHAR: + blocks.append(from_buffer_u(alloc_buffer)) + else: + #print col_name, target_type, alloc_buffer.value + blocks.append(alloc_buffer.value) + break + + elif ret == SQL_SUCCESS_WITH_INFO: + if target_type == SQL_C_BINARY: + blocks.append(alloc_buffer.raw) + else: + blocks.append(alloc_buffer.value) + + elif ret == SQL_NO_DATA: + break + else: + check_success(self, ret) + + if blocks != []: + if py_v3: + if target_type != SQL_C_BINARY: + raw_value = ''.join(blocks) + else: + raw_value = BLANK_BYTE.join(blocks) + else: + raw_value = ''.join(blocks) + + value_list.append(buf_cvt_func(raw_value)) + col_num += 1 + + return self._row_type(value_list) + else: if ret == SQL_NO_DATA_FOUND: return None else: - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - + check_success(self, ret) + + def __next__(self): + self.next() + def next(self): row = self.fetchone() - if row == None: + if row is None: raise(StopIteration) return row - + def __iter__(self): return self - + def skip(self, count = 0): - for i in xrange(count): - ret = ODBC_API.SQLFetchScroll(self._stmt_h, SQL_FETCH_NEXT, 0) + for i in range(count): + ret = ODBC_API.SQLFetchScroll(self.stmt_h, SQL_FETCH_NEXT, 0) if ret != SQL_SUCCESS: - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - return None - - - + check_success(self, ret) + return None + + + def nextset(self): - ret = ODBC_API.SQLMoreResults(self._stmt_h) + ret = ODBC_API.SQLMoreResults(self.stmt_h) if ret not in (SQL_SUCCESS, SQL_NO_DATA): - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - + check_success(self, ret) + if ret == SQL_NO_DATA: - self._free_results('FREE_STATEMENT') + self._free_stmt() return False else: self._NumOfRows() self._UpdateDesc() #self._BindCols() return True - - - def _free_results(self, free_statement): + + + def _free_stmt(self, free_type = None): if not self.connection.connected: raise ProgrammingError('HY000','Attempt to use a closed connection.') - - self.description = None - if free_statement == 'FREE_STATEMENT': - ret = ODBC_API.SQLFreeStmt(self._stmt_h, SQL_CLOSE) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - else: - ret = ODBC_API.SQLFreeStmt(self._stmt_h, SQL_UNBIND) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - - ret = ODBC_API.SQLFreeStmt(self._stmt_h, SQL_RESET_PARAMS) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - - self.rowcount = -1 - - - + + #self.description = None + #self.rowcount = -1 + if free_type in (SQL_CLOSE, None): + ret = ODBC_API.SQLFreeStmt(self.stmt_h, SQL_CLOSE) + if ret != SQL_SUCCESS: + check_success(self, ret) + if free_type in (SQL_UNBIND, None): + ret = ODBC_API.SQLFreeStmt(self.stmt_h, SQL_UNBIND) + if ret != SQL_SUCCESS: + check_success(self, ret) + if free_type in (SQL_RESET_PARAMS, None): + ret = ODBC_API.SQLFreeStmt(self.stmt_h, SQL_RESET_PARAMS) + if ret != SQL_SUCCESS: + check_success(self, ret) + + + def getTypeInfo(self, sqlType = None): - if sqlType == None: + if sqlType is None: type = SQL_ALL_TYPES else: type = sqlType - ret = ODBC_API.SQLGetTypeInfo(self._stmt_h, type) + ret = ODBC_API.SQLGetTypeInfo(self.stmt_h, type) if ret in (SQL_SUCCESS, SQL_SUCCESS_WITH_INFO): self._NumOfRows() self._UpdateDesc() #self._BindCols() return self.fetchone() - - + + def tables(self, table=None, catalog=None, schema=None, tableType=None): - """Return a list with all tables""" + """Return a list with all tables""" l_catalog = l_schema = l_table = l_tableType = 0 - - if catalog != None: + + if unicode in [type(x) for x in (table, catalog, schema,tableType)]: + string_p = lambda x:wchar_pointer(UCS_buf(x)) + API_f = ODBC_API.SQLTablesW + else: + string_p = ctypes.c_char_p + API_f = ODBC_API.SQLTables + + + + if catalog is not None: l_catalog = len(catalog) - catalog = ctypes.c_char_p(catalog) + catalog = string_p(catalog) - if schema != None: + if schema is not None: l_schema = len(schema) - schema = ctypes.c_char_p(schema) - - if table != None: + schema = string_p(schema) + + if table is not None: l_table = len(table) - table = ctypes.c_char_p(table) - - if tableType != None: + table = string_p(table) + + if tableType is not None: l_tableType = len(tableType) - tableType = ctypes.c_char_p(tableType) - - self._free_results('FREE_STATEMENT') + tableType = string_p(tableType) + + self._free_stmt() + self._last_param_types = None self.statement = None - ret = ODBC_API.SQLTables(self._stmt_h, + ret = API_f(self.stmt_h, catalog, l_catalog, - schema, l_schema, + schema, l_schema, table, l_table, tableType, l_tableType) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - + check_success(self, ret) + self._NumOfRows() self._UpdateDesc() #self._BindCols() - return (self) - - + return self + + def columns(self, table=None, catalog=None, schema=None, column=None): - """Return a list with all columns""" + """Return a list with all columns""" l_catalog = l_schema = l_table = l_column = 0 - if catalog != None: + + if unicode in [type(x) for x in (table, catalog, schema,column)]: + string_p = lambda x:wchar_pointer(UCS_buf(x)) + API_f = ODBC_API.SQLColumnsW + else: + string_p = ctypes.c_char_p + API_f = ODBC_API.SQLColumns + + + + if catalog is not None: l_catalog = len(catalog) - catalog = ctypes.c_char_p(catalog) - if schema != None: + catalog = string_p(catalog) + if schema is not None: l_schema = len(schema) - schema = ctypes.c_char_p(schema) - if table != None: + schema = string_p(schema) + if table is not None: l_table = len(table) - table = ctypes.c_char_p(table) - if column != None: + table = string_p(table) + if column is not None: l_column = len(column) - column = ctypes.c_char_p(column) - - self._free_results('FREE_STATEMENT') + column = string_p(column) + + self._free_stmt() + self._last_param_types = None self.statement = None - - ret = ODBC_API.SQLColumns(self._stmt_h, - catalog, l_catalog, - schema, l_schema, - table, l_table, - column, l_column) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) + + ret = API_f(self.stmt_h, + catalog, l_catalog, + schema, l_schema, + table, l_table, + column, l_column) + check_success(self, ret) self._NumOfRows() self._UpdateDesc() #self._BindCols() - return (self) - - + return self + + def primaryKeys(self, table=None, catalog=None, schema=None): l_catalog = l_schema = l_table = 0 - if catalog != None: + + if unicode in [type(x) for x in (table, catalog, schema)]: + string_p = lambda x:wchar_pointer(UCS_buf(x)) + API_f = ODBC_API.SQLPrimaryKeysW + else: + string_p = ctypes.c_char_p + API_f = ODBC_API.SQLPrimaryKeys + + + + if catalog is not None: l_catalog = len(catalog) - catalog = ctypes.c_char_p(catalog) - - if schema != None: + catalog = string_p(catalog) + + if schema is not None: l_schema = len(schema) - schema = ctypes.c_char_p(schema) - - if table != None: + schema = string_p(schema) + + if table is not None: l_table = len(table) - table = ctypes.c_char_p(table) - - self._free_results('FREE_STATEMENT') + table = string_p(table) + + self._free_stmt() + self._last_param_types = None self.statement = None - - ret = ODBC_API.SQLPrimaryKeys(self._stmt_h, - catalog, l_catalog, - schema, l_schema, - table, l_table) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - + + ret = API_f(self.stmt_h, + catalog, l_catalog, + schema, l_schema, + table, l_table) + check_success(self, ret) + self._NumOfRows() self._UpdateDesc() #self._BindCols() - return (self) - - + return self + + def foreignKeys(self, table=None, catalog=None, schema=None, foreignTable=None, foreignCatalog=None, foreignSchema=None): l_catalog = l_schema = l_table = l_foreignTable = l_foreignCatalog = l_foreignSchema = 0 - if catalog != None: + + if unicode in [type(x) for x in (table, catalog, schema,foreignTable,foreignCatalog,foreignSchema)]: + string_p = lambda x:wchar_pointer(UCS_buf(x)) + API_f = ODBC_API.SQLForeignKeysW + else: + string_p = ctypes.c_char_p + API_f = ODBC_API.SQLForeignKeys + + if catalog is not None: l_catalog = len(catalog) - catalog = ctypes.c_char_p(catalog) - if schema != None: + catalog = string_p(catalog) + if schema is not None: l_schema = len(schema) - schema = ctypes.c_char_p(schema) - if table != None: + schema = string_p(schema) + if table is not None: l_table = len(table) - table = ctypes.c_char_p(table) - if foreignTable != None: + table = string_p(table) + if foreignTable is not None: l_foreignTable = len(foreignTable) - foreignTable = ctypes.c_char_p(foreignTable) - if foreignCatalog != None: + foreignTable = string_p(foreignTable) + if foreignCatalog is not None: l_foreignCatalog = len(foreignCatalog) - foreignCatalog = ctypes.c_char_p(foreignCatalog) - if foreignSchema != None: + foreignCatalog = string_p(foreignCatalog) + if foreignSchema is not None: l_foreignSchema = len(foreignSchema) - foreignSchema = ctypes.c_char_p(foreignSchema) - - self._free_results('FREE_STATEMENT') + foreignSchema = string_p(foreignSchema) + + self._free_stmt() + self._last_param_types = None self.statement = None - - ret = ODBC_API.SQLForeignKeys(self._stmt_h, - catalog, l_catalog, - schema, l_schema, - table, l_table, - foreignCatalog, l_foreignCatalog, - foreignSchema, l_foreignSchema, - foreignTable, l_foreignTable) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - + + ret = API_f(self.stmt_h, + catalog, l_catalog, + schema, l_schema, + table, l_table, + foreignCatalog, l_foreignCatalog, + foreignSchema, l_foreignSchema, + foreignTable, l_foreignTable) + check_success(self, ret) + self._NumOfRows() self._UpdateDesc() #self._BindCols() - return (self) - - + return self + + def procedurecolumns(self, procedure=None, catalog=None, schema=None, column=None): l_catalog = l_schema = l_procedure = l_column = 0 - if catalog != None: + if unicode in [type(x) for x in (procedure, catalog, schema,column)]: + string_p = lambda x:wchar_pointer(UCS_buf(x)) + API_f = ODBC_API.SQLProcedureColumnsW + else: + string_p = ctypes.c_char_p + API_f = ODBC_API.SQLProcedureColumns + + + if catalog is not None: l_catalog = len(catalog) - catalog = ctypes.c_char_p(catalog) - if schema != None: + catalog = string_p(catalog) + if schema is not None: l_schema = len(schema) - schema = ctypes.c_char_p(schema) - if procedure != None: + schema = string_p(schema) + if procedure is not None: l_procedure = len(procedure) - procedure = ctypes.c_char_p(procedure) - if column != None: + procedure = string_p(procedure) + if column is not None: l_column = len(column) - column = ctypes.c_char_p(column) - - - self._free_results('FREE_STATEMENT') + column = string_p(column) + + + self._free_stmt() + self._last_param_types = None self.statement = None - - ret = ODBC_API.SQLProcedureColumns(self._stmt_h, - catalog, l_catalog, - schema, l_schema, - procedure, l_procedure, - column, l_column) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - + + ret = API_f(self.stmt_h, + catalog, l_catalog, + schema, l_schema, + procedure, l_procedure, + column, l_column) + check_success(self, ret) + self._NumOfRows() self._UpdateDesc() - return (self) - - + return self + + def procedures(self, procedure=None, catalog=None, schema=None): l_catalog = l_schema = l_procedure = 0 - if catalog != None: + + if unicode in [type(x) for x in (procedure, catalog, schema)]: + string_p = lambda x:wchar_pointer(UCS_buf(x)) + API_f = ODBC_API.SQLProceduresW + else: + string_p = ctypes.c_char_p + API_f = ODBC_API.SQLProcedures + + + + if catalog is not None: l_catalog = len(catalog) - catalog = ctypes.c_char_p(catalog) - if schema != None: + catalog = string_p(catalog) + if schema is not None: l_schema = len(schema) - schema = ctypes.c_char_p(schema) - if procedure != None: + schema = string_p(schema) + if procedure is not None: l_procedure = len(procedure) - procedure = ctypes.c_char_p(procedure) - - - self._free_results('FREE_STATEMENT') + procedure = string_p(procedure) + + + self._free_stmt() + self._last_param_types = None self.statement = None - - ret = ODBC_API.SQLProcedures(self._stmt_h, - catalog, l_catalog, - schema, l_schema, - procedure, l_procedure) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - + + ret = API_f(self.stmt_h, + catalog, l_catalog, + schema, l_schema, + procedure, l_procedure) + check_success(self, ret) + self._NumOfRows() self._UpdateDesc() - return (self) + return self def statistics(self, table, catalog=None, schema=None, unique=False, quick=True): l_table = l_catalog = l_schema = 0 - - if catalog != None: + + if unicode in [type(x) for x in (table, catalog, schema)]: + string_p = lambda x:wchar_pointer(UCS_buf(x)) + API_f = ODBC_API.SQLStatisticsW + else: + string_p = ctypes.c_char_p + API_f = ODBC_API.SQLStatistics + + + if catalog is not None: l_catalog = len(catalog) - catalog = ctypes.c_char_p(catalog) - if schema != None: + catalog = string_p(catalog) + if schema is not None: l_schema = len(schema) - schema = ctypes.c_char_p(schema) - if table != None: + schema = string_p(schema) + if table is not None: l_table = len(table) - table = ctypes.c_char_p(table) - + table = string_p(table) + if unique: Unique = SQL_INDEX_UNIQUE else: @@ -1977,134 +2198,138 @@ class Cursor: Reserved = SQL_QUICK else: Reserved = SQL_ENSURE - - self._free_results('FREE_STATEMENT') + + self._free_stmt() + self._last_param_types = None self.statement = None - - ret = ODBC_API.SQLStatistics(self._stmt_h, - catalog, l_catalog, - schema, l_schema, - table, l_table, - Unique, Reserved) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - + + ret = API_f(self.stmt_h, + catalog, l_catalog, + schema, l_schema, + table, l_table, + Unique, Reserved) + check_success(self, ret) + self._NumOfRows() self._UpdateDesc() #self._BindCols() - return (self) - + return self + def commit(self): self.connection.commit() def rollback(self): self.connection.rollback() - + def setoutputsize(self, size, column = None): self._outputsize[column] = size - + def setinputsizes(self, sizes): self._inputsizers = [size for size in sizes] def close(self): """ Call SQLCloseCursor API to free the statement handle""" -# ret = ODBC_API.SQLCloseCursor(self._stmt_h) -# validate(ret, SQL_HANDLE_STMT, self._stmt_h) -# - ret = ODBC_API.SQLFreeStmt(self._stmt_h, SQL_CLOSE) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) +# ret = ODBC_API.SQLCloseCursor(self.stmt_h) +# check_success(self, ret) +# + ret = ODBC_API.SQLFreeStmt(self.stmt_h, SQL_CLOSE) + check_success(self, ret) - ret = ODBC_API.SQLFreeStmt(self._stmt_h, SQL_UNBIND) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) + ret = ODBC_API.SQLFreeStmt(self.stmt_h, SQL_UNBIND) + check_success(self, ret) - ret = ODBC_API.SQLFreeStmt(self._stmt_h, SQL_RESET_PARAMS) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) - - ret = ODBC_API.SQLFreeHandle(SQL_HANDLE_STMT, self._stmt_h) - validate(ret, SQL_HANDLE_STMT, self._stmt_h) + ret = ODBC_API.SQLFreeStmt(self.stmt_h, SQL_RESET_PARAMS) + check_success(self, ret) + ret = ODBC_API.SQLFreeHandle(SQL_HANDLE_STMT, self.stmt_h) + check_success(self, ret) + self.closed = True - - - def __del__(self): + + + def __del__(self): if not self.closed: - if DEBUG: print 'auto closing cursor: ', + #if DEBUG:print 'auto closing cursor: ', try: self.close() except: - if DEBUG: print 'failed' + #if DEBUG:print 'failed' pass else: - if DEBUG: print 'succeed' + #if DEBUG:print 'succeed' pass - + def __exit__(self, type, value, traceback): if value: self.rollback() else: self.commit() - + self.close() - - + + def __enter__(self): return self + +# This class implement a odbc connection. +# +# -# This class implement a odbc connection. -# -# class Connection: - def __init__(self, connectString = '', autocommit = False, ansi = False, timeout = 0, unicode_results = False, readonly = False, **kargs): + def __init__(self, connectString = '', autocommit = False, ansi = False, timeout = 0, unicode_results = use_unicode, readonly = False, **kargs): """Init variables and connect to the engine""" self.connected = 0 self.type_size_dic = {} + self.ansi = False self.unicode_results = False self.dbc_h = ctypes.c_void_p() self.autocommit = autocommit self.readonly = False self.timeout = 0 - - for key, value in kargs.items(): + self._cursors = [] + for key, value in list(kargs.items()): connectString = connectString + key + '=' + value + ';' self.connectString = connectString - + self.clear_output_converters() try: lock.acquire() - if shared_env_h == None: + if shared_env_h is None: #Initialize an enviroment if it is not created. AllocateEnv() finally: lock.release() - + # Allocate an DBC handle self.dbc_h under the environment shared_env_h # This DBC handle is actually the basis of a "connection" - # The handle of self.dbc_h will be used to connect to a certain source + # The handle of self.dbc_h will be used to connect to a certain source # in the self.connect and self.ConnectByDSN method - + ret = ODBC_API.SQLAllocHandle(SQL_HANDLE_DBC, shared_env_h, ADDR(self.dbc_h)) - validate(ret, SQL_HANDLE_DBC, self.dbc_h) - + check_success(self, ret) + self.connect(connectString, autocommit, ansi, timeout, unicode_results, readonly) - - - - def connect(self, connectString = '', autocommit = False, ansi = False, timeout = 0, unicode_results = False, readonly = False): + + + + def connect(self, connectString = '', autocommit = False, ansi = False, timeout = 0, unicode_results = use_unicode, readonly = False): """Connect to odbc, using connect strings and set the connection's attributes like autocommit and timeout by calling SQLSetConnectAttr - """ + """ # Before we establish the connection by the connection string # Set the connection's attribute of "timeout" (Actully LOGIN_TIMEOUT) if timeout != 0: + self.settimeout(timeout) ret = ODBC_API.SQLSetConnectAttr(self.dbc_h, SQL_ATTR_LOGIN_TIMEOUT, timeout, SQL_IS_UINTEGER); - validate(ret, SQL_HANDLE_DBC, self.dbc_h) + check_success(self, ret) # Create one connection with a connect string by calling SQLDriverConnect @@ -2112,12 +2337,12 @@ class Connection: # Convert the connetsytring to encoded string - # so it can be converted to a ctypes c_char array object - - + # so it can be converted to a ctypes c_char array object + + self.ansi = ansi if not ansi: - c_connectString = wchar_type(to_unicode(self.connectString)) + c_connectString = wchar_pointer(UCS_buf(self.connectString)) odbc_func = ODBC_API.SQLDriverConnectW else: c_connectString = ctypes.c_char_p(self.connectString) @@ -2139,45 +2364,44 @@ class Connection: lock.release() else: ret = odbc_func(self.dbc_h, 0, c_connectString, len(self.connectString), None, 0, None, SQL_DRIVER_NOPROMPT) - validate(ret, SQL_HANDLE_DBC, self.dbc_h) - - - # Set the connection's attribute of "autocommit" + check_success(self, ret) + + + # Set the connection's attribute of "autocommit" # self.autocommit = autocommit - + if self.autocommit == True: ret = ODBC_API.SQLSetConnectAttr(self.dbc_h, SQL_ATTR_AUTOCOMMIT, SQL_AUTOCOMMIT_ON, SQL_IS_UINTEGER) else: ret = ODBC_API.SQLSetConnectAttr(self.dbc_h, SQL_ATTR_AUTOCOMMIT, SQL_AUTOCOMMIT_OFF, SQL_IS_UINTEGER) - - validate(ret, SQL_HANDLE_DBC, self.dbc_h) - - # Set the connection's attribute of "readonly" + check_success(self, ret) + + # Set the connection's attribute of "readonly" # self.readonly = readonly - + ret = ODBC_API.SQLSetConnectAttr(self.dbc_h, SQL_ATTR_ACCESS_MODE, self.readonly and SQL_MODE_READ_ONLY or SQL_MODE_READ_WRITE, SQL_IS_UINTEGER) - validate(ret, SQL_HANDLE_DBC, self.dbc_h) - + check_success(self, ret) + self.unicode_results = unicode_results - self.update_type_size_info() self.connected = 1 - + self.update_db_special_info() + def clear_output_converters(self): self.output_converter = {} for sqltype, profile in SQL_data_type_dict.items(): self.output_converter[sqltype] = profile[1] - - + + def add_output_converter(self, sqltype, func): self.output_converter[sqltype] = func - + def settimeout(self, timeout): ret = ODBC_API.SQLSetConnectAttr(self.dbc_h, SQL_ATTR_CONNECTION_TIMEOUT, timeout, SQL_IS_UINTEGER); - validate(ret, SQL_HANDLE_DBC, self.dbc_h) + check_success(self, ret) self.timeout = timeout - + def ConnectByDSN(self, dsn, user, passwd = ''): """Connect to odbc, we need dsn, user and optionally password""" @@ -2186,25 +2410,25 @@ class Connection: self.passwd = passwd sn = create_buffer(dsn) - un = create_buffer(user) + un = create_buffer(user) pw = create_buffer(passwd) - + ret = ODBC_API.SQLConnect(self.dbc_h, sn, len(sn), un, len(un), pw, len(pw)) - validate(ret, SQL_HANDLE_DBC, self.dbc_h) + check_success(self, ret) - self.update_type_size_info() + self.update_db_special_info() self.connected = 1 - - - def cursor(self, row_type_callable=None): + + + def cursor(self, row_type_callable=None): #self.settimeout(self.timeout) if not self.connected: raise ProgrammingError('HY000','Attempt to use a closed connection.') + cur = Cursor(self, row_type_callable=row_type_callable) + self._cursors.append(cur) + return cur - - return Cursor(self, row_type_callable=row_type_callable) - - def update_type_size_info(self): + def update_db_special_info(self): for sql_type in ( SQL_TYPE_TIMESTAMP, SQL_TYPE_DATE, @@ -2212,103 +2436,124 @@ class Connection: SQL_SS_TIME2, ): cur = Cursor(self) - info_tuple = cur.getTypeInfo(sql_type) - if info_tuple != None: - self.type_size_dic[sql_type] = info_tuple[2], info_tuple[14] + try: + info_tuple = cur.getTypeInfo(sql_type) + if info_tuple is not None: + self.type_size_dic[sql_type] = info_tuple[2], info_tuple[14] + except: + pass cur.close() - - + + self.support_SQLDescribeParam = False + try: + driver_name = self.getinfo(SQL_DRIVER_NAME) + if any(x in driver_name for x in ('SQLSRV','ncli','libsqlncli')): + self.support_SQLDescribeParam = True + except: + pass + def commit(self): if not self.connected: raise ProgrammingError('HY000','Attempt to use a closed connection.') - - ret = ODBC_API.SQLEndTran(SQL_HANDLE_DBC, self.dbc_h, SQL_COMMIT); - validate(ret, SQL_HANDLE_DBC, self.dbc_h) + + ret = SQLEndTran(SQL_HANDLE_DBC, self.dbc_h, SQL_COMMIT) + if ret != SQL_SUCCESS: + check_success(self, ret) def rollback(self): if not self.connected: raise ProgrammingError('HY000','Attempt to use a closed connection.') - - ret = ODBC_API.SQLEndTran(SQL_HANDLE_DBC, self.dbc_h, SQL_ROLLBACK); - validate(ret, SQL_HANDLE_DBC, self.dbc_h) - - - + ret = SQLEndTran(SQL_HANDLE_DBC, self.dbc_h, SQL_ROLLBACK) + if ret != SQL_SUCCESS: + check_success(self, ret) + + + def getinfo(self,infotype): - if infotype not in aInfoTypes.keys(): - raise ProgrammingError('HY000','Invalid getinfo value: '+str(infotype)) - - + if infotype not in list(aInfoTypes.keys()): + raise ProgrammingError('HY000','Invalid getinfo value: '+str(infotype)) + + if aInfoTypes[infotype] == 'GI_UINTEGER': total_buf_len = 1000 alloc_buffer = ctypes.c_ulong() - used_buf_len = ctypes.c_short() + used_buf_len = c_short() ret = ODBC_API.SQLGetInfo(self.dbc_h,infotype,ADDR(alloc_buffer), total_buf_len,\ ADDR(used_buf_len)) - validate(ret, SQL_HANDLE_DBC, self.dbc_h) + check_success(self, ret) result = alloc_buffer.value - + elif aInfoTypes[infotype] == 'GI_USMALLINT': total_buf_len = 1000 alloc_buffer = ctypes.c_ushort() - used_buf_len = ctypes.c_short() + used_buf_len = c_short() ret = ODBC_API.SQLGetInfo(self.dbc_h,infotype,ADDR(alloc_buffer), total_buf_len,\ ADDR(used_buf_len)) - validate(ret, SQL_HANDLE_DBC, self.dbc_h) + check_success(self, ret) result = alloc_buffer.value else: total_buf_len = 1000 alloc_buffer = create_buffer(total_buf_len) - used_buf_len = ctypes.c_short() - ret = ODBC_API.SQLGetInfo(self.dbc_h,infotype,ADDR(alloc_buffer), total_buf_len,\ + used_buf_len = c_short() + if self.ansi: + API_f = ODBC_API.SQLGetInfo + else: + API_f = ODBC_API.SQLGetInfoW + ret = API_f(self.dbc_h,infotype,ADDR(alloc_buffer), total_buf_len,\ ADDR(used_buf_len)) - validate(ret, SQL_HANDLE_DBC, self.dbc_h) - result = alloc_buffer.value + check_success(self, ret) + if self.ansi: + result = alloc_buffer.value + else: + result = UCS_dec(alloc_buffer) if aInfoTypes[infotype] == 'GI_YESNO': - if result[0] == 'Y': + if unicode(result[0]) == unicode('Y'): result = True else: result = False - + return result - + def __exit__(self, type, value, traceback): if value: self.rollback() else: self.commit() - + if self.connected: self.close() - + def __enter__(self): return self def __del__(self): if self.connected: self.close() - + def close(self): if not self.connected: raise ProgrammingError('HY000','Attempt to close a closed connection.') - - + for cur in self._cursors: + if not cur is None: + if not cur.closed: + cur.close() + if self.connected: - if DEBUG: print 'disconnect' + #if DEBUG:print 'disconnect' if not self.autocommit: self.rollback() ret = ODBC_API.SQLDisconnect(self.dbc_h) - validate(ret, SQL_HANDLE_DBC, self.dbc_h) - if DEBUG: print 'free dbc' + check_success(self, ret) + #if DEBUG:print 'free dbc' ret = ODBC_API.SQLFreeHandle(SQL_HANDLE_DBC, self.dbc_h) - validate(ret, SQL_HANDLE_DBC, self.dbc_h) + check_success(self, ret) # if shared_env_h.value: -# if DEBUG: print 'env' +# #if DEBUG:print 'env' # ret = ODBC_API.SQLFreeHandle(SQL_HANDLE_ENV, shared_env_h) -# validate(ret, SQL_HANDLE_ENV, shared_env_h) +# check_success(shared_env_h, ret) self.connected = 0 - + odbc = Connection connect = odbc ''' @@ -2316,39 +2561,113 @@ def connect(connectString = '', autocommit = False, ansi = False, timeout = 0, u return odbc(connectString, autocommit, ansi, timeout, unicode_results, readonly, kargs) ''' +def drivers(): + if sys.platform not in ('win32','cli'): + raise Exception('This function is available for use in Windows only.') + try: + lock.acquire() + if shared_env_h is None: + AllocateEnv() + finally: + lock.release() + + DriverDescription = create_buffer_u(1000) + BufferLength1 = c_short(1000) + DescriptionLength = c_short() + DriverAttributes = create_buffer_u(1000) + BufferLength2 = c_short(1000) + AttributesLength = c_short() + ret = SQL_SUCCESS + DriverList = [] + Direction = SQL_FETCH_FIRST + while ret != SQL_NO_DATA: + ret = ODBC_API.SQLDriversW(shared_env_h, Direction , DriverDescription , BufferLength1 + , ADDR(DescriptionLength), DriverAttributes, BufferLength2, ADDR(AttributesLength)) + check_success(shared_env_h, ret) + DriverList.append(DriverDescription.value) + if Direction == SQL_FETCH_FIRST: + Direction = SQL_FETCH_NEXT + return DriverList + + + + def win_create_mdb(mdb_path, sort_order = "General\0\0"): if sys.platform not in ('win32','cli'): raise Exception('This function is available for use in Windows only.') + + mdb_driver = [d for d in drivers() if 'Microsoft Access Driver (*.mdb' in d] + if mdb_driver == []: + raise Exception('Access Driver is not found.') + else: + driver_name = mdb_driver[0].encode('mbcs') + + #CREATE_DB= ctypes.windll.ODBCCP32.SQLConfigDataSource.argtypes = [ctypes.c_void_p,ctypes.c_ushort,ctypes.c_char_p,ctypes.c_char_p] - c_Path = "CREATE_DB=" + mdb_path + " " + sort_order + + if py_v3: + c_Path = bytes("CREATE_DB=" + mdb_path + " " + sort_order,'mbcs') + else: + c_Path = "CREATE_DB=" + mdb_path + " " + sort_order ODBC_ADD_SYS_DSN = 1 - ret = ctypes.windll.ODBCCP32.SQLConfigDataSource(None,ODBC_ADD_SYS_DSN,"Microsoft Access Driver (*.mdb)", c_Path) + + + ret = ctypes.windll.ODBCCP32.SQLConfigDataSource(None,ODBC_ADD_SYS_DSN,driver_name, c_Path) if not ret: - raise Exception('Failed to create Access mdb file. Please check file path, permission and Access driver readiness.') - + raise Exception('Failed to create Access mdb file - "%s". Please check file path, permission and Access driver readiness.' %mdb_path) + + +def win_connect_mdb(mdb_path): + if sys.platform not in ('win32','cli'): + raise Exception('This function is available for use in Windows only.') + + mdb_driver = [d for d in drivers() if 'Microsoft Access Driver (*.mdb' in d] + if mdb_driver == []: + raise Exception('Access Driver is not found.') + else: + driver_name = mdb_driver[0] + return connect('Driver={'+driver_name+"};DBQ="+mdb_path, unicode_results = use_unicode, readonly = False) + + + def win_compact_mdb(mdb_path, compacted_mdb_path, sort_order = "General\0\0"): if sys.platform not in ('win32','cli'): raise Exception('This function is available for use in Windows only.') + + + mdb_driver = [d for d in drivers() if 'Microsoft Access Driver (*.mdb' in d] + if mdb_driver == []: + raise Exception('Access Driver is not found.') + else: + driver_name = mdb_driver[0].encode('mbcs') + #COMPACT_DB= - c_Path = "COMPACT_DB=" + mdb_path + " " + compacted_mdb_path + " " + sort_order - ODBC_ADD_SYS_DSN = 1 ctypes.windll.ODBCCP32.SQLConfigDataSource.argtypes = [ctypes.c_void_p,ctypes.c_ushort,ctypes.c_char_p,ctypes.c_char_p] - ret = ctypes.windll.ODBCCP32.SQLConfigDataSource(None,ODBC_ADD_SYS_DSN,"Microsoft Access Driver (*.mdb)", c_Path) + #driver_name = "Microsoft Access Driver (*.mdb)" + if py_v3: + c_Path = bytes("COMPACT_DB=" + mdb_path + " " + compacted_mdb_path + " " + sort_order,'mbcs') + #driver_name = bytes(driver_name,'mbcs') + else: + c_Path = "COMPACT_DB=" + mdb_path + " " + compacted_mdb_path + " " + sort_order + + ODBC_ADD_SYS_DSN = 1 + ret = ctypes.windll.ODBCCP32.SQLConfigDataSource(None,ODBC_ADD_SYS_DSN,driver_name, c_Path) if not ret: - raise Exception('Failed to compact Access mdb file. Please check file path, permission and Access driver readiness.') + raise Exception('Failed to compact Access mdb file - "%s". Please check file path, permission and Access driver readiness.' %compacted_mdb_path) + def dataSources(): """Return a list with [name, descrition]""" dsn = create_buffer(1024) desc = create_buffer(1024) - dsn_len = ctypes.c_short() - desc_len = ctypes.c_short() + dsn_len = c_short() + desc_len = c_short() dsn_list = {} try: lock.acquire() - if shared_env_h == None: + if shared_env_h is None: AllocateEnv() finally: lock.release() @@ -2361,4 +2680,4 @@ def dataSources(): ctrl_err(SQL_HANDLE_ENV, shared_env_h, ret) else: dsn_list[dsn.value] = desc.value - return dsn_list + return dsn_list \ No newline at end of file