upgraded to pymysql 0.2
This commit is contained in:
@@ -1 +1 @@
|
||||
Version 2.00.0 (2012-07-12 15:23:22) dev
|
||||
Version 2.00.0 (2012-07-12 15:31:21) dev
|
||||
|
||||
Executable
+37
@@ -0,0 +1,37 @@
|
||||
Changes
|
||||
--------
|
||||
0.4 -Miscellaneous bug fixes
|
||||
-Implementation of SSL support
|
||||
-Implementation of kill()
|
||||
-Cleaned up charset functionality
|
||||
-Fixed BIT type handling
|
||||
-Connections raise exceptions after they are close()'d
|
||||
-Full Py3k and unicode support
|
||||
|
||||
0.3 -Implemented most of the extended DBAPI 2.0 spec including callproc()
|
||||
-Fixed error handling to include the message from the server and support
|
||||
multiple protocol versions.
|
||||
-Implemented ping()
|
||||
-Implemented unicode support (probably needs better testing)
|
||||
-Removed DeprecationWarnings
|
||||
-Ran against the MySQLdb unit tests to check for bugs
|
||||
-Added support for client_flag, charset, sql_mode, read_default_file,
|
||||
use_unicode, cursorclass, init_command, and connect_timeout.
|
||||
-Refactoring for some more compatibility with MySQLdb including a fake
|
||||
pymysql.version_info attribute.
|
||||
-Now runs with no warnings with the -3 command-line switch
|
||||
-Added test cases for all outstanding tickets and closed most of them.
|
||||
-Basic Jython support added.
|
||||
-Fixed empty result sets bug.
|
||||
-Integrated new unit tests and refactored the example into one.
|
||||
-Fixed bug with decimal conversion.
|
||||
-Fixed string encoding bug. Now unicode and binary data work!
|
||||
-Added very basic docstrings.
|
||||
|
||||
0.2 -Changed connection parameter name 'password' to 'passwd'
|
||||
to make it more plugin replaceable for the other mysql clients.
|
||||
-Changed pack()/unpack() calls so it runs on 64 bit OSes too.
|
||||
-Added support for unix_socket.
|
||||
-Added support for no password.
|
||||
-Renamed decorders to decoders.
|
||||
-Better handling of non-existing decoder.
|
||||
Regular → Executable
+6
-2
@@ -7,8 +7,8 @@ PyMySQL Installation
|
||||
This package contains a pure-Python MySQL client library.
|
||||
Documentation on the MySQL client/server protocol can be found here:
|
||||
http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol
|
||||
If you would like to run the test suite, create a ~/.my.cnf file and
|
||||
a database called "test_pymysql". The goal of pymysql is to be a drop-in
|
||||
If you would like to run the test suite, edit the config parameters in
|
||||
pymysql/tests/base.py. The goal of pymysql is to be a drop-in
|
||||
replacement for MySQLdb and work on CPython 2.3+, Jython, IronPython, PyPy
|
||||
and Python 3. We test for compatibility by simply changing the import
|
||||
statements in the Django MySQL backend and running its unit tests as well
|
||||
@@ -34,4 +34,8 @@ Installation
|
||||
# ... or ...
|
||||
# python setup.py install
|
||||
|
||||
Python 3.0 Support
|
||||
------------------
|
||||
|
||||
Simply run the build-py3k.sh script from the local directory. It will
|
||||
build a working package in the ./py3k directory.
|
||||
@@ -23,13 +23,13 @@ THE SOFTWARE.
|
||||
|
||||
'''
|
||||
|
||||
VERSION = (0, 4, None)
|
||||
VERSION = (0, 5, None)
|
||||
|
||||
from constants import FIELD_TYPE
|
||||
from converters import escape_dict, escape_sequence, escape_string
|
||||
from err import Warning, Error, InterfaceError, DataError, \
|
||||
DatabaseError, OperationalError, IntegrityError, InternalError, \
|
||||
NotSupportedError, ProgrammingError
|
||||
NotSupportedError, ProgrammingError, MySQLError
|
||||
from times import Date, Time, Timestamp, \
|
||||
DateFromTicks, TimeFromTicks, TimestampFromTicks
|
||||
|
||||
@@ -110,7 +110,7 @@ def thread_safe():
|
||||
def install_as_MySQLdb():
|
||||
"""
|
||||
After this function is called, any application that imports MySQLdb or
|
||||
_mysql will unwittingly actually use
|
||||
_mysql will unwittingly actually use
|
||||
"""
|
||||
sys.modules["MySQLdb"] = sys.modules["_mysql"] = sys.modules["pymysql"]
|
||||
|
||||
@@ -121,12 +121,11 @@ __all__ = [
|
||||
'InterfaceError', 'InternalError', 'MySQLError', 'NULL', 'NUMBER',
|
||||
'NotSupportedError', 'DBAPISet', 'OperationalError', 'ProgrammingError',
|
||||
'ROWID', 'STRING', 'TIME', 'TIMESTAMP', 'Warning', 'apilevel', 'connect',
|
||||
'connections', 'constants', 'converters', 'cursors', 'debug', 'escape',
|
||||
'connections', 'constants', 'converters', 'cursors',
|
||||
'escape_dict', 'escape_sequence', 'escape_string', 'get_client_info',
|
||||
'paramstyle', 'string_literal', 'threadsafety', 'version_info',
|
||||
'paramstyle', 'threadsafety', 'version_info',
|
||||
|
||||
"install_as_MySQLdb",
|
||||
|
||||
"NULL","__version__",
|
||||
]
|
||||
|
||||
|
||||
@@ -172,4 +172,3 @@ def charset_by_name(name):
|
||||
def charset_by_id(id):
|
||||
return _charsets.by_id(id)
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,12 @@ try:
|
||||
except ImportError:
|
||||
import StringIO
|
||||
|
||||
try:
|
||||
import getpass
|
||||
DEFAULT_USER = getpass.getuser()
|
||||
except ImportError:
|
||||
DEFAULT_USER = None
|
||||
|
||||
from charset import MBLENGTH, charset_by_name, charset_by_id
|
||||
from cursors import Cursor
|
||||
from constants import FIELD_TYPE, FLAG
|
||||
@@ -50,22 +56,24 @@ UNSIGNED_INT24_LENGTH = 3
|
||||
UNSIGNED_INT64_LENGTH = 8
|
||||
|
||||
DEFAULT_CHARSET = 'latin1'
|
||||
MAX_PACKET_LENGTH = 256*256*256-1
|
||||
|
||||
|
||||
def dump_packet(data):
|
||||
|
||||
|
||||
def is_ascii(data):
|
||||
if byte2int(data) >= 65 and byte2int(data) <= 122: #data.isalnum():
|
||||
return data
|
||||
return '.'
|
||||
print "packet length %d" % len(data)
|
||||
print "method call[1]: %s" % sys._getframe(1).f_code.co_name
|
||||
print "method call[2]: %s" % sys._getframe(2).f_code.co_name
|
||||
print "method call[3]: %s" % sys._getframe(3).f_code.co_name
|
||||
print "method call[4]: %s" % sys._getframe(4).f_code.co_name
|
||||
print "method call[5]: %s" % sys._getframe(5).f_code.co_name
|
||||
print "-" * 88
|
||||
|
||||
try:
|
||||
print "packet length %d" % len(data)
|
||||
print "method call[1]: %s" % sys._getframe(1).f_code.co_name
|
||||
print "method call[2]: %s" % sys._getframe(2).f_code.co_name
|
||||
print "method call[3]: %s" % sys._getframe(3).f_code.co_name
|
||||
print "method call[4]: %s" % sys._getframe(4).f_code.co_name
|
||||
print "method call[5]: %s" % sys._getframe(5).f_code.co_name
|
||||
print "-" * 88
|
||||
except ValueError: pass
|
||||
dump_data = [data[i:i+16] for i in xrange(len(data)) if i%16 == 0]
|
||||
for d in dump_data:
|
||||
print ' '.join(map(lambda x:"%02X" % byte2int(x), d)) + \
|
||||
@@ -153,18 +161,28 @@ def unpack_uint16(n):
|
||||
# TODO: stop using bit-shifting in these functions...
|
||||
# TODO: rename to "uint" to make it clear they're unsigned...
|
||||
def unpack_int24(n):
|
||||
return struct.unpack('B',n[0])[0] + (struct.unpack('B', n[1])[0] << 8) +\
|
||||
(struct.unpack('B',n[2])[0] << 16)
|
||||
try:
|
||||
return struct.unpack('B',n[0])[0] + (struct.unpack('B', n[1])[0] << 8) +\
|
||||
(struct.unpack('B',n[2])[0] << 16)
|
||||
except TypeError:
|
||||
return n[0] + (n[1] << 8) + (n[2] << 16)
|
||||
|
||||
def unpack_int32(n):
|
||||
return struct.unpack('B',n[0])[0] + (struct.unpack('B', n[1])[0] << 8) +\
|
||||
(struct.unpack('B',n[2])[0] << 16) + (struct.unpack('B', n[3])[0] << 24)
|
||||
try:
|
||||
return struct.unpack('B',n[0])[0] + (struct.unpack('B', n[1])[0] << 8) +\
|
||||
(struct.unpack('B',n[2])[0] << 16) + (struct.unpack('B', n[3])[0] << 24)
|
||||
except TypeError:
|
||||
return n[0] + (n[1] << 8) + (n[2] << 16) + (n[3] << 24)
|
||||
|
||||
def unpack_int64(n):
|
||||
return struct.unpack('B',n[0])[0] + (struct.unpack('B', n[1])[0]<<8) +\
|
||||
(struct.unpack('B',n[2])[0] << 16) + (struct.unpack('B',n[3])[0]<<24)+\
|
||||
(struct.unpack('B',n[4])[0] << 32) + (struct.unpack('B',n[5])[0]<<40)+\
|
||||
(struct.unpack('B',n[6])[0] << 48) + (struct.unpack('B',n[7])[0]<<56)
|
||||
try:
|
||||
return struct.unpack('B',n[0])[0] + (struct.unpack('B', n[1])[0]<<8) +\
|
||||
(struct.unpack('B',n[2])[0] << 16) + (struct.unpack('B',n[3])[0]<<24)+\
|
||||
(struct.unpack('B',n[4])[0] << 32) + (struct.unpack('B',n[5])[0]<<40)+\
|
||||
(struct.unpack('B',n[6])[0] << 48) + (struct.unpack('B',n[7])[0]<<56)
|
||||
except TypeError:
|
||||
return n[0] + (n[1] << 8) + (n[2] << 16) + (n[3] << 24) +\
|
||||
(n[4] << 32) + (n[5] << 40) + (n[6] << 48) + (n[7] << 56)
|
||||
|
||||
def defaulterrorhandler(connection, cursor, errorclass, errorvalue):
|
||||
err = errorclass, errorvalue
|
||||
@@ -189,19 +207,16 @@ class MysqlPacket(object):
|
||||
from the network socket, removes packet header and provides an interface
|
||||
for reading/parsing the packet results."""
|
||||
|
||||
def __init__(self, socket):
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
self.__position = 0
|
||||
self.__recv_packet(socket)
|
||||
del socket
|
||||
self.__recv_packet()
|
||||
|
||||
def __recv_packet(self, socket):
|
||||
def __recv_packet(self):
|
||||
"""Parse the packet header and read entire packet payload into buffer."""
|
||||
packet_header = socket.recv(4)
|
||||
while len(packet_header) < 4:
|
||||
d = socket.recv(4 - len(packet_header))
|
||||
if len(d) == 0:
|
||||
raise OperationalError(2013, "Lost connection to MySQL server during query")
|
||||
packet_header += d
|
||||
packet_header = self.connection.rfile.read(4)
|
||||
if len(packet_header) < 4:
|
||||
raise OperationalError(2013, "Lost connection to MySQL server during query")
|
||||
|
||||
if DEBUG: dump_packet(packet_header)
|
||||
packet_length_bin = packet_header[:3]
|
||||
@@ -210,16 +225,11 @@ class MysqlPacket(object):
|
||||
|
||||
bin_length = packet_length_bin + int2byte(0) # pad little-endian number
|
||||
bytes_to_read = struct.unpack('<I', bin_length)[0]
|
||||
|
||||
payload_buff = [] # this is faster than cStringIO
|
||||
while bytes_to_read > 0:
|
||||
recv_data = socket.recv(bytes_to_read)
|
||||
if len(recv_data) == 0:
|
||||
raise OperationalError(2013, "Lost connection to MySQL server during query")
|
||||
if DEBUG: dump_packet(recv_data)
|
||||
payload_buff.append(recv_data)
|
||||
bytes_to_read -= len(recv_data)
|
||||
self.__data = join_bytes(payload_buff)
|
||||
recv_data = self.connection.rfile.read(bytes_to_read)
|
||||
if len(recv_data) < bytes_to_read:
|
||||
raise OperationalError(2013, "Lost connection to MySQL server during query")
|
||||
if DEBUG: dump_packet(recv_data)
|
||||
self.__data = recv_data
|
||||
|
||||
def packet_number(self): return self.__packet_number
|
||||
|
||||
@@ -354,7 +364,7 @@ class FieldDescriptorPacket(MysqlPacket):
|
||||
self.db = self.read_length_coded_string()
|
||||
self.table_name = self.read_length_coded_string()
|
||||
self.org_table = self.read_length_coded_string()
|
||||
self.name = self.read_length_coded_string()
|
||||
self.name = self.read_length_coded_string().decode(self.connection.charset)
|
||||
self.org_name = self.read_length_coded_string()
|
||||
self.advance(1) # non-null filler
|
||||
self.charsetnr = struct.unpack('<H', self.read(2))[0]
|
||||
@@ -396,6 +406,58 @@ class FieldDescriptorPacket(MysqlPacket):
|
||||
% (self.__class__, self.db, self.table_name, self.name,
|
||||
self.type_code))
|
||||
|
||||
class OKPacketWrapper(object):
|
||||
"""
|
||||
OK Packet Wrapper. It uses an existing packet object, and wraps
|
||||
around it, exposing useful variables while still providing access
|
||||
to the original packet objects variables and methods.
|
||||
"""
|
||||
|
||||
def __init__(self, from_packet):
|
||||
if not from_packet.is_ok_packet():
|
||||
raise ValueError('Cannot create ' + str(self.__class__.__name__)
|
||||
+ ' object from invalid packet type')
|
||||
|
||||
self.packet = from_packet
|
||||
self.packet.advance(1)
|
||||
|
||||
self.affected_rows = self.packet.read_length_coded_binary()
|
||||
self.insert_id = self.packet.read_length_coded_binary()
|
||||
self.server_status = struct.unpack('<H', self.packet.read(2))[0]
|
||||
self.warning_count = struct.unpack('<H', self.packet.read(2))[0]
|
||||
self.message = self.packet.read_all()
|
||||
|
||||
def __getattr__(self, key):
|
||||
if hasattr(self.packet, key):
|
||||
return getattr(self.packet, key)
|
||||
|
||||
raise AttributeError(str(self.__class__)
|
||||
+ " instance has no attribute '" + key + "'")
|
||||
|
||||
class EOFPacketWrapper(object):
|
||||
"""
|
||||
EOF Packet Wrapper. It uses an existing packet object, and wraps
|
||||
around it, exposing useful variables while still providing access
|
||||
to the original packet objects variables and methods.
|
||||
"""
|
||||
|
||||
def __init__(self, from_packet):
|
||||
if not from_packet.is_eof_packet():
|
||||
raise ValueError('Cannot create ' + str(self.__class__.__name__)
|
||||
+ ' object from invalid packet type')
|
||||
|
||||
self.packet = from_packet
|
||||
self.warning_count = self.packet.read(2)
|
||||
server_status = struct.unpack('<h', self.packet.read(2))[0]
|
||||
self.has_next = (server_status
|
||||
& SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if hasattr(self.packet, key):
|
||||
return getattr(self.packet, key)
|
||||
|
||||
raise AttributeError(str(self.__class__)
|
||||
+ " instance has no attribute '" + key + "'")
|
||||
|
||||
class Connection(object):
|
||||
"""
|
||||
@@ -482,12 +544,12 @@ class Connection(object):
|
||||
host = _config("host", host)
|
||||
db = _config("db",db)
|
||||
unix_socket = _config("socket",unix_socket)
|
||||
port = _config("port", port)
|
||||
port = int(_config("port", port))
|
||||
charset = _config("default-character-set", charset)
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.user = user
|
||||
self.user = user or DEFAULT_USER
|
||||
self.password = passwd
|
||||
self.db = db
|
||||
self.unix_socket = unix_socket
|
||||
@@ -498,7 +560,7 @@ class Connection(object):
|
||||
self.charset = DEFAULT_CHARSET
|
||||
self.use_unicode = False
|
||||
|
||||
if use_unicode:
|
||||
if use_unicode is not None:
|
||||
self.use_unicode = use_unicode
|
||||
|
||||
client_flag |= CAPABILITIES
|
||||
@@ -512,14 +574,15 @@ class Connection(object):
|
||||
|
||||
self._connect()
|
||||
|
||||
self._result = None
|
||||
self._affected_rows = 0
|
||||
self.host_info = "Not connected"
|
||||
|
||||
self.messages = []
|
||||
self.set_charset(charset)
|
||||
self.encoders = encoders
|
||||
self.decoders = conv
|
||||
|
||||
self._affected_rows = 0
|
||||
self.host_info = "Not connected"
|
||||
|
||||
self.autocommit(False)
|
||||
|
||||
if sql_mode is not None:
|
||||
@@ -537,10 +600,16 @@ class Connection(object):
|
||||
|
||||
def close(self):
|
||||
''' Send the quit message and close the socket '''
|
||||
if self.socket is None:
|
||||
raise Error("Already closed")
|
||||
send_data = struct.pack('<i',1) + int2byte(COM_QUIT)
|
||||
self.socket.send(send_data)
|
||||
self.wfile.write(send_data)
|
||||
self.wfile.close()
|
||||
self.rfile.close()
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
self.rfile = None
|
||||
self.wfile = None
|
||||
|
||||
def autocommit(self, value):
|
||||
''' Set whether or not to commit after every execute() '''
|
||||
@@ -578,8 +647,10 @@ class Connection(object):
|
||||
''' Alias for escape() '''
|
||||
return escape_item(obj, self.charset)
|
||||
|
||||
def cursor(self):
|
||||
def cursor(self, cursor=None):
|
||||
''' Create a new cursor to execute queries with '''
|
||||
if cursor:
|
||||
return cursor(self)
|
||||
return self.cursorclass(self)
|
||||
|
||||
def __enter__(self):
|
||||
@@ -594,9 +665,11 @@ class Connection(object):
|
||||
self.commit()
|
||||
|
||||
# The following methods are INTERNAL USE ONLY (called from Cursor)
|
||||
def query(self, sql):
|
||||
def query(self, sql, unbuffered=False):
|
||||
if DEBUG:
|
||||
print "sending query: %s" % sql
|
||||
self._execute_command(COM_QUERY, sql)
|
||||
self._affected_rows = self._read_query_result()
|
||||
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
|
||||
return self._affected_rows
|
||||
|
||||
def next_result(self):
|
||||
@@ -621,6 +694,8 @@ class Connection(object):
|
||||
''' Check if the server is alive '''
|
||||
try:
|
||||
self._execute_command(COM_PING, "")
|
||||
pkt = self.read_packet()
|
||||
return pkt.is_ok_packet()
|
||||
except:
|
||||
if reconnect:
|
||||
self._connect()
|
||||
@@ -630,9 +705,6 @@ class Connection(object):
|
||||
self.errorhandler(None, exc, value)
|
||||
return
|
||||
|
||||
pkt = self.read_packet()
|
||||
return pkt.is_ok_packet()
|
||||
|
||||
def set_charset(self, charset):
|
||||
try:
|
||||
if charset:
|
||||
@@ -663,64 +735,67 @@ class Connection(object):
|
||||
self.host_info = "socket %s:%d" % (self.host, self.port)
|
||||
if DEBUG: print 'connected using socket'
|
||||
self.socket = sock
|
||||
self.rfile = self.socket.makefile("rb")
|
||||
self.wfile = self.socket.makefile("wb")
|
||||
self._get_server_information()
|
||||
self._request_authentication()
|
||||
except socket.error, e:
|
||||
raise OperationalError(2003, "Can't connect to MySQL server on %r (%d)" % (self.host, e.args[0]))
|
||||
raise OperationalError(2003, "Can't connect to MySQL server on %r (%s)" % (self.host, e.args[0]))
|
||||
|
||||
def read_packet(self, packet_type=MysqlPacket):
|
||||
"""Read an entire "mysql packet" in its entirety from the network
|
||||
and return a MysqlPacket type that represents the results."""
|
||||
|
||||
# TODO: is socket.recv(small_number) significantly slower than
|
||||
# socket.recv(large_number)? if so, maybe we should buffer
|
||||
# the socket.recv() (though that obviously makes memory management
|
||||
# more complicated.
|
||||
packet = packet_type(self.socket)
|
||||
packet = packet_type(self)
|
||||
packet.check_error()
|
||||
return packet
|
||||
|
||||
def _read_query_result(self):
|
||||
result = MySQLResult(self)
|
||||
result.read()
|
||||
def _read_query_result(self, unbuffered=False):
|
||||
if unbuffered:
|
||||
try:
|
||||
result = MySQLResult(self)
|
||||
result.init_unbuffered_query()
|
||||
except:
|
||||
result.unbuffered_active = False
|
||||
raise
|
||||
else:
|
||||
result = MySQLResult(self)
|
||||
result.read()
|
||||
self._result = result
|
||||
return result.affected_rows
|
||||
|
||||
def insert_id(self):
|
||||
if self._result:
|
||||
return self._result.insert_id
|
||||
else:
|
||||
return 0
|
||||
|
||||
def _send_command(self, command, sql):
|
||||
#send_data = struct.pack('<i', len(sql) + 1) + command + sql
|
||||
# could probably be more efficient, at least it's correct
|
||||
if not self.socket:
|
||||
self.errorhandler(None, InterfaceError, "(0, '')")
|
||||
|
||||
# If the last query was unbuffered, make sure it finishes before
|
||||
# sending new commands
|
||||
if self._result is not None and self._result.unbuffered_active:
|
||||
self._result._finish_unbuffered_query()
|
||||
|
||||
if isinstance(sql, unicode):
|
||||
sql = sql.encode(self.charset)
|
||||
|
||||
buf = int2byte(command) + sql
|
||||
pckt_no = 0
|
||||
while len(buf) >= MAX_PACKET_LENGTH:
|
||||
header = struct.pack('<i', MAX_PACKET_LENGTH)[:-1]+int2byte(pckt_no)
|
||||
send_data = header + buf[:MAX_PACKET_LENGTH]
|
||||
self.socket.send(send_data)
|
||||
if DEBUG: dump_packet(send_data)
|
||||
buf = buf[MAX_PACKET_LENGTH:]
|
||||
pckt_no += 1
|
||||
header = struct.pack('<i', len(buf))[:-1]+int2byte(pckt_no)
|
||||
self.socket.send(header+buf)
|
||||
|
||||
|
||||
#sock = self.socket
|
||||
#sock.send(send_data)
|
||||
|
||||
#
|
||||
prelude = struct.pack('<i', len(sql)+1) + int2byte(command)
|
||||
self.wfile.write(prelude + sql)
|
||||
self.wfile.flush()
|
||||
if DEBUG: dump_packet(prelude + sql)
|
||||
|
||||
def _execute_command(self, command, sql):
|
||||
self._send_command(command, sql)
|
||||
|
||||
|
||||
def _request_authentication(self):
|
||||
self._send_authentication()
|
||||
|
||||
def _send_authentication(self):
|
||||
sock = self.socket
|
||||
self.client_flag |= CAPABILITIES
|
||||
if self.server_version.startswith('5'):
|
||||
self.client_flag |= MULTI_RESULTS
|
||||
@@ -742,12 +817,15 @@ class Connection(object):
|
||||
|
||||
if DEBUG: dump_packet(data)
|
||||
|
||||
sock.send(data)
|
||||
sock = self.socket = ssl.wrap_socket(sock, keyfile=self.key,
|
||||
certfile=self.cert,
|
||||
ssl_version=ssl.PROTOCOL_TLSv1,
|
||||
cert_reqs=ssl.CERT_REQUIRED,
|
||||
ca_certs=self.ca)
|
||||
self.wfile.write(data)
|
||||
self.wfile.flush()
|
||||
self.socket = ssl.wrap_self.socketet(self.socket, keyfile=self.key,
|
||||
certfile=self.cert,
|
||||
ssl_version=ssl.PROTOCOL_TLSv1,
|
||||
cert_reqs=ssl.CERT_REQUIRED,
|
||||
ca_certs=self.ca)
|
||||
self.rfile = self.socket.makefile("rb")
|
||||
self.wfile = self.socket.makefile("wb")
|
||||
|
||||
data = data_init + self.user+int2byte(0) + _scramble(self.password.encode(self.charset), self.salt)
|
||||
|
||||
@@ -760,9 +838,10 @@ class Connection(object):
|
||||
|
||||
if DEBUG: dump_packet(data)
|
||||
|
||||
sock.send(data)
|
||||
self.wfile.write(data)
|
||||
self.wfile.flush()
|
||||
|
||||
auth_packet = MysqlPacket(sock)
|
||||
auth_packet = MysqlPacket(self)
|
||||
auth_packet.check_error()
|
||||
if DEBUG: auth_packet.dump()
|
||||
|
||||
@@ -776,8 +855,9 @@ class Connection(object):
|
||||
data = _scramble_323(self.password.encode(self.charset), self.salt.encode(self.charset)) + int2byte(0)
|
||||
data = pack_int24(len(data)) + int2byte(next_packet) + data
|
||||
|
||||
sock.send(data)
|
||||
auth_packet = MysqlPacket(sock)
|
||||
self.wfile.write(data)
|
||||
self.wfile.flush()
|
||||
auth_packet = MysqlPacket(self)
|
||||
auth_packet.check_error()
|
||||
if DEBUG: auth_packet.dump()
|
||||
|
||||
@@ -796,9 +876,8 @@ class Connection(object):
|
||||
return self.protocol_version
|
||||
|
||||
def _get_server_information(self):
|
||||
sock = self.socket
|
||||
i = 0
|
||||
packet = MysqlPacket(sock)
|
||||
packet = MysqlPacket(self)
|
||||
data = packet.get_all_data()
|
||||
|
||||
if DEBUG: dump_packet(data)
|
||||
@@ -862,6 +941,11 @@ class MySQLResult(object):
|
||||
self.description = None
|
||||
self.rows = None
|
||||
self.has_next = None
|
||||
self.unbuffered_active = False
|
||||
|
||||
def __del__(self):
|
||||
if self.unbuffered_active:
|
||||
self._finish_unbuffered_query()
|
||||
|
||||
def read(self):
|
||||
self.first_packet = self.connection.read_packet()
|
||||
@@ -872,19 +956,78 @@ class MySQLResult(object):
|
||||
else:
|
||||
self._read_result_packet()
|
||||
|
||||
def init_unbuffered_query(self):
|
||||
self.unbuffered_active = True
|
||||
self.first_packet = self.connection.read_packet()
|
||||
|
||||
if self.first_packet.is_ok_packet():
|
||||
self._read_ok_packet()
|
||||
self.unbuffered_active = False
|
||||
else:
|
||||
self.field_count = byte2int(self.first_packet.read(1))
|
||||
self._get_descriptions()
|
||||
|
||||
# Apparently, MySQLdb picks this number because it's the maximum
|
||||
# value of a 64bit unsigned integer. Since we're emulating MySQLdb,
|
||||
# we set it to this instead of None, which would be preferred.
|
||||
self.affected_rows = 18446744073709551615
|
||||
|
||||
def _read_ok_packet(self):
|
||||
self.first_packet.advance(1) # field_count (always '0')
|
||||
self.affected_rows = self.first_packet.read_length_coded_binary()
|
||||
self.insert_id = self.first_packet.read_length_coded_binary()
|
||||
self.server_status = struct.unpack('<H', self.first_packet.read(2))[0]
|
||||
self.warning_count = struct.unpack('<H', self.first_packet.read(2))[0]
|
||||
self.message = self.first_packet.read_all()
|
||||
ok_packet = OKPacketWrapper(self.first_packet)
|
||||
self.affected_rows = ok_packet.affected_rows
|
||||
self.insert_id = ok_packet.insert_id
|
||||
self.server_status = ok_packet.server_status
|
||||
self.warning_count = ok_packet.warning_count
|
||||
self.message = ok_packet.message
|
||||
|
||||
def _check_packet_is_eof(self, packet):
|
||||
if packet.is_eof_packet():
|
||||
eof_packet = EOFPacketWrapper(packet)
|
||||
self.warning_count = eof_packet.warning_count
|
||||
self.has_next = eof_packet.has_next
|
||||
return True
|
||||
return False
|
||||
|
||||
def _read_result_packet(self):
|
||||
self.field_count = byte2int(self.first_packet.read(1))
|
||||
self._get_descriptions()
|
||||
self._read_rowdata_packet()
|
||||
|
||||
def _read_rowdata_packet_unbuffered(self):
|
||||
# Check if in an active query
|
||||
if self.unbuffered_active == False: return
|
||||
|
||||
# EOF
|
||||
packet = self.connection.read_packet()
|
||||
if self._check_packet_is_eof(packet):
|
||||
self.unbuffered_active = False
|
||||
self.rows = None
|
||||
return
|
||||
|
||||
row = []
|
||||
for field in self.fields:
|
||||
data = packet.read_length_coded_string()
|
||||
converted = None
|
||||
if field.type_code in self.connection.decoders:
|
||||
converter = self.connection.decoders[field.type_code]
|
||||
if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter)
|
||||
if data != None:
|
||||
converted = converter(self.connection, field, data)
|
||||
row.append(converted)
|
||||
|
||||
self.affected_rows = 1
|
||||
self.rows = tuple((row))
|
||||
if DEBUG: self.rows
|
||||
|
||||
def _finish_unbuffered_query(self):
|
||||
# After much reading on the MySQL protocol, it appears that there is,
|
||||
# in fact, no way to stop MySQL from sending all the data after
|
||||
# executing a query, so we just spin, and wait for an EOF packet.
|
||||
while self.unbuffered_active:
|
||||
packet = self.connection.read_packet()
|
||||
if self._check_packet_is_eof(packet):
|
||||
self.unbuffered_active = False
|
||||
|
||||
# TODO: implement this as an iteratable so that it is more
|
||||
# memory efficient and lower-latency to client...
|
||||
def _read_rowdata_packet(self):
|
||||
@@ -892,24 +1035,18 @@ class MySQLResult(object):
|
||||
rows = []
|
||||
while True:
|
||||
packet = self.connection.read_packet()
|
||||
if packet.is_eof_packet():
|
||||
self.warning_count = packet.read(2)
|
||||
server_status = struct.unpack('<h', packet.read(2))[0]
|
||||
self.has_next = (server_status
|
||||
& SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS)
|
||||
if self._check_packet_is_eof(packet):
|
||||
break
|
||||
|
||||
row = []
|
||||
for field in self.fields:
|
||||
data = packet.read_length_coded_string()
|
||||
converted = None
|
||||
if field.type_code in self.connection.decoders:
|
||||
converter = self.connection.decoders[field.type_code]
|
||||
|
||||
if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter)
|
||||
data = packet.read_length_coded_string()
|
||||
converted = None
|
||||
if data != None:
|
||||
converted = converter(self.connection, field, data)
|
||||
|
||||
row.append(converted)
|
||||
|
||||
rows.append(tuple(row))
|
||||
@@ -930,4 +1067,3 @@ class MySQLResult(object):
|
||||
eof_packet = self.connection.read_packet()
|
||||
assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'
|
||||
self.description = tuple(description)
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import re
|
||||
import datetime
|
||||
import time
|
||||
import sys
|
||||
|
||||
from constants import FIELD_TYPE, FLAG
|
||||
from charset import charset_by_id
|
||||
|
||||
PYTHON3 = sys.version_info[0] > 2
|
||||
|
||||
try:
|
||||
set
|
||||
except NameError:
|
||||
@@ -22,12 +25,12 @@ def escape_item(val, charset):
|
||||
return escape_sequence(val, charset)
|
||||
if type(val) is dict:
|
||||
return escape_dict(val, charset)
|
||||
if hasattr(val, "decode") and not isinstance(val, unicode):
|
||||
if PYTHON3 and hasattr(val, "decode") and not isinstance(val, unicode):
|
||||
# deal with py3k bytes
|
||||
val = val.decode(charset)
|
||||
encoder = encoders[type(val)]
|
||||
val = encoder(val)
|
||||
if type(val) is str:
|
||||
if type(val) in [str, int]:
|
||||
return val
|
||||
val = val.encode(charset)
|
||||
return val
|
||||
@@ -44,7 +47,7 @@ def escape_sequence(val, charset):
|
||||
for item in val:
|
||||
quoted = escape_item(item, charset)
|
||||
n.append(quoted)
|
||||
return tuple(n)
|
||||
return "(" + ",".join(n) + ")"
|
||||
|
||||
def escape_set(val, charset):
|
||||
val = map(lambda x: escape_item(x, charset), val)
|
||||
@@ -56,7 +59,10 @@ def escape_bool(value):
|
||||
def escape_object(value):
|
||||
return str(value)
|
||||
|
||||
escape_int = escape_long = escape_object
|
||||
def escape_int(value):
|
||||
return value
|
||||
|
||||
escape_long = escape_object
|
||||
|
||||
def escape_float(value):
|
||||
return ('%.15g' % value)
|
||||
@@ -142,16 +148,19 @@ def convert_timedelta(connection, field, obj):
|
||||
can accept values as (+|-)DD HH:MM:SS. The latter format will not
|
||||
be parsed correctly by this function.
|
||||
"""
|
||||
from math import modf
|
||||
try:
|
||||
microseconds = 0
|
||||
if not isinstance(obj, unicode):
|
||||
obj = obj.decode(connection.charset)
|
||||
hours, minutes, seconds = tuple([int(x) for x in obj.split(':')])
|
||||
if "." in obj:
|
||||
(obj, tail) = obj.split('.')
|
||||
microseconds = int(tail)
|
||||
hours, minutes, seconds = obj.split(':')
|
||||
tdelta = datetime.timedelta(
|
||||
hours = int(hours),
|
||||
minutes = int(minutes),
|
||||
seconds = int(seconds),
|
||||
microseconds = int(modf(float(seconds))[0]*1000000),
|
||||
microseconds = microseconds
|
||||
)
|
||||
return tdelta
|
||||
except ValueError:
|
||||
@@ -179,12 +188,14 @@ def convert_time(connection, field, obj):
|
||||
to be treated as time-of-day and not a time offset, then you can
|
||||
use set this function as the converter for FIELD_TYPE.TIME.
|
||||
"""
|
||||
from math import modf
|
||||
try:
|
||||
hour, minute, second = obj.split(':')
|
||||
return datetime.time(hour=int(hour), minute=int(minute),
|
||||
second=int(second),
|
||||
microsecond=int(modf(float(second))[0]*1000000))
|
||||
microseconds = 0
|
||||
if "." in obj:
|
||||
(obj, tail) = obj.split('.')
|
||||
microseconds = int(tail)
|
||||
hours, minutes, seconds = obj.split(':')
|
||||
return datetime.time(hour=int(hours), minute=int(minutes),
|
||||
second=int(seconds), microsecond=microseconds)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@@ -267,8 +278,6 @@ def convert_characters(connection, field, data):
|
||||
elif connection.charset != field_charset:
|
||||
data = data.decode(field_charset)
|
||||
data = data.encode(connection.charset)
|
||||
else:
|
||||
data = data.decode(connection.charset)
|
||||
return data
|
||||
|
||||
def convert_int(connection, field, data):
|
||||
@@ -334,6 +343,7 @@ try:
|
||||
# python version > 2.3
|
||||
from decimal import Decimal
|
||||
def convert_decimal(connection, field, data):
|
||||
data = data.decode(connection.charset)
|
||||
return Decimal(data)
|
||||
decoders[FIELD_TYPE.DECIMAL] = convert_decimal
|
||||
decoders[FIELD_TYPE.NEWDECIMAL] = convert_decimal
|
||||
@@ -344,4 +354,3 @@ try:
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -92,12 +92,21 @@ class Cursor(object):
|
||||
|
||||
# TODO: make sure that conn.escape is correct
|
||||
|
||||
if args is not None:
|
||||
query = query % conn.escape(args)
|
||||
|
||||
if isinstance(query, unicode):
|
||||
query = query.encode(charset)
|
||||
|
||||
if args is not None:
|
||||
if isinstance(args, tuple) or isinstance(args, list):
|
||||
escaped_args = tuple(conn.escape(arg) for arg in args)
|
||||
elif isinstance(args, dict):
|
||||
escaped_args = dict((key, conn.escape(val)) for (key, val) in args.items())
|
||||
else:
|
||||
#If it's not a dictionary let's try escaping it anyways.
|
||||
#Worst case it will throw a Value error
|
||||
escaped_args = conn.escape(args)
|
||||
|
||||
query = query % escaped_args
|
||||
|
||||
result = 0
|
||||
try:
|
||||
result = self._query(query)
|
||||
@@ -113,12 +122,12 @@ class Cursor(object):
|
||||
def executemany(self, query, args):
|
||||
''' Run several data against one query '''
|
||||
del self.messages[:]
|
||||
conn = self._get_db()
|
||||
#conn = self._get_db()
|
||||
if not args:
|
||||
return
|
||||
charset = conn.charset
|
||||
if isinstance(query, unicode):
|
||||
query = query.encode(charset)
|
||||
#charset = conn.charset
|
||||
#if isinstance(query, unicode):
|
||||
# query = query.encode(charset)
|
||||
|
||||
self.rowcount = sum([ self.execute(query, arg) for arg in args ])
|
||||
return self.rowcount
|
||||
@@ -231,12 +240,9 @@ class Cursor(object):
|
||||
self.lastrowid = conn._result.insert_id
|
||||
self._rows = conn._result.rows
|
||||
self._has_next = conn._result.has_next
|
||||
conn._result = None
|
||||
|
||||
def __iter__(self):
|
||||
self._check_executed()
|
||||
result = self.rownumber and self._rows[self.rownumber:] or self._rows
|
||||
return iter(result)
|
||||
return iter(self.fetchone, None)
|
||||
|
||||
Warning = Warning
|
||||
Error = Error
|
||||
@@ -249,3 +255,156 @@ class Cursor(object):
|
||||
ProgrammingError = ProgrammingError
|
||||
NotSupportedError = NotSupportedError
|
||||
|
||||
class DictCursor(Cursor):
|
||||
"""A cursor which returns results as a dictionary"""
|
||||
|
||||
def execute(self, query, args=None):
|
||||
result = super(DictCursor, self).execute(query, args)
|
||||
if self.description:
|
||||
self._fields = [ field[0] for field in self.description ]
|
||||
return result
|
||||
|
||||
def fetchone(self):
|
||||
''' Fetch the next row '''
|
||||
self._check_executed()
|
||||
if self._rows is None or self.rownumber >= len(self._rows):
|
||||
return None
|
||||
result = dict(zip(self._fields, self._rows[self.rownumber]))
|
||||
self.rownumber += 1
|
||||
return result
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
''' Fetch several rows '''
|
||||
self._check_executed()
|
||||
if self._rows is None:
|
||||
return None
|
||||
end = self.rownumber + (size or self.arraysize)
|
||||
result = [ dict(zip(self._fields, r)) for r in self._rows[self.rownumber:end] ]
|
||||
self.rownumber = min(end, len(self._rows))
|
||||
return tuple(result)
|
||||
|
||||
def fetchall(self):
|
||||
''' Fetch all the rows '''
|
||||
self._check_executed()
|
||||
if self._rows is None:
|
||||
return None
|
||||
if self.rownumber:
|
||||
result = [ dict(zip(self._fields, r)) for r in self._rows[self.rownumber:] ]
|
||||
else:
|
||||
result = [ dict(zip(self._fields, r)) for r in self._rows ]
|
||||
self.rownumber = len(self._rows)
|
||||
return tuple(result)
|
||||
|
||||
class SSCursor(Cursor):
|
||||
"""
|
||||
Unbuffered Cursor, mainly useful for queries that return a lot of data,
|
||||
or for connections to remote servers over a slow network.
|
||||
|
||||
Instead of copying every row of data into a buffer, this will fetch
|
||||
rows as needed. The upside of this, is the client uses much less memory,
|
||||
and rows are returned much faster when traveling over a slow network,
|
||||
or if the result set is very big.
|
||||
|
||||
There are limitations, though. The MySQL protocol doesn't support
|
||||
returning the total number of rows, so the only way to tell how many rows
|
||||
there are is to iterate over every row returned. Also, it currently isn't
|
||||
possible to scroll backwards, as only the current row is held in memory.
|
||||
"""
|
||||
|
||||
def close(self):
|
||||
conn = self._get_db()
|
||||
conn._result._finish_unbuffered_query()
|
||||
|
||||
try:
|
||||
if self._has_next:
|
||||
while self.nextset(): pass
|
||||
except: pass
|
||||
|
||||
def _query(self, q):
|
||||
conn = self._get_db()
|
||||
self._last_executed = q
|
||||
conn.query(q, unbuffered=True)
|
||||
self._do_get_result()
|
||||
return self.rowcount
|
||||
|
||||
def read_next(self):
|
||||
""" Read next row """
|
||||
|
||||
conn = self._get_db()
|
||||
conn._result._read_rowdata_packet_unbuffered()
|
||||
return conn._result.rows
|
||||
|
||||
def fetchone(self):
|
||||
""" Fetch next row """
|
||||
|
||||
self._check_executed()
|
||||
row = self.read_next()
|
||||
if row is None:
|
||||
return None
|
||||
self.rownumber += 1
|
||||
return row
|
||||
|
||||
def fetchall(self):
|
||||
"""
|
||||
Fetch all, as per MySQLdb. Pretty useless for large queries, as
|
||||
it is buffered. See fetchall_unbuffered(), if you want an unbuffered
|
||||
generator version of this method.
|
||||
"""
|
||||
|
||||
rows = []
|
||||
while True:
|
||||
row = self.fetchone()
|
||||
if row is None:
|
||||
break
|
||||
rows.append(row)
|
||||
return tuple(rows)
|
||||
|
||||
def fetchall_unbuffered(self):
|
||||
"""
|
||||
Fetch all, implemented as a generator, which isn't to standard,
|
||||
however, it doesn't make sense to return everything in a list, as that
|
||||
would use ridiculous memory for large result sets.
|
||||
"""
|
||||
|
||||
row = self.fetchone()
|
||||
while row is not None:
|
||||
yield row
|
||||
row = self.fetchone()
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
""" Fetch many """
|
||||
|
||||
self._check_executed()
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
|
||||
rows = []
|
||||
for i in range(0, size):
|
||||
row = self.read_next()
|
||||
if row is None:
|
||||
break
|
||||
rows.append(row)
|
||||
self.rownumber += 1
|
||||
return tuple(rows)
|
||||
|
||||
def scroll(self, value, mode='relative'):
|
||||
self._check_executed()
|
||||
if not mode == 'relative' and not mode == 'absolute':
|
||||
self.errorhandler(self, ProgrammingError,
|
||||
"unknown scroll mode %s" % mode)
|
||||
|
||||
if mode == 'relative':
|
||||
if value < 0:
|
||||
self.errorhandler(self, NotSupportedError,
|
||||
"Backwards scrolling not supported by this cursor")
|
||||
|
||||
for i in range(0, value): self.read_next()
|
||||
self.rownumber += value
|
||||
else:
|
||||
if value < self.rownumber:
|
||||
self.errorhandler(self, NotSupportedError,
|
||||
"Backwards scrolling not supported by this cursor")
|
||||
|
||||
end = value - self.rownumber
|
||||
for i in range(0, end): self.read_next()
|
||||
self.rownumber = value
|
||||
|
||||
@@ -2,20 +2,21 @@ import struct
|
||||
|
||||
|
||||
try:
|
||||
Exception, Warning
|
||||
StandardError, Warning
|
||||
except ImportError:
|
||||
try:
|
||||
from exceptions import Exception, Warning
|
||||
from exceptions import StandardError, Warning
|
||||
except ImportError:
|
||||
import sys
|
||||
e = sys.modules['exceptions']
|
||||
Exception = e.Exception
|
||||
StandardError = e.StandardError
|
||||
Warning = e.Warning
|
||||
|
||||
|
||||
from constants import ER
|
||||
import sys
|
||||
|
||||
class MySQLError(Exception):
|
||||
|
||||
class MySQLError(StandardError):
|
||||
|
||||
"""Exception related to operation with MySQL."""
|
||||
|
||||
|
||||
@@ -106,13 +107,19 @@ _map_error(IntegrityError, ER.DUP_ENTRY, ER.NO_REFERENCED_ROW,
|
||||
ER.CANNOT_ADD_FOREIGN)
|
||||
_map_error(NotSupportedError, ER.WARNING_NOT_COMPLETE_ROLLBACK,
|
||||
ER.NOT_SUPPORTED_YET, ER.FEATURE_DISABLED, ER.UNKNOWN_STORAGE_ENGINE)
|
||||
_map_error(OperationalError, ER.DBACCESS_DENIED_ERROR, ER.ACCESS_DENIED_ERROR,
|
||||
ER.TABLEACCESS_DENIED_ERROR, ER.COLUMNACCESS_DENIED_ERROR)
|
||||
|
||||
del _map_error, ER
|
||||
|
||||
|
||||
|
||||
def _get_error_info(data):
|
||||
errno = struct.unpack('<h', data[1:3])[0]
|
||||
if data[3] == "#":
|
||||
if sys.version_info[0] == 3:
|
||||
is_41 = data[3] == ord("#")
|
||||
else:
|
||||
is_41 = data[3] == "#"
|
||||
if is_41:
|
||||
# version 4.1
|
||||
sqlstate = data[4:9].decode("utf8")
|
||||
errorvalue = data[9:].decode("utf8")
|
||||
@@ -122,7 +129,7 @@ def _get_error_info(data):
|
||||
return (errno, None, data[3:].decode("utf8"))
|
||||
|
||||
def _check_mysql_exception(errinfo):
|
||||
errno, sqlstate, errorvalue = errinfo
|
||||
errno, sqlstate, errorvalue = errinfo
|
||||
errorclass = error_map.get(errno, None)
|
||||
if errorclass:
|
||||
raise errorclass, (errno,errorvalue)
|
||||
@@ -133,8 +140,7 @@ def _check_mysql_exception(errinfo):
|
||||
def raise_mysql_exception(data):
|
||||
errinfo = _get_error_info(data)
|
||||
_check_mysql_exception(errinfo)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
from gluon.contrib.pymysql.tests.test_issues import *
|
||||
from gluon.contrib.pymysql.tests.test_example import *
|
||||
from gluon.contrib.pymysql.tests.test_basic import *
|
||||
from pymysql.tests.test_issues import *
|
||||
from pymysql.tests.test_example import *
|
||||
from pymysql.tests.test_basic import *
|
||||
from pymysql.tests.test_DictCursor import *
|
||||
|
||||
import sys
|
||||
if sys.version_info[0] == 2:
|
||||
# MySQLdb tests were designed for Python 3
|
||||
from pymysql.tests.thirdparty import *
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
|
||||
@@ -1,20 +1,18 @@
|
||||
from gluon.contrib import pymysql
|
||||
import pymysql
|
||||
import unittest
|
||||
|
||||
class PyMySQLTestCase(unittest.TestCase):
|
||||
# Edit this to suit your test environment.
|
||||
databases = [
|
||||
{"host":"localhost","user":"root",
|
||||
"passwd":"","db":"test_pymysql", "use_unicode": True},
|
||||
{"host":"localhost","user":"root","passwd":"","db":"test_pymysql2"}]
|
||||
|
||||
def setUp(self):
|
||||
try:
|
||||
self.connections = []
|
||||
self.connections = []
|
||||
|
||||
for params in self.databases:
|
||||
self.connections.append(pymysql.connect(**params))
|
||||
except pymysql.err.OperationalError as e:
|
||||
self.skipTest('Cannot connect to MySQL - skipping pymysql tests because of (%s) %s' % (type(e), e))
|
||||
for params in self.databases:
|
||||
self.connections.append(pymysql.connect(**params))
|
||||
|
||||
def tearDown(self):
|
||||
for connection in self.connections:
|
||||
|
||||
+56
@@ -0,0 +1,56 @@
|
||||
from pymysql.tests import base
|
||||
import pymysql.cursors
|
||||
|
||||
import datetime
|
||||
|
||||
class TestDictCursor(base.PyMySQLTestCase):
|
||||
|
||||
def test_DictCursor(self):
|
||||
#all assert test compare to the structure as would come out from MySQLdb
|
||||
conn = self.connections[0]
|
||||
c = conn.cursor(pymysql.cursors.DictCursor)
|
||||
# create a table ane some data to query
|
||||
c.execute("""CREATE TABLE dictcursor (name char(20), age int , DOB datetime)""")
|
||||
data = (("bob",21,"1990-02-06 23:04:56"),
|
||||
("jim",56,"1955-05-09 13:12:45"),
|
||||
("fred",100,"1911-09-12 01:01:01"))
|
||||
bob = {'name':'bob','age':21,'DOB':datetime.datetime(1990, 02, 6, 23, 04, 56)}
|
||||
jim = {'name':'jim','age':56,'DOB':datetime.datetime(1955, 05, 9, 13, 12, 45)}
|
||||
fred = {'name':'fred','age':100,'DOB':datetime.datetime(1911, 9, 12, 1, 1, 1)}
|
||||
try:
|
||||
c.executemany("insert into dictcursor values (%s,%s,%s)", data)
|
||||
# try an update which should return no rows
|
||||
c.execute("update dictcursor set age=20 where name='bob'")
|
||||
bob['age'] = 20
|
||||
# pull back the single row dict for bob and check
|
||||
c.execute("SELECT * from dictcursor where name='bob'")
|
||||
r = c.fetchone()
|
||||
self.assertEqual(bob,r,"fetchone via DictCursor failed")
|
||||
# same again, but via fetchall => tuple)
|
||||
c.execute("SELECT * from dictcursor where name='bob'")
|
||||
r = c.fetchall()
|
||||
self.assertEqual((bob,),r,"fetch a 1 row result via fetchall failed via DictCursor")
|
||||
# same test again but iterate over the
|
||||
c.execute("SELECT * from dictcursor where name='bob'")
|
||||
for r in c:
|
||||
self.assertEqual(bob, r,"fetch a 1 row result via iteration failed via DictCursor")
|
||||
# get all 3 row via fetchall
|
||||
c.execute("SELECT * from dictcursor")
|
||||
r = c.fetchall()
|
||||
self.assertEqual((bob,jim,fred), r, "fetchall failed via DictCursor")
|
||||
#same test again but do a list comprehension
|
||||
c.execute("SELECT * from dictcursor")
|
||||
r = [x for x in c]
|
||||
self.assertEqual([bob,jim,fred], r, "list comprehension failed via DictCursor")
|
||||
# get all 2 row via fetchmany
|
||||
c.execute("SELECT * from dictcursor")
|
||||
r = c.fetchmany(2)
|
||||
self.assertEqual((bob,jim), r, "fetchmany failed via DictCursor")
|
||||
finally:
|
||||
c.execute("drop table dictcursor")
|
||||
|
||||
__all__ = ["TestDictCursor"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
unittest.main()
|
||||
Executable
+100
@@ -0,0 +1,100 @@
|
||||
import sys
|
||||
|
||||
try:
|
||||
from pymysql.tests import base
|
||||
import pymysql.cursors
|
||||
except:
|
||||
# For local testing from top-level directory, without installing
|
||||
sys.path.append('../pymysql')
|
||||
from pymysql.tests import base
|
||||
import pymysql.cursors
|
||||
|
||||
class TestSSCursor(base.PyMySQLTestCase):
|
||||
def test_SSCursor(self):
|
||||
affected_rows = 18446744073709551615
|
||||
|
||||
conn = self.connections[0]
|
||||
data = [
|
||||
('America', '', 'America/Jamaica'),
|
||||
('America', '', 'America/Los_Angeles'),
|
||||
('America', '', 'America/Lima'),
|
||||
('America', '', 'America/New_York'),
|
||||
('America', '', 'America/Menominee'),
|
||||
('America', '', 'America/Havana'),
|
||||
('America', '', 'America/El_Salvador'),
|
||||
('America', '', 'America/Costa_Rica'),
|
||||
('America', '', 'America/Denver'),
|
||||
('America', '', 'America/Detroit'),]
|
||||
|
||||
try:
|
||||
cursor = conn.cursor(pymysql.cursors.SSCursor)
|
||||
|
||||
# Create table
|
||||
cursor.execute(('CREATE TABLE tz_data ('
|
||||
'region VARCHAR(64),'
|
||||
'zone VARCHAR(64),'
|
||||
'name VARCHAR(64))'))
|
||||
|
||||
# Test INSERT
|
||||
for i in data:
|
||||
cursor.execute('INSERT INTO tz_data VALUES (%s, %s, %s)', i)
|
||||
self.assertEqual(conn.affected_rows(), 1, 'affected_rows does not match')
|
||||
conn.commit()
|
||||
|
||||
# Test fetchone()
|
||||
iter = 0
|
||||
cursor.execute('SELECT * FROM tz_data')
|
||||
while True:
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
break
|
||||
iter += 1
|
||||
|
||||
# Test cursor.rowcount
|
||||
self.assertEqual(cursor.rowcount, affected_rows,
|
||||
'cursor.rowcount != %s' % (str(affected_rows)))
|
||||
|
||||
# Test cursor.rownumber
|
||||
self.assertEqual(cursor.rownumber, iter,
|
||||
'cursor.rowcount != %s' % (str(iter)))
|
||||
|
||||
# Test row came out the same as it went in
|
||||
self.assertEqual((row in data), True,
|
||||
'Row not found in source data')
|
||||
|
||||
# Test fetchall
|
||||
cursor.execute('SELECT * FROM tz_data')
|
||||
self.assertEqual(len(cursor.fetchall()), len(data),
|
||||
'fetchall failed. Number of rows does not match')
|
||||
|
||||
# Test fetchmany
|
||||
cursor.execute('SELECT * FROM tz_data')
|
||||
self.assertEqual(len(cursor.fetchmany(2)), 2,
|
||||
'fetchmany failed. Number of rows does not match')
|
||||
|
||||
# So MySQLdb won't throw "Commands out of sync"
|
||||
while True:
|
||||
res = cursor.fetchone()
|
||||
if res is None:
|
||||
break
|
||||
|
||||
# Test update, affected_rows()
|
||||
cursor.execute('UPDATE tz_data SET zone = %s', ['Foo'])
|
||||
conn.commit()
|
||||
self.assertEqual(cursor.rowcount, len(data),
|
||||
'Update failed. affected_rows != %s' % (str(len(data))))
|
||||
|
||||
# Test executemany
|
||||
cursor.executemany('INSERT INTO tz_data VALUES (%s, %s, %s)', data)
|
||||
self.assertEqual(cursor.rowcount, len(data),
|
||||
'executemany failed. cursor.rowcount != %s' % (str(len(data))))
|
||||
|
||||
finally:
|
||||
cursor.execute('DROP TABLE tz_data')
|
||||
cursor.close()
|
||||
|
||||
__all__ = ["TestSSCursor"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
unittest.main()
|
||||
@@ -1,5 +1,5 @@
|
||||
from gluon.contrib.pymysql.tests import base
|
||||
from gluon.contrib.pymysql import util
|
||||
from pymysql.tests import base
|
||||
from pymysql import util
|
||||
|
||||
import time
|
||||
import datetime
|
||||
@@ -55,6 +55,31 @@ class TestConversion(base.PyMySQLTestCase):
|
||||
finally:
|
||||
c.execute("drop table test_dict")
|
||||
|
||||
def test_string(self):
|
||||
conn = self.connections[0]
|
||||
c = conn.cursor()
|
||||
c.execute("create table test_dict (a text)")
|
||||
test_value = "I am a test string"
|
||||
try:
|
||||
c.execute("insert into test_dict (a) values (%s)", test_value)
|
||||
c.execute("select a from test_dict")
|
||||
self.assertEqual((test_value,), c.fetchone())
|
||||
finally:
|
||||
c.execute("drop table test_dict")
|
||||
|
||||
def test_integer(self):
|
||||
conn = self.connections[0]
|
||||
c = conn.cursor()
|
||||
c.execute("create table test_dict (a integer)")
|
||||
test_value = 12345
|
||||
try:
|
||||
c.execute("insert into test_dict (a) values (%s)", test_value)
|
||||
c.execute("select a from test_dict")
|
||||
self.assertEqual((test_value,), c.fetchone())
|
||||
finally:
|
||||
c.execute("drop table test_dict")
|
||||
|
||||
|
||||
def test_big_blob(self):
|
||||
""" test tons of data """
|
||||
conn = self.connections[0]
|
||||
@@ -67,6 +92,26 @@ class TestConversion(base.PyMySQLTestCase):
|
||||
self.assertEqual(data.encode(conn.charset), c.fetchone()[0])
|
||||
finally:
|
||||
c.execute("drop table test_big_blob")
|
||||
|
||||
def test_untyped(self):
|
||||
""" test conversion of null, empty string """
|
||||
conn = self.connections[0]
|
||||
c = conn.cursor()
|
||||
c.execute("select null,''")
|
||||
self.assertEqual((None,u''), c.fetchone())
|
||||
c.execute("select '',null")
|
||||
self.assertEqual((u'',None), c.fetchone())
|
||||
|
||||
def test_datetime(self):
|
||||
""" test conversion of null, empty string """
|
||||
conn = self.connections[0]
|
||||
c = conn.cursor()
|
||||
c.execute("select time('12:30'), time('23:12:59'), time('23:12:59.05100')")
|
||||
self.assertEqual((datetime.timedelta(0, 45000),
|
||||
datetime.timedelta(0, 83579),
|
||||
datetime.timedelta(0, 83579, 51000)),
|
||||
c.fetchone())
|
||||
|
||||
|
||||
class TestCursor(base.PyMySQLTestCase):
|
||||
# this test case does not work quite right yet, however,
|
||||
@@ -134,6 +179,33 @@ class TestCursor(base.PyMySQLTestCase):
|
||||
finally:
|
||||
c.execute("drop table test_nr")
|
||||
|
||||
def test_aggregates(self):
|
||||
""" test aggregate functions """
|
||||
conn = self.connections[0]
|
||||
c = conn.cursor()
|
||||
try:
|
||||
c.execute('create table test_aggregates (i integer)')
|
||||
for i in xrange(0, 10):
|
||||
c.execute('insert into test_aggregates (i) values (%s)', (i,))
|
||||
c.execute('select sum(i) from test_aggregates')
|
||||
r, = c.fetchone()
|
||||
self.assertEqual(sum(range(0,10)), r)
|
||||
finally:
|
||||
c.execute('drop table test_aggregates')
|
||||
|
||||
def test_single_tuple(self):
|
||||
""" test a single tuple """
|
||||
conn = self.connections[0]
|
||||
c = conn.cursor()
|
||||
try:
|
||||
c.execute("create table mystuff (id integer primary key)")
|
||||
c.execute("insert into mystuff (id) values (1)")
|
||||
c.execute("insert into mystuff (id) values (2)")
|
||||
c.execute("select id from mystuff where id in %s", ((1,),))
|
||||
self.assertEqual([(1,)], list(c.fetchall()))
|
||||
finally:
|
||||
c.execute("drop table mystuff")
|
||||
|
||||
__all__ = ["TestConversion","TestCursor"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from gluon.contrib import pymysql
|
||||
from gluon.contrib.pymysql.tests import base
|
||||
import pymysql
|
||||
from pymysql.tests import base
|
||||
|
||||
class TestExample(base.PyMySQLTestCase):
|
||||
def test_example(self):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from gluon.contrib import pymysql
|
||||
from gluon.contrib.pymysql.tests import base
|
||||
import pymysql
|
||||
from pymysql.tests import base
|
||||
import unittest
|
||||
|
||||
import sys
|
||||
|
||||
@@ -11,6 +12,10 @@ except AttributeError:
|
||||
|
||||
import datetime
|
||||
|
||||
# backwards compatibility:
|
||||
if not hasattr(unittest, "skip"):
|
||||
unittest.skip = lambda message: lambda f: f
|
||||
|
||||
class TestOldIssues(base.PyMySQLTestCase):
|
||||
def test_issue_3(self):
|
||||
""" undefined methods datetime_or_None, date_or_None """
|
||||
@@ -89,15 +94,15 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
|
||||
""" can't handle large result fields """
|
||||
conn = self.connections[0]
|
||||
cur = conn.cursor()
|
||||
cur.execute("create table issue13 (t text)")
|
||||
try:
|
||||
cur.execute("create table issue13 (t text)")
|
||||
# ticket says 18k
|
||||
size = 18*1024
|
||||
cur.execute("insert into issue13 (t) values (%s)", ("x" * size,))
|
||||
cur.execute("select t from issue13")
|
||||
# use assert_ so that obscenely huge error messages don't print
|
||||
# use assertTrue so that obscenely huge error messages don't print
|
||||
r = cur.fetchone()[0]
|
||||
self.assert_("x" * size == r)
|
||||
self.assertTrue("x" * size == r)
|
||||
finally:
|
||||
cur.execute("drop table issue13")
|
||||
|
||||
@@ -115,7 +120,7 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
|
||||
c = conn.cursor()
|
||||
c.execute("create table issue15 (t varchar(32))")
|
||||
try:
|
||||
c.execute("insert into issue15 (t) values (%s)", (u'\xe4\xf6\xfc'))
|
||||
c.execute("insert into issue15 (t) values (%s)", (u'\xe4\xf6\xfc',))
|
||||
c.execute("select t from issue15")
|
||||
self.assertEqual(u'\xe4\xf6\xfc', c.fetchone()[0])
|
||||
finally:
|
||||
@@ -133,6 +138,7 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
|
||||
finally:
|
||||
c.execute("drop table issue16")
|
||||
|
||||
@unittest.skip("test_issue_17() requires a custom, legacy MySQL configuration and will not be run.")
|
||||
def test_issue_17(self):
|
||||
""" could not connect mysql use passwod """
|
||||
conn = self.connections[0]
|
||||
@@ -181,23 +187,22 @@ class TestNewIssues(base.PyMySQLTestCase):
|
||||
finally:
|
||||
c.execute(_uni("drop table hei\xc3\x9fe", "utf8"))
|
||||
|
||||
# Will fail without manual intervention:
|
||||
#def test_issue_35(self):
|
||||
#
|
||||
# conn = self.connections[0]
|
||||
# c = conn.cursor()
|
||||
# print "sudo killall -9 mysqld within the next 10 seconds"
|
||||
# try:
|
||||
# c.execute("select sleep(10)")
|
||||
# self.fail()
|
||||
# except pymysql.OperationalError, e:
|
||||
# self.assertEqual(2013, e.args[0])
|
||||
@unittest.skip("This test requires manual intervention")
|
||||
def test_issue_35(self):
|
||||
conn = self.connections[0]
|
||||
c = conn.cursor()
|
||||
print "sudo killall -9 mysqld within the next 10 seconds"
|
||||
try:
|
||||
c.execute("select sleep(10)")
|
||||
self.fail()
|
||||
except pymysql.OperationalError, e:
|
||||
self.assertEqual(2013, e.args[0])
|
||||
|
||||
def test_issue_36(self):
|
||||
conn = self.connections[0]
|
||||
c = conn.cursor()
|
||||
# kill connections[0]
|
||||
original_count = c.execute("show processlist")
|
||||
c.execute("show processlist")
|
||||
kill_id = None
|
||||
for id,user,host,db,command,time,state,info in c.fetchall():
|
||||
if info == "show processlist":
|
||||
@@ -212,8 +217,13 @@ class TestNewIssues(base.PyMySQLTestCase):
|
||||
except:
|
||||
pass
|
||||
# check the process list from the other connection
|
||||
self.assertEqual(original_count - 1, self.connections[1].cursor().execute("show processlist"))
|
||||
del self.connections[0]
|
||||
try:
|
||||
c = self.connections[1].cursor()
|
||||
c.execute("show processlist")
|
||||
ids = [row[0] for row in c.fetchall()]
|
||||
self.assertFalse(kill_id in ids)
|
||||
finally:
|
||||
del self.connections[0]
|
||||
|
||||
def test_issue_37(self):
|
||||
conn = self.connections[0]
|
||||
@@ -230,10 +240,38 @@ class TestNewIssues(base.PyMySQLTestCase):
|
||||
|
||||
try:
|
||||
c.execute("create table issue38 (id integer, data mediumblob)")
|
||||
c.execute("insert into issue38 values (1, %s)", datum)
|
||||
c.execute("insert into issue38 values (1, %s)", (datum,))
|
||||
finally:
|
||||
c.execute("drop table issue38")
|
||||
__all__ = ["TestOldIssues", "TestNewIssues"]
|
||||
|
||||
def disabled_test_issue_54(self):
|
||||
conn = self.connections[0]
|
||||
c = conn.cursor()
|
||||
big_sql = "select * from issue54 where "
|
||||
big_sql += " and ".join("%d=%d" % (i,i) for i in xrange(0, 100000))
|
||||
|
||||
try:
|
||||
c.execute("create table issue54 (id integer primary key)")
|
||||
c.execute("insert into issue54 (id) values (7)")
|
||||
c.execute(big_sql)
|
||||
self.assertEqual(7, c.fetchone()[0])
|
||||
finally:
|
||||
c.execute("drop table issue54")
|
||||
|
||||
class TestGitHubIssues(base.PyMySQLTestCase):
|
||||
def test_issue_66(self):
|
||||
conn = self.connections[0]
|
||||
c = conn.cursor()
|
||||
self.assertEqual(0, conn.insert_id())
|
||||
try:
|
||||
c.execute("create table issue66 (id integer primary key auto_increment, x integer)")
|
||||
c.execute("insert into issue66 (x) values (1)")
|
||||
c.execute("insert into issue66 (x) values (1)")
|
||||
self.assertEqual(2, conn.insert_id())
|
||||
finally:
|
||||
c.execute("drop table issue66")
|
||||
|
||||
__all__ = ["TestOldIssues", "TestNewIssues", "TestGitHubIssues"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
|
||||
@@ -14,4 +14,3 @@ def TimeFromTicks(ticks):
|
||||
|
||||
def TimestampFromTicks(ticks):
|
||||
return datetime(*localtime(ticks)[:6])
|
||||
|
||||
|
||||
@@ -17,4 +17,3 @@ def join_bytes(bs):
|
||||
for b in bs[1:]:
|
||||
rv += b
|
||||
return rv
|
||||
|
||||
|
||||
Reference in New Issue
Block a user