upgraded to pymysql 0.2

This commit is contained in:
mdipierro
2012-07-12 15:31:24 -05:00
parent 0616d22aec
commit 2780f6f032
18 changed files with 809 additions and 192 deletions
+1 -1
View File
@@ -1 +1 @@
Version 2.00.0 (2012-07-12 15:23:22) dev
Version 2.00.0 (2012-07-12 15:31:21) dev
+37
View File
@@ -0,0 +1,37 @@
Changes
--------
0.4 -Miscellaneous bug fixes
-Implementation of SSL support
-Implementation of kill()
-Cleaned up charset functionality
-Fixed BIT type handling
-Connections raise exceptions after they are close()'d
-Full Py3k and unicode support
0.3 -Implemented most of the extended DBAPI 2.0 spec including callproc()
-Fixed error handling to include the message from the server and support
multiple protocol versions.
-Implemented ping()
-Implemented unicode support (probably needs better testing)
-Removed DeprecationWarnings
-Ran against the MySQLdb unit tests to check for bugs
-Added support for client_flag, charset, sql_mode, read_default_file,
use_unicode, cursorclass, init_command, and connect_timeout.
-Refactoring for some more compatibility with MySQLdb including a fake
pymysql.version_info attribute.
-Now runs with no warnings with the -3 command-line switch
-Added test cases for all outstanding tickets and closed most of them.
-Basic Jython support added.
-Fixed empty result sets bug.
-Integrated new unit tests and refactored the example into one.
-Fixed bug with decimal conversion.
-Fixed string encoding bug. Now unicode and binary data work!
-Added very basic docstrings.
0.2 -Changed connection parameter name 'password' to 'passwd'
to make it more plugin replaceable for the other mysql clients.
-Changed pack()/unpack() calls so it runs on 64 bit OSes too.
-Added support for unix_socket.
-Added support for no password.
-Renamed decorders to decoders.
-Better handling of non-existing decoder.
+6 -2
View File
@@ -7,8 +7,8 @@ PyMySQL Installation
This package contains a pure-Python MySQL client library.
Documentation on the MySQL client/server protocol can be found here:
http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol
If you would like to run the test suite, create a ~/.my.cnf file and
a database called "test_pymysql". The goal of pymysql is to be a drop-in
If you would like to run the test suite, edit the config parameters in
pymysql/tests/base.py. The goal of pymysql is to be a drop-in
replacement for MySQLdb and work on CPython 2.3+, Jython, IronPython, PyPy
and Python 3. We test for compatibility by simply changing the import
statements in the Django MySQL backend and running its unit tests as well
@@ -34,4 +34,8 @@ Installation
# ... or ...
# python setup.py install
Python 3.0 Support
------------------
Simply run the build-py3k.sh script from the local directory. It will
build a working package in the ./py3k directory.
+5 -6
View File
@@ -23,13 +23,13 @@ THE SOFTWARE.
'''
VERSION = (0, 4, None)
VERSION = (0, 5, None)
from constants import FIELD_TYPE
from converters import escape_dict, escape_sequence, escape_string
from err import Warning, Error, InterfaceError, DataError, \
DatabaseError, OperationalError, IntegrityError, InternalError, \
NotSupportedError, ProgrammingError
NotSupportedError, ProgrammingError, MySQLError
from times import Date, Time, Timestamp, \
DateFromTicks, TimeFromTicks, TimestampFromTicks
@@ -110,7 +110,7 @@ def thread_safe():
def install_as_MySQLdb():
"""
After this function is called, any application that imports MySQLdb or
_mysql will unwittingly actually use
_mysql will unwittingly actually use
"""
sys.modules["MySQLdb"] = sys.modules["_mysql"] = sys.modules["pymysql"]
@@ -121,12 +121,11 @@ __all__ = [
'InterfaceError', 'InternalError', 'MySQLError', 'NULL', 'NUMBER',
'NotSupportedError', 'DBAPISet', 'OperationalError', 'ProgrammingError',
'ROWID', 'STRING', 'TIME', 'TIMESTAMP', 'Warning', 'apilevel', 'connect',
'connections', 'constants', 'converters', 'cursors', 'debug', 'escape',
'connections', 'constants', 'converters', 'cursors',
'escape_dict', 'escape_sequence', 'escape_string', 'get_client_info',
'paramstyle', 'string_literal', 'threadsafety', 'version_info',
'paramstyle', 'threadsafety', 'version_info',
"install_as_MySQLdb",
"NULL","__version__",
]
-1
View File
@@ -172,4 +172,3 @@ def charset_by_name(name):
def charset_by_id(id):
return _charsets.by_id(id)
+243 -107
View File
@@ -25,6 +25,12 @@ try:
except ImportError:
import StringIO
try:
import getpass
DEFAULT_USER = getpass.getuser()
except ImportError:
DEFAULT_USER = None
from charset import MBLENGTH, charset_by_name, charset_by_id
from cursors import Cursor
from constants import FIELD_TYPE, FLAG
@@ -50,22 +56,24 @@ UNSIGNED_INT24_LENGTH = 3
UNSIGNED_INT64_LENGTH = 8
DEFAULT_CHARSET = 'latin1'
MAX_PACKET_LENGTH = 256*256*256-1
def dump_packet(data):
def is_ascii(data):
if byte2int(data) >= 65 and byte2int(data) <= 122: #data.isalnum():
return data
return '.'
print "packet length %d" % len(data)
print "method call[1]: %s" % sys._getframe(1).f_code.co_name
print "method call[2]: %s" % sys._getframe(2).f_code.co_name
print "method call[3]: %s" % sys._getframe(3).f_code.co_name
print "method call[4]: %s" % sys._getframe(4).f_code.co_name
print "method call[5]: %s" % sys._getframe(5).f_code.co_name
print "-" * 88
try:
print "packet length %d" % len(data)
print "method call[1]: %s" % sys._getframe(1).f_code.co_name
print "method call[2]: %s" % sys._getframe(2).f_code.co_name
print "method call[3]: %s" % sys._getframe(3).f_code.co_name
print "method call[4]: %s" % sys._getframe(4).f_code.co_name
print "method call[5]: %s" % sys._getframe(5).f_code.co_name
print "-" * 88
except ValueError: pass
dump_data = [data[i:i+16] for i in xrange(len(data)) if i%16 == 0]
for d in dump_data:
print ' '.join(map(lambda x:"%02X" % byte2int(x), d)) + \
@@ -153,18 +161,28 @@ def unpack_uint16(n):
# TODO: stop using bit-shifting in these functions...
# TODO: rename to "uint" to make it clear they're unsigned...
def unpack_int24(n):
return struct.unpack('B',n[0])[0] + (struct.unpack('B', n[1])[0] << 8) +\
(struct.unpack('B',n[2])[0] << 16)
try:
return struct.unpack('B',n[0])[0] + (struct.unpack('B', n[1])[0] << 8) +\
(struct.unpack('B',n[2])[0] << 16)
except TypeError:
return n[0] + (n[1] << 8) + (n[2] << 16)
def unpack_int32(n):
return struct.unpack('B',n[0])[0] + (struct.unpack('B', n[1])[0] << 8) +\
(struct.unpack('B',n[2])[0] << 16) + (struct.unpack('B', n[3])[0] << 24)
try:
return struct.unpack('B',n[0])[0] + (struct.unpack('B', n[1])[0] << 8) +\
(struct.unpack('B',n[2])[0] << 16) + (struct.unpack('B', n[3])[0] << 24)
except TypeError:
return n[0] + (n[1] << 8) + (n[2] << 16) + (n[3] << 24)
def unpack_int64(n):
return struct.unpack('B',n[0])[0] + (struct.unpack('B', n[1])[0]<<8) +\
(struct.unpack('B',n[2])[0] << 16) + (struct.unpack('B',n[3])[0]<<24)+\
(struct.unpack('B',n[4])[0] << 32) + (struct.unpack('B',n[5])[0]<<40)+\
(struct.unpack('B',n[6])[0] << 48) + (struct.unpack('B',n[7])[0]<<56)
try:
return struct.unpack('B',n[0])[0] + (struct.unpack('B', n[1])[0]<<8) +\
(struct.unpack('B',n[2])[0] << 16) + (struct.unpack('B',n[3])[0]<<24)+\
(struct.unpack('B',n[4])[0] << 32) + (struct.unpack('B',n[5])[0]<<40)+\
(struct.unpack('B',n[6])[0] << 48) + (struct.unpack('B',n[7])[0]<<56)
except TypeError:
return n[0] + (n[1] << 8) + (n[2] << 16) + (n[3] << 24) +\
(n[4] << 32) + (n[5] << 40) + (n[6] << 48) + (n[7] << 56)
def defaulterrorhandler(connection, cursor, errorclass, errorvalue):
err = errorclass, errorvalue
@@ -189,19 +207,16 @@ class MysqlPacket(object):
from the network socket, removes packet header and provides an interface
for reading/parsing the packet results."""
def __init__(self, socket):
def __init__(self, connection):
self.connection = connection
self.__position = 0
self.__recv_packet(socket)
del socket
self.__recv_packet()
def __recv_packet(self, socket):
def __recv_packet(self):
"""Parse the packet header and read entire packet payload into buffer."""
packet_header = socket.recv(4)
while len(packet_header) < 4:
d = socket.recv(4 - len(packet_header))
if len(d) == 0:
raise OperationalError(2013, "Lost connection to MySQL server during query")
packet_header += d
packet_header = self.connection.rfile.read(4)
if len(packet_header) < 4:
raise OperationalError(2013, "Lost connection to MySQL server during query")
if DEBUG: dump_packet(packet_header)
packet_length_bin = packet_header[:3]
@@ -210,16 +225,11 @@ class MysqlPacket(object):
bin_length = packet_length_bin + int2byte(0) # pad little-endian number
bytes_to_read = struct.unpack('<I', bin_length)[0]
payload_buff = [] # this is faster than cStringIO
while bytes_to_read > 0:
recv_data = socket.recv(bytes_to_read)
if len(recv_data) == 0:
raise OperationalError(2013, "Lost connection to MySQL server during query")
if DEBUG: dump_packet(recv_data)
payload_buff.append(recv_data)
bytes_to_read -= len(recv_data)
self.__data = join_bytes(payload_buff)
recv_data = self.connection.rfile.read(bytes_to_read)
if len(recv_data) < bytes_to_read:
raise OperationalError(2013, "Lost connection to MySQL server during query")
if DEBUG: dump_packet(recv_data)
self.__data = recv_data
def packet_number(self): return self.__packet_number
@@ -354,7 +364,7 @@ class FieldDescriptorPacket(MysqlPacket):
self.db = self.read_length_coded_string()
self.table_name = self.read_length_coded_string()
self.org_table = self.read_length_coded_string()
self.name = self.read_length_coded_string()
self.name = self.read_length_coded_string().decode(self.connection.charset)
self.org_name = self.read_length_coded_string()
self.advance(1) # non-null filler
self.charsetnr = struct.unpack('<H', self.read(2))[0]
@@ -396,6 +406,58 @@ class FieldDescriptorPacket(MysqlPacket):
% (self.__class__, self.db, self.table_name, self.name,
self.type_code))
class OKPacketWrapper(object):
"""
OK Packet Wrapper. It uses an existing packet object, and wraps
around it, exposing useful variables while still providing access
to the original packet objects variables and methods.
"""
def __init__(self, from_packet):
if not from_packet.is_ok_packet():
raise ValueError('Cannot create ' + str(self.__class__.__name__)
+ ' object from invalid packet type')
self.packet = from_packet
self.packet.advance(1)
self.affected_rows = self.packet.read_length_coded_binary()
self.insert_id = self.packet.read_length_coded_binary()
self.server_status = struct.unpack('<H', self.packet.read(2))[0]
self.warning_count = struct.unpack('<H', self.packet.read(2))[0]
self.message = self.packet.read_all()
def __getattr__(self, key):
if hasattr(self.packet, key):
return getattr(self.packet, key)
raise AttributeError(str(self.__class__)
+ " instance has no attribute '" + key + "'")
class EOFPacketWrapper(object):
"""
EOF Packet Wrapper. It uses an existing packet object, and wraps
around it, exposing useful variables while still providing access
to the original packet objects variables and methods.
"""
def __init__(self, from_packet):
if not from_packet.is_eof_packet():
raise ValueError('Cannot create ' + str(self.__class__.__name__)
+ ' object from invalid packet type')
self.packet = from_packet
self.warning_count = self.packet.read(2)
server_status = struct.unpack('<h', self.packet.read(2))[0]
self.has_next = (server_status
& SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS)
def __getattr__(self, key):
if hasattr(self.packet, key):
return getattr(self.packet, key)
raise AttributeError(str(self.__class__)
+ " instance has no attribute '" + key + "'")
class Connection(object):
"""
@@ -482,12 +544,12 @@ class Connection(object):
host = _config("host", host)
db = _config("db",db)
unix_socket = _config("socket",unix_socket)
port = _config("port", port)
port = int(_config("port", port))
charset = _config("default-character-set", charset)
self.host = host
self.port = port
self.user = user
self.user = user or DEFAULT_USER
self.password = passwd
self.db = db
self.unix_socket = unix_socket
@@ -498,7 +560,7 @@ class Connection(object):
self.charset = DEFAULT_CHARSET
self.use_unicode = False
if use_unicode:
if use_unicode is not None:
self.use_unicode = use_unicode
client_flag |= CAPABILITIES
@@ -512,14 +574,15 @@ class Connection(object):
self._connect()
self._result = None
self._affected_rows = 0
self.host_info = "Not connected"
self.messages = []
self.set_charset(charset)
self.encoders = encoders
self.decoders = conv
self._affected_rows = 0
self.host_info = "Not connected"
self.autocommit(False)
if sql_mode is not None:
@@ -537,10 +600,16 @@ class Connection(object):
def close(self):
''' Send the quit message and close the socket '''
if self.socket is None:
raise Error("Already closed")
send_data = struct.pack('<i',1) + int2byte(COM_QUIT)
self.socket.send(send_data)
self.wfile.write(send_data)
self.wfile.close()
self.rfile.close()
self.socket.close()
self.socket = None
self.rfile = None
self.wfile = None
def autocommit(self, value):
''' Set whether or not to commit after every execute() '''
@@ -578,8 +647,10 @@ class Connection(object):
''' Alias for escape() '''
return escape_item(obj, self.charset)
def cursor(self):
def cursor(self, cursor=None):
''' Create a new cursor to execute queries with '''
if cursor:
return cursor(self)
return self.cursorclass(self)
def __enter__(self):
@@ -594,9 +665,11 @@ class Connection(object):
self.commit()
# The following methods are INTERNAL USE ONLY (called from Cursor)
def query(self, sql):
def query(self, sql, unbuffered=False):
if DEBUG:
print "sending query: %s" % sql
self._execute_command(COM_QUERY, sql)
self._affected_rows = self._read_query_result()
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
return self._affected_rows
def next_result(self):
@@ -621,6 +694,8 @@ class Connection(object):
''' Check if the server is alive '''
try:
self._execute_command(COM_PING, "")
pkt = self.read_packet()
return pkt.is_ok_packet()
except:
if reconnect:
self._connect()
@@ -630,9 +705,6 @@ class Connection(object):
self.errorhandler(None, exc, value)
return
pkt = self.read_packet()
return pkt.is_ok_packet()
def set_charset(self, charset):
try:
if charset:
@@ -663,64 +735,67 @@ class Connection(object):
self.host_info = "socket %s:%d" % (self.host, self.port)
if DEBUG: print 'connected using socket'
self.socket = sock
self.rfile = self.socket.makefile("rb")
self.wfile = self.socket.makefile("wb")
self._get_server_information()
self._request_authentication()
except socket.error, e:
raise OperationalError(2003, "Can't connect to MySQL server on %r (%d)" % (self.host, e.args[0]))
raise OperationalError(2003, "Can't connect to MySQL server on %r (%s)" % (self.host, e.args[0]))
def read_packet(self, packet_type=MysqlPacket):
"""Read an entire "mysql packet" in its entirety from the network
and return a MysqlPacket type that represents the results."""
# TODO: is socket.recv(small_number) significantly slower than
# socket.recv(large_number)? if so, maybe we should buffer
# the socket.recv() (though that obviously makes memory management
# more complicated.
packet = packet_type(self.socket)
packet = packet_type(self)
packet.check_error()
return packet
def _read_query_result(self):
result = MySQLResult(self)
result.read()
def _read_query_result(self, unbuffered=False):
if unbuffered:
try:
result = MySQLResult(self)
result.init_unbuffered_query()
except:
result.unbuffered_active = False
raise
else:
result = MySQLResult(self)
result.read()
self._result = result
return result.affected_rows
def insert_id(self):
if self._result:
return self._result.insert_id
else:
return 0
def _send_command(self, command, sql):
#send_data = struct.pack('<i', len(sql) + 1) + command + sql
# could probably be more efficient, at least it's correct
if not self.socket:
self.errorhandler(None, InterfaceError, "(0, '')")
# If the last query was unbuffered, make sure it finishes before
# sending new commands
if self._result is not None and self._result.unbuffered_active:
self._result._finish_unbuffered_query()
if isinstance(sql, unicode):
sql = sql.encode(self.charset)
buf = int2byte(command) + sql
pckt_no = 0
while len(buf) >= MAX_PACKET_LENGTH:
header = struct.pack('<i', MAX_PACKET_LENGTH)[:-1]+int2byte(pckt_no)
send_data = header + buf[:MAX_PACKET_LENGTH]
self.socket.send(send_data)
if DEBUG: dump_packet(send_data)
buf = buf[MAX_PACKET_LENGTH:]
pckt_no += 1
header = struct.pack('<i', len(buf))[:-1]+int2byte(pckt_no)
self.socket.send(header+buf)
#sock = self.socket
#sock.send(send_data)
#
prelude = struct.pack('<i', len(sql)+1) + int2byte(command)
self.wfile.write(prelude + sql)
self.wfile.flush()
if DEBUG: dump_packet(prelude + sql)
def _execute_command(self, command, sql):
self._send_command(command, sql)
def _request_authentication(self):
self._send_authentication()
def _send_authentication(self):
sock = self.socket
self.client_flag |= CAPABILITIES
if self.server_version.startswith('5'):
self.client_flag |= MULTI_RESULTS
@@ -742,12 +817,15 @@ class Connection(object):
if DEBUG: dump_packet(data)
sock.send(data)
sock = self.socket = ssl.wrap_socket(sock, keyfile=self.key,
certfile=self.cert,
ssl_version=ssl.PROTOCOL_TLSv1,
cert_reqs=ssl.CERT_REQUIRED,
ca_certs=self.ca)
self.wfile.write(data)
self.wfile.flush()
self.socket = ssl.wrap_self.socketet(self.socket, keyfile=self.key,
certfile=self.cert,
ssl_version=ssl.PROTOCOL_TLSv1,
cert_reqs=ssl.CERT_REQUIRED,
ca_certs=self.ca)
self.rfile = self.socket.makefile("rb")
self.wfile = self.socket.makefile("wb")
data = data_init + self.user+int2byte(0) + _scramble(self.password.encode(self.charset), self.salt)
@@ -760,9 +838,10 @@ class Connection(object):
if DEBUG: dump_packet(data)
sock.send(data)
self.wfile.write(data)
self.wfile.flush()
auth_packet = MysqlPacket(sock)
auth_packet = MysqlPacket(self)
auth_packet.check_error()
if DEBUG: auth_packet.dump()
@@ -776,8 +855,9 @@ class Connection(object):
data = _scramble_323(self.password.encode(self.charset), self.salt.encode(self.charset)) + int2byte(0)
data = pack_int24(len(data)) + int2byte(next_packet) + data
sock.send(data)
auth_packet = MysqlPacket(sock)
self.wfile.write(data)
self.wfile.flush()
auth_packet = MysqlPacket(self)
auth_packet.check_error()
if DEBUG: auth_packet.dump()
@@ -796,9 +876,8 @@ class Connection(object):
return self.protocol_version
def _get_server_information(self):
sock = self.socket
i = 0
packet = MysqlPacket(sock)
packet = MysqlPacket(self)
data = packet.get_all_data()
if DEBUG: dump_packet(data)
@@ -862,6 +941,11 @@ class MySQLResult(object):
self.description = None
self.rows = None
self.has_next = None
self.unbuffered_active = False
def __del__(self):
if self.unbuffered_active:
self._finish_unbuffered_query()
def read(self):
self.first_packet = self.connection.read_packet()
@@ -872,19 +956,78 @@ class MySQLResult(object):
else:
self._read_result_packet()
def init_unbuffered_query(self):
self.unbuffered_active = True
self.first_packet = self.connection.read_packet()
if self.first_packet.is_ok_packet():
self._read_ok_packet()
self.unbuffered_active = False
else:
self.field_count = byte2int(self.first_packet.read(1))
self._get_descriptions()
# Apparently, MySQLdb picks this number because it's the maximum
# value of a 64bit unsigned integer. Since we're emulating MySQLdb,
# we set it to this instead of None, which would be preferred.
self.affected_rows = 18446744073709551615
def _read_ok_packet(self):
self.first_packet.advance(1) # field_count (always '0')
self.affected_rows = self.first_packet.read_length_coded_binary()
self.insert_id = self.first_packet.read_length_coded_binary()
self.server_status = struct.unpack('<H', self.first_packet.read(2))[0]
self.warning_count = struct.unpack('<H', self.first_packet.read(2))[0]
self.message = self.first_packet.read_all()
ok_packet = OKPacketWrapper(self.first_packet)
self.affected_rows = ok_packet.affected_rows
self.insert_id = ok_packet.insert_id
self.server_status = ok_packet.server_status
self.warning_count = ok_packet.warning_count
self.message = ok_packet.message
def _check_packet_is_eof(self, packet):
if packet.is_eof_packet():
eof_packet = EOFPacketWrapper(packet)
self.warning_count = eof_packet.warning_count
self.has_next = eof_packet.has_next
return True
return False
def _read_result_packet(self):
self.field_count = byte2int(self.first_packet.read(1))
self._get_descriptions()
self._read_rowdata_packet()
def _read_rowdata_packet_unbuffered(self):
# Check if in an active query
if self.unbuffered_active == False: return
# EOF
packet = self.connection.read_packet()
if self._check_packet_is_eof(packet):
self.unbuffered_active = False
self.rows = None
return
row = []
for field in self.fields:
data = packet.read_length_coded_string()
converted = None
if field.type_code in self.connection.decoders:
converter = self.connection.decoders[field.type_code]
if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter)
if data != None:
converted = converter(self.connection, field, data)
row.append(converted)
self.affected_rows = 1
self.rows = tuple((row))
if DEBUG: self.rows
def _finish_unbuffered_query(self):
# After much reading on the MySQL protocol, it appears that there is,
# in fact, no way to stop MySQL from sending all the data after
# executing a query, so we just spin, and wait for an EOF packet.
while self.unbuffered_active:
packet = self.connection.read_packet()
if self._check_packet_is_eof(packet):
self.unbuffered_active = False
# TODO: implement this as an iteratable so that it is more
# memory efficient and lower-latency to client...
def _read_rowdata_packet(self):
@@ -892,24 +1035,18 @@ class MySQLResult(object):
rows = []
while True:
packet = self.connection.read_packet()
if packet.is_eof_packet():
self.warning_count = packet.read(2)
server_status = struct.unpack('<h', packet.read(2))[0]
self.has_next = (server_status
& SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS)
if self._check_packet_is_eof(packet):
break
row = []
for field in self.fields:
data = packet.read_length_coded_string()
converted = None
if field.type_code in self.connection.decoders:
converter = self.connection.decoders[field.type_code]
if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter)
data = packet.read_length_coded_string()
converted = None
if data != None:
converted = converter(self.connection, field, data)
row.append(converted)
rows.append(tuple(row))
@@ -930,4 +1067,3 @@ class MySQLResult(object):
eof_packet = self.connection.read_packet()
assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'
self.description = tuple(description)
+24 -15
View File
@@ -1,10 +1,13 @@
import re
import datetime
import time
import sys
from constants import FIELD_TYPE, FLAG
from charset import charset_by_id
PYTHON3 = sys.version_info[0] > 2
try:
set
except NameError:
@@ -22,12 +25,12 @@ def escape_item(val, charset):
return escape_sequence(val, charset)
if type(val) is dict:
return escape_dict(val, charset)
if hasattr(val, "decode") and not isinstance(val, unicode):
if PYTHON3 and hasattr(val, "decode") and not isinstance(val, unicode):
# deal with py3k bytes
val = val.decode(charset)
encoder = encoders[type(val)]
val = encoder(val)
if type(val) is str:
if type(val) in [str, int]:
return val
val = val.encode(charset)
return val
@@ -44,7 +47,7 @@ def escape_sequence(val, charset):
for item in val:
quoted = escape_item(item, charset)
n.append(quoted)
return tuple(n)
return "(" + ",".join(n) + ")"
def escape_set(val, charset):
val = map(lambda x: escape_item(x, charset), val)
@@ -56,7 +59,10 @@ def escape_bool(value):
def escape_object(value):
return str(value)
escape_int = escape_long = escape_object
def escape_int(value):
return value
escape_long = escape_object
def escape_float(value):
return ('%.15g' % value)
@@ -142,16 +148,19 @@ def convert_timedelta(connection, field, obj):
can accept values as (+|-)DD HH:MM:SS. The latter format will not
be parsed correctly by this function.
"""
from math import modf
try:
microseconds = 0
if not isinstance(obj, unicode):
obj = obj.decode(connection.charset)
hours, minutes, seconds = tuple([int(x) for x in obj.split(':')])
if "." in obj:
(obj, tail) = obj.split('.')
microseconds = int(tail)
hours, minutes, seconds = obj.split(':')
tdelta = datetime.timedelta(
hours = int(hours),
minutes = int(minutes),
seconds = int(seconds),
microseconds = int(modf(float(seconds))[0]*1000000),
microseconds = microseconds
)
return tdelta
except ValueError:
@@ -179,12 +188,14 @@ def convert_time(connection, field, obj):
to be treated as time-of-day and not a time offset, then you can
use set this function as the converter for FIELD_TYPE.TIME.
"""
from math import modf
try:
hour, minute, second = obj.split(':')
return datetime.time(hour=int(hour), minute=int(minute),
second=int(second),
microsecond=int(modf(float(second))[0]*1000000))
microseconds = 0
if "." in obj:
(obj, tail) = obj.split('.')
microseconds = int(tail)
hours, minutes, seconds = obj.split(':')
return datetime.time(hour=int(hours), minute=int(minutes),
second=int(seconds), microsecond=microseconds)
except ValueError:
return None
@@ -267,8 +278,6 @@ def convert_characters(connection, field, data):
elif connection.charset != field_charset:
data = data.decode(field_charset)
data = data.encode(connection.charset)
else:
data = data.decode(connection.charset)
return data
def convert_int(connection, field, data):
@@ -334,6 +343,7 @@ try:
# python version > 2.3
from decimal import Decimal
def convert_decimal(connection, field, data):
data = data.decode(connection.charset)
return Decimal(data)
decoders[FIELD_TYPE.DECIMAL] = convert_decimal
decoders[FIELD_TYPE.NEWDECIMAL] = convert_decimal
@@ -344,4 +354,3 @@ try:
except ImportError:
pass
+170 -11
View File
@@ -92,12 +92,21 @@ class Cursor(object):
# TODO: make sure that conn.escape is correct
if args is not None:
query = query % conn.escape(args)
if isinstance(query, unicode):
query = query.encode(charset)
if args is not None:
if isinstance(args, tuple) or isinstance(args, list):
escaped_args = tuple(conn.escape(arg) for arg in args)
elif isinstance(args, dict):
escaped_args = dict((key, conn.escape(val)) for (key, val) in args.items())
else:
#If it's not a dictionary let's try escaping it anyways.
#Worst case it will throw a Value error
escaped_args = conn.escape(args)
query = query % escaped_args
result = 0
try:
result = self._query(query)
@@ -113,12 +122,12 @@ class Cursor(object):
def executemany(self, query, args):
''' Run several data against one query '''
del self.messages[:]
conn = self._get_db()
#conn = self._get_db()
if not args:
return
charset = conn.charset
if isinstance(query, unicode):
query = query.encode(charset)
#charset = conn.charset
#if isinstance(query, unicode):
# query = query.encode(charset)
self.rowcount = sum([ self.execute(query, arg) for arg in args ])
return self.rowcount
@@ -231,12 +240,9 @@ class Cursor(object):
self.lastrowid = conn._result.insert_id
self._rows = conn._result.rows
self._has_next = conn._result.has_next
conn._result = None
def __iter__(self):
self._check_executed()
result = self.rownumber and self._rows[self.rownumber:] or self._rows
return iter(result)
return iter(self.fetchone, None)
Warning = Warning
Error = Error
@@ -249,3 +255,156 @@ class Cursor(object):
ProgrammingError = ProgrammingError
NotSupportedError = NotSupportedError
class DictCursor(Cursor):
"""A cursor which returns results as a dictionary"""
def execute(self, query, args=None):
result = super(DictCursor, self).execute(query, args)
if self.description:
self._fields = [ field[0] for field in self.description ]
return result
def fetchone(self):
''' Fetch the next row '''
self._check_executed()
if self._rows is None or self.rownumber >= len(self._rows):
return None
result = dict(zip(self._fields, self._rows[self.rownumber]))
self.rownumber += 1
return result
def fetchmany(self, size=None):
''' Fetch several rows '''
self._check_executed()
if self._rows is None:
return None
end = self.rownumber + (size or self.arraysize)
result = [ dict(zip(self._fields, r)) for r in self._rows[self.rownumber:end] ]
self.rownumber = min(end, len(self._rows))
return tuple(result)
def fetchall(self):
''' Fetch all the rows '''
self._check_executed()
if self._rows is None:
return None
if self.rownumber:
result = [ dict(zip(self._fields, r)) for r in self._rows[self.rownumber:] ]
else:
result = [ dict(zip(self._fields, r)) for r in self._rows ]
self.rownumber = len(self._rows)
return tuple(result)
class SSCursor(Cursor):
"""
Unbuffered Cursor, mainly useful for queries that return a lot of data,
or for connections to remote servers over a slow network.
Instead of copying every row of data into a buffer, this will fetch
rows as needed. The upside of this, is the client uses much less memory,
and rows are returned much faster when traveling over a slow network,
or if the result set is very big.
There are limitations, though. The MySQL protocol doesn't support
returning the total number of rows, so the only way to tell how many rows
there are is to iterate over every row returned. Also, it currently isn't
possible to scroll backwards, as only the current row is held in memory.
"""
def close(self):
conn = self._get_db()
conn._result._finish_unbuffered_query()
try:
if self._has_next:
while self.nextset(): pass
except: pass
def _query(self, q):
conn = self._get_db()
self._last_executed = q
conn.query(q, unbuffered=True)
self._do_get_result()
return self.rowcount
def read_next(self):
""" Read next row """
conn = self._get_db()
conn._result._read_rowdata_packet_unbuffered()
return conn._result.rows
def fetchone(self):
""" Fetch next row """
self._check_executed()
row = self.read_next()
if row is None:
return None
self.rownumber += 1
return row
def fetchall(self):
"""
Fetch all, as per MySQLdb. Pretty useless for large queries, as
it is buffered. See fetchall_unbuffered(), if you want an unbuffered
generator version of this method.
"""
rows = []
while True:
row = self.fetchone()
if row is None:
break
rows.append(row)
return tuple(rows)
def fetchall_unbuffered(self):
"""
Fetch all, implemented as a generator, which isn't to standard,
however, it doesn't make sense to return everything in a list, as that
would use ridiculous memory for large result sets.
"""
row = self.fetchone()
while row is not None:
yield row
row = self.fetchone()
def fetchmany(self, size=None):
""" Fetch many """
self._check_executed()
if size is None:
size = self.arraysize
rows = []
for i in range(0, size):
row = self.read_next()
if row is None:
break
rows.append(row)
self.rownumber += 1
return tuple(rows)
def scroll(self, value, mode='relative'):
self._check_executed()
if not mode == 'relative' and not mode == 'absolute':
self.errorhandler(self, ProgrammingError,
"unknown scroll mode %s" % mode)
if mode == 'relative':
if value < 0:
self.errorhandler(self, NotSupportedError,
"Backwards scrolling not supported by this cursor")
for i in range(0, value): self.read_next()
self.rownumber += value
else:
if value < self.rownumber:
self.errorhandler(self, NotSupportedError,
"Backwards scrolling not supported by this cursor")
end = value - self.rownumber
for i in range(0, end): self.read_next()
self.rownumber = value
+17 -11
View File
@@ -2,20 +2,21 @@ import struct
try:
Exception, Warning
StandardError, Warning
except ImportError:
try:
from exceptions import Exception, Warning
from exceptions import StandardError, Warning
except ImportError:
import sys
e = sys.modules['exceptions']
Exception = e.Exception
StandardError = e.StandardError
Warning = e.Warning
from constants import ER
import sys
class MySQLError(Exception):
class MySQLError(StandardError):
"""Exception related to operation with MySQL."""
@@ -106,13 +107,19 @@ _map_error(IntegrityError, ER.DUP_ENTRY, ER.NO_REFERENCED_ROW,
ER.CANNOT_ADD_FOREIGN)
_map_error(NotSupportedError, ER.WARNING_NOT_COMPLETE_ROLLBACK,
ER.NOT_SUPPORTED_YET, ER.FEATURE_DISABLED, ER.UNKNOWN_STORAGE_ENGINE)
_map_error(OperationalError, ER.DBACCESS_DENIED_ERROR, ER.ACCESS_DENIED_ERROR,
ER.TABLEACCESS_DENIED_ERROR, ER.COLUMNACCESS_DENIED_ERROR)
del _map_error, ER
def _get_error_info(data):
errno = struct.unpack('<h', data[1:3])[0]
if data[3] == "#":
if sys.version_info[0] == 3:
is_41 = data[3] == ord("#")
else:
is_41 = data[3] == "#"
if is_41:
# version 4.1
sqlstate = data[4:9].decode("utf8")
errorvalue = data[9:].decode("utf8")
@@ -122,7 +129,7 @@ def _get_error_info(data):
return (errno, None, data[3:].decode("utf8"))
def _check_mysql_exception(errinfo):
errno, sqlstate, errorvalue = errinfo
errno, sqlstate, errorvalue = errinfo
errorclass = error_map.get(errno, None)
if errorclass:
raise errorclass, (errno,errorvalue)
@@ -133,8 +140,7 @@ def _check_mysql_exception(errinfo):
def raise_mysql_exception(data):
errinfo = _get_error_info(data)
_check_mysql_exception(errinfo)
+9 -3
View File
@@ -1,6 +1,12 @@
from gluon.contrib.pymysql.tests.test_issues import *
from gluon.contrib.pymysql.tests.test_example import *
from gluon.contrib.pymysql.tests.test_basic import *
from pymysql.tests.test_issues import *
from pymysql.tests.test_example import *
from pymysql.tests.test_basic import *
from pymysql.tests.test_DictCursor import *
import sys
if sys.version_info[0] == 2:
# MySQLdb tests were designed for Python 3
from pymysql.tests.thirdparty import *
if __name__ == "__main__":
import unittest
+5 -7
View File
@@ -1,20 +1,18 @@
from gluon.contrib import pymysql
import pymysql
import unittest
class PyMySQLTestCase(unittest.TestCase):
# Edit this to suit your test environment.
databases = [
{"host":"localhost","user":"root",
"passwd":"","db":"test_pymysql", "use_unicode": True},
{"host":"localhost","user":"root","passwd":"","db":"test_pymysql2"}]
def setUp(self):
try:
self.connections = []
self.connections = []
for params in self.databases:
self.connections.append(pymysql.connect(**params))
except pymysql.err.OperationalError as e:
self.skipTest('Cannot connect to MySQL - skipping pymysql tests because of (%s) %s' % (type(e), e))
for params in self.databases:
self.connections.append(pymysql.connect(**params))
def tearDown(self):
for connection in self.connections:
+56
View File
@@ -0,0 +1,56 @@
from pymysql.tests import base
import pymysql.cursors
import datetime
class TestDictCursor(base.PyMySQLTestCase):
def test_DictCursor(self):
#all assert test compare to the structure as would come out from MySQLdb
conn = self.connections[0]
c = conn.cursor(pymysql.cursors.DictCursor)
# create a table ane some data to query
c.execute("""CREATE TABLE dictcursor (name char(20), age int , DOB datetime)""")
data = (("bob",21,"1990-02-06 23:04:56"),
("jim",56,"1955-05-09 13:12:45"),
("fred",100,"1911-09-12 01:01:01"))
bob = {'name':'bob','age':21,'DOB':datetime.datetime(1990, 02, 6, 23, 04, 56)}
jim = {'name':'jim','age':56,'DOB':datetime.datetime(1955, 05, 9, 13, 12, 45)}
fred = {'name':'fred','age':100,'DOB':datetime.datetime(1911, 9, 12, 1, 1, 1)}
try:
c.executemany("insert into dictcursor values (%s,%s,%s)", data)
# try an update which should return no rows
c.execute("update dictcursor set age=20 where name='bob'")
bob['age'] = 20
# pull back the single row dict for bob and check
c.execute("SELECT * from dictcursor where name='bob'")
r = c.fetchone()
self.assertEqual(bob,r,"fetchone via DictCursor failed")
# same again, but via fetchall => tuple)
c.execute("SELECT * from dictcursor where name='bob'")
r = c.fetchall()
self.assertEqual((bob,),r,"fetch a 1 row result via fetchall failed via DictCursor")
# same test again but iterate over the
c.execute("SELECT * from dictcursor where name='bob'")
for r in c:
self.assertEqual(bob, r,"fetch a 1 row result via iteration failed via DictCursor")
# get all 3 row via fetchall
c.execute("SELECT * from dictcursor")
r = c.fetchall()
self.assertEqual((bob,jim,fred), r, "fetchall failed via DictCursor")
#same test again but do a list comprehension
c.execute("SELECT * from dictcursor")
r = [x for x in c]
self.assertEqual([bob,jim,fred], r, "list comprehension failed via DictCursor")
# get all 2 row via fetchmany
c.execute("SELECT * from dictcursor")
r = c.fetchmany(2)
self.assertEqual((bob,jim), r, "fetchmany failed via DictCursor")
finally:
c.execute("drop table dictcursor")
__all__ = ["TestDictCursor"]
if __name__ == "__main__":
import unittest
unittest.main()
+100
View File
@@ -0,0 +1,100 @@
import sys
try:
from pymysql.tests import base
import pymysql.cursors
except:
# For local testing from top-level directory, without installing
sys.path.append('../pymysql')
from pymysql.tests import base
import pymysql.cursors
class TestSSCursor(base.PyMySQLTestCase):
def test_SSCursor(self):
affected_rows = 18446744073709551615
conn = self.connections[0]
data = [
('America', '', 'America/Jamaica'),
('America', '', 'America/Los_Angeles'),
('America', '', 'America/Lima'),
('America', '', 'America/New_York'),
('America', '', 'America/Menominee'),
('America', '', 'America/Havana'),
('America', '', 'America/El_Salvador'),
('America', '', 'America/Costa_Rica'),
('America', '', 'America/Denver'),
('America', '', 'America/Detroit'),]
try:
cursor = conn.cursor(pymysql.cursors.SSCursor)
# Create table
cursor.execute(('CREATE TABLE tz_data ('
'region VARCHAR(64),'
'zone VARCHAR(64),'
'name VARCHAR(64))'))
# Test INSERT
for i in data:
cursor.execute('INSERT INTO tz_data VALUES (%s, %s, %s)', i)
self.assertEqual(conn.affected_rows(), 1, 'affected_rows does not match')
conn.commit()
# Test fetchone()
iter = 0
cursor.execute('SELECT * FROM tz_data')
while True:
row = cursor.fetchone()
if row is None:
break
iter += 1
# Test cursor.rowcount
self.assertEqual(cursor.rowcount, affected_rows,
'cursor.rowcount != %s' % (str(affected_rows)))
# Test cursor.rownumber
self.assertEqual(cursor.rownumber, iter,
'cursor.rowcount != %s' % (str(iter)))
# Test row came out the same as it went in
self.assertEqual((row in data), True,
'Row not found in source data')
# Test fetchall
cursor.execute('SELECT * FROM tz_data')
self.assertEqual(len(cursor.fetchall()), len(data),
'fetchall failed. Number of rows does not match')
# Test fetchmany
cursor.execute('SELECT * FROM tz_data')
self.assertEqual(len(cursor.fetchmany(2)), 2,
'fetchmany failed. Number of rows does not match')
# So MySQLdb won't throw "Commands out of sync"
while True:
res = cursor.fetchone()
if res is None:
break
# Test update, affected_rows()
cursor.execute('UPDATE tz_data SET zone = %s', ['Foo'])
conn.commit()
self.assertEqual(cursor.rowcount, len(data),
'Update failed. affected_rows != %s' % (str(len(data))))
# Test executemany
cursor.executemany('INSERT INTO tz_data VALUES (%s, %s, %s)', data)
self.assertEqual(cursor.rowcount, len(data),
'executemany failed. cursor.rowcount != %s' % (str(len(data))))
finally:
cursor.execute('DROP TABLE tz_data')
cursor.close()
__all__ = ["TestSSCursor"]
if __name__ == "__main__":
import unittest
unittest.main()
+74 -2
View File
@@ -1,5 +1,5 @@
from gluon.contrib.pymysql.tests import base
from gluon.contrib.pymysql import util
from pymysql.tests import base
from pymysql import util
import time
import datetime
@@ -55,6 +55,31 @@ class TestConversion(base.PyMySQLTestCase):
finally:
c.execute("drop table test_dict")
def test_string(self):
conn = self.connections[0]
c = conn.cursor()
c.execute("create table test_dict (a text)")
test_value = "I am a test string"
try:
c.execute("insert into test_dict (a) values (%s)", test_value)
c.execute("select a from test_dict")
self.assertEqual((test_value,), c.fetchone())
finally:
c.execute("drop table test_dict")
def test_integer(self):
conn = self.connections[0]
c = conn.cursor()
c.execute("create table test_dict (a integer)")
test_value = 12345
try:
c.execute("insert into test_dict (a) values (%s)", test_value)
c.execute("select a from test_dict")
self.assertEqual((test_value,), c.fetchone())
finally:
c.execute("drop table test_dict")
def test_big_blob(self):
""" test tons of data """
conn = self.connections[0]
@@ -67,6 +92,26 @@ class TestConversion(base.PyMySQLTestCase):
self.assertEqual(data.encode(conn.charset), c.fetchone()[0])
finally:
c.execute("drop table test_big_blob")
def test_untyped(self):
""" test conversion of null, empty string """
conn = self.connections[0]
c = conn.cursor()
c.execute("select null,''")
self.assertEqual((None,u''), c.fetchone())
c.execute("select '',null")
self.assertEqual((u'',None), c.fetchone())
def test_datetime(self):
""" test conversion of null, empty string """
conn = self.connections[0]
c = conn.cursor()
c.execute("select time('12:30'), time('23:12:59'), time('23:12:59.05100')")
self.assertEqual((datetime.timedelta(0, 45000),
datetime.timedelta(0, 83579),
datetime.timedelta(0, 83579, 51000)),
c.fetchone())
class TestCursor(base.PyMySQLTestCase):
# this test case does not work quite right yet, however,
@@ -134,6 +179,33 @@ class TestCursor(base.PyMySQLTestCase):
finally:
c.execute("drop table test_nr")
def test_aggregates(self):
""" test aggregate functions """
conn = self.connections[0]
c = conn.cursor()
try:
c.execute('create table test_aggregates (i integer)')
for i in xrange(0, 10):
c.execute('insert into test_aggregates (i) values (%s)', (i,))
c.execute('select sum(i) from test_aggregates')
r, = c.fetchone()
self.assertEqual(sum(range(0,10)), r)
finally:
c.execute('drop table test_aggregates')
def test_single_tuple(self):
""" test a single tuple """
conn = self.connections[0]
c = conn.cursor()
try:
c.execute("create table mystuff (id integer primary key)")
c.execute("insert into mystuff (id) values (1)")
c.execute("insert into mystuff (id) values (2)")
c.execute("select id from mystuff where id in %s", ((1,),))
self.assertEqual([(1,)], list(c.fetchall()))
finally:
c.execute("drop table mystuff")
__all__ = ["TestConversion","TestCursor"]
if __name__ == "__main__":
+2 -2
View File
@@ -1,5 +1,5 @@
from gluon.contrib import pymysql
from gluon.contrib.pymysql.tests import base
import pymysql
from pymysql.tests import base
class TestExample(base.PyMySQLTestCase):
def test_example(self):
+60 -22
View File
@@ -1,5 +1,6 @@
from gluon.contrib import pymysql
from gluon.contrib.pymysql.tests import base
import pymysql
from pymysql.tests import base
import unittest
import sys
@@ -11,6 +12,10 @@ except AttributeError:
import datetime
# backwards compatibility:
if not hasattr(unittest, "skip"):
unittest.skip = lambda message: lambda f: f
class TestOldIssues(base.PyMySQLTestCase):
def test_issue_3(self):
""" undefined methods datetime_or_None, date_or_None """
@@ -89,15 +94,15 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
""" can't handle large result fields """
conn = self.connections[0]
cur = conn.cursor()
cur.execute("create table issue13 (t text)")
try:
cur.execute("create table issue13 (t text)")
# ticket says 18k
size = 18*1024
cur.execute("insert into issue13 (t) values (%s)", ("x" * size,))
cur.execute("select t from issue13")
# use assert_ so that obscenely huge error messages don't print
# use assertTrue so that obscenely huge error messages don't print
r = cur.fetchone()[0]
self.assert_("x" * size == r)
self.assertTrue("x" * size == r)
finally:
cur.execute("drop table issue13")
@@ -115,7 +120,7 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
c = conn.cursor()
c.execute("create table issue15 (t varchar(32))")
try:
c.execute("insert into issue15 (t) values (%s)", (u'\xe4\xf6\xfc'))
c.execute("insert into issue15 (t) values (%s)", (u'\xe4\xf6\xfc',))
c.execute("select t from issue15")
self.assertEqual(u'\xe4\xf6\xfc', c.fetchone()[0])
finally:
@@ -133,6 +138,7 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
finally:
c.execute("drop table issue16")
@unittest.skip("test_issue_17() requires a custom, legacy MySQL configuration and will not be run.")
def test_issue_17(self):
""" could not connect mysql use passwod """
conn = self.connections[0]
@@ -181,23 +187,22 @@ class TestNewIssues(base.PyMySQLTestCase):
finally:
c.execute(_uni("drop table hei\xc3\x9fe", "utf8"))
# Will fail without manual intervention:
#def test_issue_35(self):
#
# conn = self.connections[0]
# c = conn.cursor()
# print "sudo killall -9 mysqld within the next 10 seconds"
# try:
# c.execute("select sleep(10)")
# self.fail()
# except pymysql.OperationalError, e:
# self.assertEqual(2013, e.args[0])
@unittest.skip("This test requires manual intervention")
def test_issue_35(self):
conn = self.connections[0]
c = conn.cursor()
print "sudo killall -9 mysqld within the next 10 seconds"
try:
c.execute("select sleep(10)")
self.fail()
except pymysql.OperationalError, e:
self.assertEqual(2013, e.args[0])
def test_issue_36(self):
conn = self.connections[0]
c = conn.cursor()
# kill connections[0]
original_count = c.execute("show processlist")
c.execute("show processlist")
kill_id = None
for id,user,host,db,command,time,state,info in c.fetchall():
if info == "show processlist":
@@ -212,8 +217,13 @@ class TestNewIssues(base.PyMySQLTestCase):
except:
pass
# check the process list from the other connection
self.assertEqual(original_count - 1, self.connections[1].cursor().execute("show processlist"))
del self.connections[0]
try:
c = self.connections[1].cursor()
c.execute("show processlist")
ids = [row[0] for row in c.fetchall()]
self.assertFalse(kill_id in ids)
finally:
del self.connections[0]
def test_issue_37(self):
conn = self.connections[0]
@@ -230,10 +240,38 @@ class TestNewIssues(base.PyMySQLTestCase):
try:
c.execute("create table issue38 (id integer, data mediumblob)")
c.execute("insert into issue38 values (1, %s)", datum)
c.execute("insert into issue38 values (1, %s)", (datum,))
finally:
c.execute("drop table issue38")
__all__ = ["TestOldIssues", "TestNewIssues"]
def disabled_test_issue_54(self):
conn = self.connections[0]
c = conn.cursor()
big_sql = "select * from issue54 where "
big_sql += " and ".join("%d=%d" % (i,i) for i in xrange(0, 100000))
try:
c.execute("create table issue54 (id integer primary key)")
c.execute("insert into issue54 (id) values (7)")
c.execute(big_sql)
self.assertEqual(7, c.fetchone()[0])
finally:
c.execute("drop table issue54")
class TestGitHubIssues(base.PyMySQLTestCase):
def test_issue_66(self):
conn = self.connections[0]
c = conn.cursor()
self.assertEqual(0, conn.insert_id())
try:
c.execute("create table issue66 (id integer primary key auto_increment, x integer)")
c.execute("insert into issue66 (x) values (1)")
c.execute("insert into issue66 (x) values (1)")
self.assertEqual(2, conn.insert_id())
finally:
c.execute("drop table issue66")
__all__ = ["TestOldIssues", "TestNewIssues", "TestGitHubIssues"]
if __name__ == "__main__":
import unittest
-1
View File
@@ -14,4 +14,3 @@ def TimeFromTicks(ticks):
def TimestampFromTicks(ticks):
return datetime(*localtime(ticks)[:6])
-1
View File
@@ -17,4 +17,3 @@ def join_bytes(bs):
for b in bs[1:]:
rv += b
return rv