Remove sqlalchemy and elixir

This commit is contained in:
Ruud
2014-02-10 23:41:54 +01:00
parent 4b356aba3e
commit 8724076601
218 changed files with 14 additions and 94000 deletions

View File

@@ -181,7 +181,7 @@ class MediaPlugin(MediaBase):
'media': media,
}
def list(self, types = None, status = None, release_status = None, limit_offset = None, starts_with = None, search = None, order = None):
def list(self, types = None, status = None, release_status = None, limit_offset = None, starts_with = None, search = None):
db = get_db()
@@ -270,7 +270,6 @@ class MediaPlugin(MediaBase):
limit_offset = kwargs.get('limit_offset')
starts_with = kwargs.get('starts_with')
search = kwargs.get('search')
order = kwargs.get('order')
total_movies, movies = self.list(
types = types,
@@ -278,8 +277,7 @@ class MediaPlugin(MediaBase):
release_status = release_status,
limit_offset = limit_offset,
starts_with = starts_with,
search = search,
order = order
search = search
)
return {

View File

@@ -1,4 +0,0 @@
[db_settings]
repository_id = CouchPotato
version_table = migrate_version
required_dbs = ['sqlite']

View File

@@ -1,25 +0,0 @@
from migrate.changeset.schema import create_column
from sqlalchemy.schema import MetaData, Column, Table, Index
from sqlalchemy.types import Integer
meta = MetaData()
def upgrade(migrate_engine):
meta.bind = migrate_engine
# Change release, add last_edit and index
last_edit_column = Column('last_edit', Integer)
release = Table('release', meta, last_edit_column)
create_column(last_edit_column, release)
Index('ix_release_last_edit', release.c.last_edit).create()
# Change movie last_edit
last_edit_column = Column('last_edit', Integer)
movie = Table('movie', meta, last_edit_column)
Index('ix_movie_last_edit', movie.c.last_edit).create()
def downgrade(migrate_engine):
pass

View File

@@ -1,18 +0,0 @@
from migrate.changeset.schema import create_column
from sqlalchemy.schema import MetaData, Column, Table, Index
from sqlalchemy.types import Integer
meta = MetaData()
def upgrade(migrate_engine):
meta.bind = migrate_engine
category_column = Column('category_id', Integer)
movie = Table('movie', meta, category_column)
create_column(category_column, movie)
Index('ix_movie_category_id', movie.c.category_id).create()
def downgrade(migrate_engine):
pass

View File

@@ -1,30 +0,0 @@
"""
Examples
Adding a column:
from migrate import *
from migrate.changeset.schema import create_column
from sqlalchemy import *
meta = MetaData()
def upgrade(migrate_engine):
meta.bind = migrate_engine
#print changeset.schema
path_column = Column('path', String)
resource = Table('resource', meta, path_column)
create_column(path_column, resource)
Adding Relation table: http://www.mail-archive.com/sqlelixir@googlegroups.com/msg02061.html
person = Table('person', metadata, Column('id', Integer))
person_column = Column('person_id', Integer, ForeignKey('person.id'), nullable=False)
movie = Table('movie', metadata, person_column)
person_constraint = ForeignKeyConstraint(['person_id'], ['person.id'], ondelete="restrict", table=movie)
"""

View File

@@ -86,13 +86,13 @@ class XBMC(Notification):
# send the text message
resp = self.notifyXBMCnoJSON(host, {'title':self.default_title, 'message':message})
for result in resp:
if result.get('result') and result['result'] == 'OK':
for r in resp:
if r.get('result') and r['result'] == 'OK':
log.debug('Message delivered successfully!')
success = True
break
elif result.get('error'):
log.error('XBMC error; %s: %s (%s)', (result['id'], result['error']['message'], result['error']['code']))
elif r.get('error'):
log.error('XBMC error; %s: %s (%s)', (r['id'], r['error']['message'], r['error']['code']))
break
elif result.get('result') and type(result['result']['version']).__name__ == 'dict':
@@ -109,13 +109,13 @@ class XBMC(Notification):
# send the text message
resp = self.request(host, [('GUI.ShowNotification', {'title':self.default_title, 'message':message, 'image': self.getNotificationImage('small')})])
for result in resp:
if result.get('result') and result['result'] == 'OK':
for r in resp:
if r.get('result') and r['result'] == 'OK':
log.debug('Message delivered successfully!')
success = True
break
elif result.get('error'):
log.error('XBMC error; %s: %s (%s)', (result['id'], result['error']['message'], result['error']['code']))
elif r.get('error'):
log.error('XBMC error; %s: %s (%s)', (r['id'], r['error']['message'], r['error']['code']))
break
# error getting version info (we do have contact with XBMC though)

View File

@@ -110,7 +110,7 @@ class Plugin(object):
f.write(content)
f.close()
os.chmod(path, Env.getPermission('file'))
except Exception as e:
except:
log.error('Unable writing to file "%s": %s', (path, traceback.format_exc()))
if os.path.isfile(path):
os.remove(path)

View File

@@ -202,10 +202,9 @@ class QualityPlugin(Plugin):
# Try again with loose testing
for quality in qualities:
loose_score = self.guessLooseScore(quality, files = files, extra = extra)
loose_score = self.guessLooseScore(quality, extra = extra)
self.calcScore(score, quality, loose_score)
# Return nothing if all scores are 0
has_non_zero = 0
for s in score:
@@ -262,7 +261,7 @@ class QualityPlugin(Plugin):
return score
def guessLooseScore(self, quality, files = None, extra = None):
def guessLooseScore(self, quality, extra = None):
score = 0

View File

@@ -8,7 +8,6 @@ from couchpotato.core.plugins.base import Plugin
from .index import ReleaseIndex, ReleaseStatusIndex, ReleaseIDIndex, ReleaseDownloadIndex
from couchpotato.environment import Env
from inspect import ismethod, isfunction
from sqlalchemy.exc import InterfaceError
import os
import time
import traceback
@@ -395,7 +394,7 @@ class Release(Plugin):
continue
rls['info'][info] = toUnicode(rel[info])
except InterfaceError:
except:
log.debug('Couldn\'t add %s to ReleaseInfo: %s', (info, traceback.format_exc()))
db.update(rls)

View File

@@ -1,505 +0,0 @@
import uuid
import datetime
from sqlalchemy import Column, ForeignKey, Table, Index
from sqlalchemy.ext.associationproxy import AssociationProxy
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import relationship, object_mapper, ColumnProperty, class_mapper
from sqlalchemy.orm.exc import UnmappedInstanceError
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from couchpotato.core.helpers.encoding import toUnicode
from sqlalchemy.ext.mutable import Mutable
from sqlalchemy.types import Integer, Unicode, UnicodeText, Boolean, String, \
TypeDecorator
import json
import time
class SetEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj)
return json.JSONEncoder.default(self, obj)
class JsonType(TypeDecorator):
impl = UnicodeText
def process_bind_param(self, value, dialect):
try:
return toUnicode(json.dumps(value, cls = SetEncoder))
except:
try:
return toUnicode(json.dumps(value, cls = SetEncoder, encoding = 'latin-1'))
except:
raise
def process_result_value(self, value, dialect):
return json.loads(value if value else '{}')
class MutableDict(Mutable, dict):
@classmethod
def coerce(cls, key, value):
if not isinstance(value, MutableDict):
if isinstance(value, dict):
return MutableDict(value)
return Mutable.coerce(key, value)
else:
return value
def __delitem(self, key):
dict.__delitem__(self, key)
self.changed()
def __setitem__(self, key, value):
dict.__setitem__(self, key, value)
self.changed()
def __getstate__(self):
return dict(self)
def __setstate__(self, state):
self.update(self)
def update(self, *args, **kwargs):
super(MutableDict, self).update(*args, **kwargs)
self.changed()
MutableDict.associate_with(JsonType)
Base = declarative_base()
COLUMN_BLACKLIST = ('_sa_polymorphic_on', )
def is_mapped_class(cls):
try:
class_mapper(cls)
return True
except:
return False
def is_like_list(instance, relation):
"""Returns ``True`` if and only if the relation of `instance` whose name is
`relation` is list-like.
A relation may be like a list if, for example, it is a non-lazy one-to-many
relation, or it is a dynamically loaded one-to-many.
"""
if relation in instance._sa_class_manager:
return instance._sa_class_manager[relation].property.uselist
related_value = getattr(type(instance), relation, None)
return isinstance(related_value, AssociationProxy)
class TableHelper():
def to_dict(self, deep = None, exclude = None, include = None,
exclude_relations = None, include_relations = None,
include_methods = None):
instance = self
if (exclude is not None or exclude_relations is not None) and \
(include is not None or include_relations is not None):
raise ValueError('Cannot specify both include and exclude.')
# create a list of names of columns, including hybrid properties
try:
columns = [p.key for p in object_mapper(instance).iterate_properties
if isinstance(p, ColumnProperty)]
except UnmappedInstanceError:
return instance
for parent in type(instance).mro():
columns += [key for key, value in parent.__dict__.items()
if isinstance(value, hybrid_property)]
# filter the columns based on exclude and include values
if exclude is not None:
columns = (c for c in columns if c not in exclude)
elif include is not None:
columns = (c for c in columns if c in include)
# create a dictionary mapping column name to value
result = dict((col, getattr(instance, col)) for col in columns
if not (col.startswith('__') or col in COLUMN_BLACKLIST))
# add any included methods
if include_methods is not None:
result.update(dict((method, getattr(instance, method)()) for method in include_methods if not '.' in method))
# Check for objects in the dictionary that may not be serializable by
# default. Specifically, convert datetime and date objects to ISO 8601
# format, and convert UUID objects to hexadecimal strings.
for key, value in result.items():
# TODO We can get rid of this when issue #33 is resolved.
if isinstance(value, datetime.date):
result[key] = value.isoformat()
elif isinstance(value, uuid.UUID):
result[key] = str(value)
elif is_mapped_class(type(value)):
result[key] = value.to_dict()
# recursively call _to_dict on each of the `deep` relations
deep = deep or {}
for relation, rdeep in deep.items():
# Get the related value so we can see if it is None, a list, a query
# (as specified by a dynamic relationship loader), or an actual
# instance of a model.
relatedvalue = getattr(instance, relation)
if relatedvalue is None:
result[relation] = None
continue
# Determine the included and excluded fields for the related model.
newexclude = None
newinclude = None
if exclude_relations is not None and relation in exclude_relations:
newexclude = exclude_relations[relation]
elif (include_relations is not None and
relation in include_relations):
newinclude = include_relations[relation]
# Determine the included methods for the related model.
newmethods = None
if include_methods is not None:
newmethods = [method.split('.', 1)[1] for method in include_methods
if method.split('.', 1)[0] == relation]
if is_like_list(instance, relation):
result[relation] = [inst.to_dict(rdeep, exclude = newexclude,
include = newinclude,
include_methods = newmethods)
for inst in relatedvalue]
continue
# If the related value is dynamically loaded, resolve the query to get
# the single instance.
if isinstance(relatedvalue, Query):
relatedvalue = relatedvalue.one()
result[relation] = relatedvalue.to_dict(rdeep, exclude = newexclude,
include = newinclude,
include_methods = newmethods)
return result
movie_files = Table('movie_files__file_movie', Base.metadata,
Column('movie_id', Integer, ForeignKey('movie.id'), nullable = False),
Column('file_id', Integer, ForeignKey('file.id'), nullable = False),
Index('movie_files_idx', 'movie_id', 'file_id', unique = True)
)
release_files = Table('release_files__file_release', Base.metadata,
Column('release_id', Integer, ForeignKey('release.id'), nullable = False),
Column('file_id', Integer, ForeignKey('file.id'), nullable = False),
Index('release_files_idx', 'release_id', 'file_id', unique = True)
)
library_files = Table('library_files__file_library', Base.metadata,
Column('library_id', Integer, ForeignKey('library.id'), nullable = False),
Column('file_id', Integer, ForeignKey('file.id'), nullable = False),
Index('library_files_idx', 'library_id', 'file_id', unique = True)
)
class Movie(Base, TableHelper):
__tablename__ = 'movie'
id = Column(Integer, primary_key = True)
"""Movie Resource a movie could have multiple releases
The files belonging to the movie object are global for the whole movie
such as trailers, nfo, thumbnails"""
last_edit = Column(Integer, default = lambda: int(time.time()), index = True)
type = 'movie' # Compat tv branch
library_id = Column(Integer, ForeignKey('library.id'), index = True)
status_id = Column(Integer, ForeignKey('status.id'), index = True)
profile_id = Column(Integer, ForeignKey('profile.id'), index = True)
category_id = Column(Integer, ForeignKey('category.id'), index = True)
library = relationship('Library') #cascade = 'delete, delete-orphan', single_parent = True)
status = relationship('Status')
profile = relationship('Profile')
category = relationship('Category')
releases = relationship('Release') #, cascade = 'all, delete-orphan')
files = relationship('File', secondary = movie_files) #, cascade = 'all, delete-orphan', single_parent = True)
Media = Movie # Compat tv branch
class Library(Base, TableHelper):
__tablename__ = 'library'
id = Column(Integer, primary_key = True)
""""""
year = Column(Integer)
identifier = Column(String(20), index = True)
plot = Column(UnicodeText)
tagline = Column(UnicodeText(255))
info = Column(JsonType)
status_id = Column(Integer, ForeignKey('status.id'), index = True)
status = relationship('Status')
movies = relationship('Movie') #, cascade = 'all, delete-orphan')
titles = relationship('LibraryTitle', order_by="desc(LibraryTitle.default)") #, cascade = 'all, delete-orphan')
files = relationship('File', secondary = library_files) #, cascade = 'all, delete-orphan', single_parent = True)
class LibraryTitle(Base, TableHelper):
__tablename__ = 'librarytitle'
id = Column(Integer, primary_key = True)
""""""
#using_options(order_by = '-default')
title = Column(Unicode)
simple_title = Column(Unicode, index = True)
default = Column(Boolean, default = False, index = True)
language = relationship('Language')
libraries_id = Column(Integer, ForeignKey('library.id'), index = True)
libraries = relationship('Library')
class Language(Base, TableHelper):
__tablename__ = 'language'
id = Column(Integer, primary_key = True)
""""""
identifier = Column(String(20), index = True)
label = Column(Unicode)
titles_id = Column(Integer, ForeignKey('librarytitle.id'), index = True)
titles = relationship('LibraryTitle')
class Release(Base, TableHelper):
__tablename__ = 'release'
id = Column(Integer, primary_key = True)
"""Logically groups all files that belong to a certain release, such as
parts of a movie, subtitles."""
last_edit = Column(Integer, default = lambda: int(time.time()), index = True)
identifier = Column(String(100), index = True)
movie_id = Column(Integer, ForeignKey('movie.id'), index = True)
movie = relationship('Movie')
status_id = Column(Integer, ForeignKey('status.id'), index = True)
status = relationship('Status')
quality_id = Column(Integer, ForeignKey('quality.id'), index = True)
quality = relationship('Quality')
files = relationship('File', secondary = release_files)
info = relationship('ReleaseInfo') #, cascade = 'all, delete-orphan')
def to_dict(self, deep = None, exclude = None, **kwargs):
if not exclude: exclude = []
if not deep: deep = {}
orig_dict = super(Release, self).to_dict(deep = deep, exclude = exclude)
new_info = {}
for info in orig_dict.get('info', []):
value = info['value']
try: value = int(info['value'])
except: pass
new_info[info['identifier']] = value
orig_dict['info'] = new_info
return orig_dict
class ReleaseInfo(Base, TableHelper):
__tablename__ = 'releaseinfo'
id = Column(Integer, primary_key = True)
"""Properties that can be bound to a file for off-line usage"""
identifier = Column(String(50), index = True)
value = Column(Unicode(255), nullable = False)
release_id = Column(Integer, ForeignKey('release.id'), index = True)
release = relationship('Release')
class Status(Base, TableHelper):
__tablename__ = 'status'
id = Column(Integer, primary_key = True)
"""The status of a release, such as Downloaded, Deleted, Wanted etc"""
identifier = Column(String(20), unique = True)
label = Column(Unicode(20))
releases = relationship('Release')
movies = relationship('Movie')
class Quality(Base, TableHelper):
__tablename__ = 'quality'
id = Column(Integer, primary_key = True)
"""Quality name of a release, DVD, 720p, DVD-Rip etc"""
#using_options(order_by = 'order')
identifier = Column(String(20), unique = True)
label = Column(Unicode(20))
order = Column(Integer, default = 0, index = True)
size_min = Column(Integer)
size_max = Column(Integer)
releases = relationship('Release')
profile_types = relationship('ProfileType', order_by="asc(ProfileType.order)")
class Profile(Base, TableHelper):
__tablename__ = 'profile'
id = Column(Integer, primary_key = True)
""""""
#using_options(order_by = 'order')
label = Column(Unicode(50))
order = Column(Integer, default = 0, index = True)
core = Column(Boolean, default = False)
hide = Column(Boolean, default = False)
movie = relationship('Movie')
types = relationship('ProfileType', order_by="asc(ProfileType.order)") #, cascade = 'all, delete-orphan')
def to_dict(self, deep = None, exclude = None, **kwargs):
if not exclude: exclude = []
if not deep: deep = {}
orig_dict = super(Profile, self).to_dict(deep = deep, exclude = exclude)
orig_dict['core'] = orig_dict.get('core') or False
orig_dict['hide'] = orig_dict.get('hide') or False
return orig_dict
class Category(Base, TableHelper):
__tablename__ = 'category'
id = Column(Integer, primary_key = True)
""""""
#using_options(order_by = 'order')
label = Column(Unicode(50))
order = Column(Integer, default = 0, index = True)
required = Column(Unicode(255))
preferred = Column(Unicode(255))
ignored = Column(Unicode(255))
destination = Column(Unicode(255))
movie = relationship('Movie')
class ProfileType(Base, TableHelper):
__tablename__ = 'profiletype'
id = Column(Integer, primary_key = True)
""""""
#using_options(order_by = 'order')
order = Column(Integer, default = 0, index = True)
finish = Column(Boolean, default = True)
wait_for = Column(Integer, default = 0)
quality_id = Column(Integer, ForeignKey('quality.id'), index = True)
quality = relationship('Quality')
profile_id = Column(Integer, ForeignKey('profile.id'), index = True)
profile = relationship('Profile')
class File(Base, TableHelper):
__tablename__ = 'file'
id = Column(Integer, primary_key = True)
"""File that belongs to a release."""
path = Column(Unicode(255), nullable = False, unique = True)
part = Column(Integer, default = 1)
available = Column(Boolean, default = True)
type_id = Column(Integer, ForeignKey('filetype.id'), index = True)
type = relationship('FileType')
properties = relationship('FileProperty')
movie = relationship('Movie', secondary = movie_files)
release = relationship('Release', secondary = release_files)
library = relationship('Library', secondary = library_files)
class FileType(Base, TableHelper):
__tablename__ = 'filetype'
id = Column(Integer, primary_key = True)
"""Types could be trailer, subtitle, movie, partial movie etc."""
identifier = Column(String(20), unique = True)
type = Column(Unicode(20))
name = Column(Unicode(50), nullable = False)
files = relationship('File')
class FileProperty(Base, TableHelper):
__tablename__ = 'fileproperty'
id = Column(Integer, primary_key = True)
"""Properties that can be bound to a file for off-line usage"""
identifier = Column(String(20), index = True)
value = Column(Unicode(255), nullable = False)
file_id = Column(Integer, ForeignKey('file.id'), index = True)
file = relationship('File')
class Notification(Base, TableHelper):
__tablename__ = 'notification'
id = Column(Integer, primary_key = True)
""""""
#using_options(order_by = 'added')
added = Column(Integer, default = lambda: int(time.time()), index = True)
read = Column(Boolean, default = False, index = True)
message = Column(Unicode(255))
data = Column(JsonType)
class Properties(Base, TableHelper):
__tablename__ = 'properties'
id = Column(Integer, primary_key = True)
""""""
identifier = Column(String(50), index = True)
value = Column(Unicode(255), nullable = False)
def setup():
"""Setup the database and create the tables that don't exists yet"""
from couchpotato.environment import Env
engine = Env.getEngine()
Base.metadata.create_all(engine)
try:
engine.execute("PRAGMA journal_mode = WAL")
engine.execute("PRAGMA temp_store = MEMORY")
except:
pass

View File

@@ -1,9 +1,6 @@
from couchpotato.core.event import fireEvent, addEvent
from couchpotato.core.loader import Loader
from couchpotato.core.settings import Settings
from sqlalchemy.engine import create_engine
from sqlalchemy.orm import scoped_session
from sqlalchemy.orm.session import sessionmaker
import os

View File

@@ -1,210 +0,0 @@
########################## LICENCE ###############################
##
## Copyright (c) 2005-2011, Michele Simionato
## All rights reserved.
##
## Redistributions of source code must retain the above copyright
## notice, this list of conditions and the following disclaimer.
## Redistributions in bytecode form must reproduce the above copyright
## notice, this list of conditions and the following disclaimer in
## the documentation and/or other materials provided with the
## distribution.
## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
## "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
## LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
## A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
## HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
## INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
## BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
## OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
## ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
## TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
## USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
## DAMAGE.
"""
Decorator module, see http://pypi.python.org/pypi/decorator
for the documentation.
"""
__version__ = '3.3.2'
__all__ = ["decorator", "FunctionMaker", "partial"]
import sys, re, inspect
try:
from functools import partial
except ImportError: # for Python version < 2.5
class partial(object):
"A simple replacement of functools.partial"
def __init__(self, func, *args, **kw):
self.func = func
self.args = args
self.keywords = kw
def __call__(self, *otherargs, **otherkw):
kw = self.keywords.copy()
kw.update(otherkw)
return self.func(*(self.args + otherargs), **kw)
if sys.version >= '3':
from inspect import getfullargspec
else:
class getfullargspec(object):
"A quick and dirty replacement for getfullargspec for Python 2.X"
def __init__(self, f):
self.args, self.varargs, self.varkw, self.defaults = \
inspect.getargspec(f)
self.kwonlyargs = []
self.kwonlydefaults = None
self.annotations = getattr(f, '__annotations__', {})
def __iter__(self):
yield self.args
yield self.varargs
yield self.varkw
yield self.defaults
DEF = re.compile('\s*def\s*([_\w][_\w\d]*)\s*\(')
# basic functionality
class FunctionMaker(object):
"""
An object with the ability to create functions with a given signature.
It has attributes name, doc, module, signature, defaults, dict and
methods update and make.
"""
def __init__(self, func=None, name=None, signature=None,
defaults=None, doc=None, module=None, funcdict=None):
self.shortsignature = signature
if func:
# func can be a class or a callable, but not an instance method
self.name = func.__name__
if self.name == '<lambda>': # small hack for lambda functions
self.name = '_lambda_'
self.doc = func.__doc__
self.module = func.__module__
if inspect.isfunction(func):
argspec = getfullargspec(func)
for a in ('args', 'varargs', 'varkw', 'defaults', 'kwonlyargs',
'kwonlydefaults', 'annotations'):
setattr(self, a, getattr(argspec, a))
for i, arg in enumerate(self.args):
setattr(self, 'arg%d' % i, arg)
self.signature = inspect.formatargspec(
formatvalue=lambda val: "", *argspec)[1:-1]
allargs = list(self.args)
if self.varargs:
allargs.append('*' + self.varargs)
if self.varkw:
allargs.append('**' + self.varkw)
try:
self.shortsignature = ', '.join(allargs)
except TypeError: # exotic signature, valid only in Python 2.X
self.shortsignature = self.signature
self.dict = func.__dict__.copy()
# func=None happens when decorating a caller
if name:
self.name = name
if signature is not None:
self.signature = signature
if defaults:
self.defaults = defaults
if doc:
self.doc = doc
if module:
self.module = module
if funcdict:
self.dict = funcdict
# check existence required attributes
assert hasattr(self, 'name')
if not hasattr(self, 'signature'):
raise TypeError('You are decorating a non function: %s' % func)
def update(self, func, **kw):
"Update the signature of func with the data in self"
func.__name__ = self.name
func.__doc__ = getattr(self, 'doc', None)
func.__dict__ = getattr(self, 'dict', {})
func.func_defaults = getattr(self, 'defaults', ())
func.__kwdefaults__ = getattr(self, 'kwonlydefaults', None)
callermodule = sys._getframe(3).f_globals.get('__name__', '?')
func.__module__ = getattr(self, 'module', callermodule)
func.__dict__.update(kw)
def make(self, src_templ, evaldict=None, addsource=False, **attrs):
"Make a new function from a given template and update the signature"
src = src_templ % vars(self) # expand name and signature
evaldict = evaldict or {}
mo = DEF.match(src)
if mo is None:
raise SyntaxError('not a valid function template\n%s' % src)
name = mo.group(1) # extract the function name
names = set([name] + [arg.strip(' *') for arg in
self.shortsignature.split(',')])
for n in names:
if n in ('_func_', '_call_'):
raise NameError('%s is overridden in\n%s' % (n, src))
if not src.endswith('\n'): # add a newline just for safety
src += '\n' # this is needed in old versions of Python
try:
code = compile(src, '<string>', 'single')
# print >> sys.stderr, 'Compiling %s' % src
exec code in evaldict
except:
print >> sys.stderr, 'Error in generated code:'
print >> sys.stderr, src
raise
func = evaldict[name]
if addsource:
attrs['__source__'] = src
self.update(func, **attrs)
return func
@classmethod
def create(cls, obj, body, evaldict, defaults=None,
doc=None, module=None, addsource=True,**attrs):
"""
Create a function from the strings name, signature and body.
evaldict is the evaluation dictionary. If addsource is true an attribute
__source__ is added to the result. The attributes attrs are added,
if any.
"""
if isinstance(obj, str): # "name(signature)"
name, rest = obj.strip().split('(', 1)
signature = rest[:-1] #strip a right parens
func = None
else: # a function
name = None
signature = None
func = obj
self = cls(func, name, signature, defaults, doc, module)
ibody = '\n'.join(' ' + line for line in body.splitlines())
return self.make('def %(name)s(%(signature)s):\n' + ibody,
evaldict, addsource, **attrs)
def decorator(caller, func=None):
"""
decorator(caller) converts a caller function into a decorator;
decorator(caller, func) decorates a function using a caller.
"""
if func is not None: # returns a decorated function
evaldict = func.func_globals.copy()
evaldict['_call_'] = caller
evaldict['_func_'] = func
return FunctionMaker.create(
func, "return _call_(_func_, %(shortsignature)s)",
evaldict, undecorated=func, __wrapped__=func)
else: # returns a decorator
if isinstance(caller, partial):
return partial(decorator, caller)
# otherwise assume caller is a function
first = inspect.getargspec(caller)[0][0] # first arg
evaldict = caller.func_globals.copy()
evaldict['_call_'] = caller
evaldict['decorator'] = decorator
return FunctionMaker.create(
'%s(%s)' % (caller.__name__, first),
'return decorator(_call_, %s)' % first,
evaldict, undecorated=caller, __wrapped__=caller,
doc=caller.__doc__, module=caller.__module__)

View File

@@ -1,11 +0,0 @@
"""
SQLAlchemy migrate provides two APIs :mod:`migrate.versioning` for
database schema version and repository management and
:mod:`migrate.changeset` that allows to define database schema changes
using Python.
"""
from migrate.versioning import *
from migrate.changeset import *
__version__ = '0.7.2'

View File

@@ -1,28 +0,0 @@
"""
This module extends SQLAlchemy and provides additional DDL [#]_
support.
.. [#] SQL Data Definition Language
"""
import re
import warnings
import sqlalchemy
from sqlalchemy import __version__ as _sa_version
warnings.simplefilter('always', DeprecationWarning)
_sa_version = tuple(int(re.match("\d+", x).group(0)) for x in _sa_version.split("."))
SQLA_07 = _sa_version >= (0, 7)
del re
del _sa_version
from migrate.changeset.schema import *
from migrate.changeset.constraint import *
sqlalchemy.schema.Table.__bases__ += (ChangesetTable, )
sqlalchemy.schema.Column.__bases__ += (ChangesetColumn, )
sqlalchemy.schema.Index.__bases__ += (ChangesetIndex, )
sqlalchemy.schema.DefaultClause.__bases__ += (ChangesetDefaultClause, )

View File

@@ -1,292 +0,0 @@
"""
Extensions to SQLAlchemy for altering existing tables.
At the moment, this isn't so much based off of ANSI as much as
things that just happen to work with multiple databases.
"""
import StringIO
import sqlalchemy as sa
from sqlalchemy.schema import SchemaVisitor
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.sql import ClauseElement
from sqlalchemy.schema import (ForeignKeyConstraint,
PrimaryKeyConstraint,
CheckConstraint,
UniqueConstraint,
Index)
from migrate import exceptions
from migrate.changeset import constraint
from sqlalchemy.schema import AddConstraint, DropConstraint
from sqlalchemy.sql.compiler import DDLCompiler
SchemaGenerator = SchemaDropper = DDLCompiler
class AlterTableVisitor(SchemaVisitor):
"""Common operations for ``ALTER TABLE`` statements."""
# engine.Compiler looks for .statement
# when it spawns off a new compiler
statement = ClauseElement()
def append(self, s):
"""Append content to the SchemaIterator's query buffer."""
self.buffer.write(s)
def execute(self):
"""Execute the contents of the SchemaIterator's buffer."""
try:
return self.connection.execute(self.buffer.getvalue())
finally:
self.buffer.truncate(0)
def __init__(self, dialect, connection, **kw):
self.connection = connection
self.buffer = StringIO.StringIO()
self.preparer = dialect.identifier_preparer
self.dialect = dialect
def traverse_single(self, elem):
ret = super(AlterTableVisitor, self).traverse_single(elem)
if ret:
# adapt to 0.6 which uses a string-returning
# object
self.append(" %s" % ret)
def _to_table(self, param):
"""Returns the table object for the given param object."""
if isinstance(param, (sa.Column, sa.Index, sa.schema.Constraint)):
ret = param.table
else:
ret = param
return ret
def start_alter_table(self, param):
"""Returns the start of an ``ALTER TABLE`` SQL-Statement.
Use the param object to determine the table name and use it
for building the SQL statement.
:param param: object to determine the table from
:type param: :class:`sqlalchemy.Column`, :class:`sqlalchemy.Index`,
:class:`sqlalchemy.schema.Constraint`, :class:`sqlalchemy.Table`,
or string (table name)
"""
table = self._to_table(param)
self.append('\nALTER TABLE %s ' % self.preparer.format_table(table))
return table
class ANSIColumnGenerator(AlterTableVisitor, SchemaGenerator):
"""Extends ansisql generator for column creation (alter table add col)"""
def visit_column(self, column):
"""Create a column (table already exists).
:param column: column object
:type column: :class:`sqlalchemy.Column` instance
"""
if column.default is not None:
self.traverse_single(column.default)
table = self.start_alter_table(column)
self.append("ADD ")
self.append(self.get_column_specification(column))
for cons in column.constraints:
self.traverse_single(cons)
self.execute()
# ALTER TABLE STATEMENTS
# add indexes and unique constraints
if column.index_name:
Index(column.index_name,column).create()
elif column.unique_name:
constraint.UniqueConstraint(column,
name=column.unique_name).create()
# SA bounds FK constraints to table, add manually
for fk in column.foreign_keys:
self.add_foreignkey(fk.constraint)
# add primary key constraint if needed
if column.primary_key_name:
cons = constraint.PrimaryKeyConstraint(column,
name=column.primary_key_name)
cons.create()
def add_foreignkey(self, fk):
self.connection.execute(AddConstraint(fk))
class ANSIColumnDropper(AlterTableVisitor, SchemaDropper):
"""Extends ANSI SQL dropper for column dropping (``ALTER TABLE
DROP COLUMN``).
"""
def visit_column(self, column):
"""Drop a column from its table.
:param column: the column object
:type column: :class:`sqlalchemy.Column`
"""
table = self.start_alter_table(column)
self.append('DROP COLUMN %s' % self.preparer.format_column(column))
self.execute()
class ANSISchemaChanger(AlterTableVisitor, SchemaGenerator):
"""Manages changes to existing schema elements.
Note that columns are schema elements; ``ALTER TABLE ADD COLUMN``
is in SchemaGenerator.
All items may be renamed. Columns can also have many of their properties -
type, for example - changed.
Each function is passed a tuple, containing (object, name); where
object is a type of object you'd expect for that function
(ie. table for visit_table) and name is the object's new
name. NONE means the name is unchanged.
"""
def visit_table(self, table):
"""Rename a table. Other ops aren't supported."""
self.start_alter_table(table)
self.append("RENAME TO %s" % self.preparer.quote(table.new_name,
table.quote))
self.execute()
def visit_index(self, index):
"""Rename an index"""
if hasattr(self, '_validate_identifier'):
# SA <= 0.6.3
self.append("ALTER INDEX %s RENAME TO %s" % (
self.preparer.quote(
self._validate_identifier(
index.name, True), index.quote),
self.preparer.quote(
self._validate_identifier(
index.new_name, True), index.quote)))
else:
# SA >= 0.6.5
self.append("ALTER INDEX %s RENAME TO %s" % (
self.preparer.quote(
self._index_identifier(
index.name), index.quote),
self.preparer.quote(
self._index_identifier(
index.new_name), index.quote)))
self.execute()
def visit_column(self, delta):
"""Rename/change a column."""
# ALTER COLUMN is implemented as several ALTER statements
keys = delta.keys()
if 'type' in keys:
self._run_subvisit(delta, self._visit_column_type)
if 'nullable' in keys:
self._run_subvisit(delta, self._visit_column_nullable)
if 'server_default' in keys:
# Skip 'default': only handle server-side defaults, others
# are managed by the app, not the db.
self._run_subvisit(delta, self._visit_column_default)
if 'name' in keys:
self._run_subvisit(delta, self._visit_column_name, start_alter=False)
def _run_subvisit(self, delta, func, start_alter=True):
"""Runs visit method based on what needs to be changed on column"""
table = self._to_table(delta.table)
col_name = delta.current_name
if start_alter:
self.start_alter_column(table, col_name)
ret = func(table, delta.result_column, delta)
self.execute()
def start_alter_column(self, table, col_name):
"""Starts ALTER COLUMN"""
self.start_alter_table(table)
self.append("ALTER COLUMN %s " % self.preparer.quote(col_name, table.quote))
def _visit_column_nullable(self, table, column, delta):
nullable = delta['nullable']
if nullable:
self.append("DROP NOT NULL")
else:
self.append("SET NOT NULL")
def _visit_column_default(self, table, column, delta):
default_text = self.get_column_default_string(column)
if default_text is not None:
self.append("SET DEFAULT %s" % default_text)
else:
self.append("DROP DEFAULT")
def _visit_column_type(self, table, column, delta):
type_ = delta['type']
type_text = str(type_.compile(dialect=self.dialect))
self.append("TYPE %s" % type_text)
def _visit_column_name(self, table, column, delta):
self.start_alter_table(table)
col_name = self.preparer.quote(delta.current_name, table.quote)
new_name = self.preparer.format_column(delta.result_column)
self.append('RENAME COLUMN %s TO %s' % (col_name, new_name))
class ANSIConstraintCommon(AlterTableVisitor):
"""
Migrate's constraints require a separate creation function from
SA's: Migrate's constraints are created independently of a table;
SA's are created at the same time as the table.
"""
def get_constraint_name(self, cons):
"""Gets a name for the given constraint.
If the name is already set it will be used otherwise the
constraint's :meth:`autoname <migrate.changeset.constraint.ConstraintChangeset.autoname>`
method is used.
:param cons: constraint object
"""
if cons.name is not None:
ret = cons.name
else:
ret = cons.name = cons.autoname()
return self.preparer.quote(ret, cons.quote)
def visit_migrate_primary_key_constraint(self, *p, **k):
self._visit_constraint(*p, **k)
def visit_migrate_foreign_key_constraint(self, *p, **k):
self._visit_constraint(*p, **k)
def visit_migrate_check_constraint(self, *p, **k):
self._visit_constraint(*p, **k)
def visit_migrate_unique_constraint(self, *p, **k):
self._visit_constraint(*p, **k)
class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
def _visit_constraint(self, constraint):
constraint.name = self.get_constraint_name(constraint)
self.append(self.process(AddConstraint(constraint)))
self.execute()
class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
def _visit_constraint(self, constraint):
constraint.name = self.get_constraint_name(constraint)
self.append(self.process(DropConstraint(constraint, cascade=constraint.cascade)))
self.execute()
class ANSIDialect(DefaultDialect):
columngenerator = ANSIColumnGenerator
columndropper = ANSIColumnDropper
schemachanger = ANSISchemaChanger
constraintgenerator = ANSIConstraintGenerator
constraintdropper = ANSIConstraintDropper

View File

@@ -1,199 +0,0 @@
"""
This module defines standalone schema constraint classes.
"""
from sqlalchemy import schema
from migrate.exceptions import *
class ConstraintChangeset(object):
"""Base class for Constraint classes."""
def _normalize_columns(self, cols, table_name=False):
"""Given: column objects or names; return col names and
(maybe) a table"""
colnames = []
table = None
for col in cols:
if isinstance(col, schema.Column):
if col.table is not None and table is None:
table = col.table
if table_name:
col = '.'.join((col.table.name, col.name))
else:
col = col.name
colnames.append(col)
return colnames, table
def __do_imports(self, visitor_name, *a, **kw):
engine = kw.pop('engine', self.table.bind)
from migrate.changeset.databases.visitor import (get_engine_visitor,
run_single_visitor)
visitorcallable = get_engine_visitor(engine, visitor_name)
run_single_visitor(engine, visitorcallable, self, *a, **kw)
def create(self, *a, **kw):
"""Create the constraint in the database.
:param engine: the database engine to use. If this is \
:keyword:`None` the instance's engine will be used
:type engine: :class:`sqlalchemy.engine.base.Engine`
:param connection: reuse connection istead of creating new one.
:type connection: :class:`sqlalchemy.engine.base.Connection` instance
"""
# TODO: set the parent here instead of in __init__
self.__do_imports('constraintgenerator', *a, **kw)
def drop(self, *a, **kw):
"""Drop the constraint from the database.
:param engine: the database engine to use. If this is
:keyword:`None` the instance's engine will be used
:param cascade: Issue CASCADE drop if database supports it
:type engine: :class:`sqlalchemy.engine.base.Engine`
:type cascade: bool
:param connection: reuse connection istead of creating new one.
:type connection: :class:`sqlalchemy.engine.base.Connection` instance
:returns: Instance with cleared columns
"""
self.cascade = kw.pop('cascade', False)
self.__do_imports('constraintdropper', *a, **kw)
# the spirit of Constraint objects is that they
# are immutable (just like in a DB. they're only ADDed
# or DROPped).
#self.columns.clear()
return self
class PrimaryKeyConstraint(ConstraintChangeset, schema.PrimaryKeyConstraint):
"""Construct PrimaryKeyConstraint
Migrate's additional parameters:
:param cols: Columns in constraint.
:param table: If columns are passed as strings, this kw is required
:type table: Table instance
:type cols: strings or Column instances
"""
__migrate_visit_name__ = 'migrate_primary_key_constraint'
def __init__(self, *cols, **kwargs):
colnames, table = self._normalize_columns(cols)
table = kwargs.pop('table', table)
super(PrimaryKeyConstraint, self).__init__(*colnames, **kwargs)
if table is not None:
self._set_parent(table)
def autoname(self):
"""Mimic the database's automatic constraint names"""
return "%s_pkey" % self.table.name
class ForeignKeyConstraint(ConstraintChangeset, schema.ForeignKeyConstraint):
"""Construct ForeignKeyConstraint
Migrate's additional parameters:
:param columns: Columns in constraint
:param refcolumns: Columns that this FK reffers to in another table.
:param table: If columns are passed as strings, this kw is required
:type table: Table instance
:type columns: list of strings or Column instances
:type refcolumns: list of strings or Column instances
"""
__migrate_visit_name__ = 'migrate_foreign_key_constraint'
def __init__(self, columns, refcolumns, *args, **kwargs):
colnames, table = self._normalize_columns(columns)
table = kwargs.pop('table', table)
refcolnames, reftable = self._normalize_columns(refcolumns,
table_name=True)
super(ForeignKeyConstraint, self).__init__(colnames, refcolnames, *args,
**kwargs)
if table is not None:
self._set_parent(table)
@property
def referenced(self):
return [e.column for e in self.elements]
@property
def reftable(self):
return self.referenced[0].table
def autoname(self):
"""Mimic the database's automatic constraint names"""
if hasattr(self.columns, 'keys'):
# SA <= 0.5
firstcol = self.columns[self.columns.keys()[0]]
ret = "%(table)s_%(firstcolumn)s_fkey" % dict(
table=firstcol.table.name,
firstcolumn=firstcol.name,)
else:
# SA >= 0.6
ret = "%(table)s_%(firstcolumn)s_fkey" % dict(
table=self.table.name,
firstcolumn=self.columns[0],)
return ret
class CheckConstraint(ConstraintChangeset, schema.CheckConstraint):
"""Construct CheckConstraint
Migrate's additional parameters:
:param sqltext: Plain SQL text to check condition
:param columns: If not name is applied, you must supply this kw\
to autoname constraint
:param table: If columns are passed as strings, this kw is required
:type table: Table instance
:type columns: list of Columns instances
:type sqltext: string
"""
__migrate_visit_name__ = 'migrate_check_constraint'
def __init__(self, sqltext, *args, **kwargs):
cols = kwargs.pop('columns', [])
if not cols and not kwargs.get('name', False):
raise InvalidConstraintError('You must either set "name"'
'parameter or "columns" to autogenarate it.')
colnames, table = self._normalize_columns(cols)
table = kwargs.pop('table', table)
schema.CheckConstraint.__init__(self, sqltext, *args, **kwargs)
if table is not None:
self._set_parent(table)
self.colnames = colnames
def autoname(self):
return "%(table)s_%(cols)s_check" % \
dict(table=self.table.name, cols="_".join(self.colnames))
class UniqueConstraint(ConstraintChangeset, schema.UniqueConstraint):
"""Construct UniqueConstraint
Migrate's additional parameters:
:param cols: Columns in constraint.
:param table: If columns are passed as strings, this kw is required
:type table: Table instance
:type cols: strings or Column instances
.. versionadded:: 0.6.0
"""
__migrate_visit_name__ = 'migrate_unique_constraint'
def __init__(self, *cols, **kwargs):
self.colnames, table = self._normalize_columns(cols)
table = kwargs.pop('table', table)
super(UniqueConstraint, self).__init__(*self.colnames, **kwargs)
if table is not None:
self._set_parent(table)
def autoname(self):
"""Mimic the database's automatic constraint names"""
return "%s_%s_key" % (self.table.name, self.colnames[0])

View File

@@ -1,10 +0,0 @@
"""
This module contains database dialect specific changeset
implementations.
"""
__all__ = [
'postgres',
'sqlite',
'mysql',
'oracle',
]

View File

@@ -1,93 +0,0 @@
"""
Firebird database specific implementations of changeset classes.
"""
from sqlalchemy.databases import firebird as sa_base
from sqlalchemy.schema import PrimaryKeyConstraint
from migrate import exceptions
from migrate.changeset import ansisql
FBSchemaGenerator = sa_base.FBDDLCompiler
class FBColumnGenerator(FBSchemaGenerator, ansisql.ANSIColumnGenerator):
"""Firebird column generator implementation."""
class FBColumnDropper(ansisql.ANSIColumnDropper):
"""Firebird column dropper implementation."""
def visit_column(self, column):
"""Firebird supports 'DROP col' instead of 'DROP COLUMN col' syntax
Drop primary key and unique constraints if dropped column is referencing it."""
if column.primary_key:
if column.table.primary_key.columns.contains_column(column):
column.table.primary_key.drop()
# TODO: recreate primary key if it references more than this column
for index in column.table.indexes:
# "column in index.columns" causes problems as all
# column objects compare equal and return a SQL expression
if column.name in [col.name for col in index.columns]:
index.drop()
# TODO: recreate index if it references more than this column
for cons in column.table.constraints:
if isinstance(cons,PrimaryKeyConstraint):
# will be deleted only when the column its on
# is deleted!
continue
should_drop = column.name in cons.columns
if should_drop:
self.start_alter_table(column)
self.append("DROP CONSTRAINT ")
self.append(self.preparer.format_constraint(cons))
self.execute()
# TODO: recreate unique constraint if it refenrences more than this column
self.start_alter_table(column)
self.append('DROP %s' % self.preparer.format_column(column))
self.execute()
class FBSchemaChanger(ansisql.ANSISchemaChanger):
"""Firebird schema changer implementation."""
def visit_table(self, table):
"""Rename table not supported"""
raise exceptions.NotSupportedError(
"Firebird does not support renaming tables.")
def _visit_column_name(self, table, column, delta):
self.start_alter_table(table)
col_name = self.preparer.quote(delta.current_name, table.quote)
new_name = self.preparer.format_column(delta.result_column)
self.append('ALTER COLUMN %s TO %s' % (col_name, new_name))
def _visit_column_nullable(self, table, column, delta):
"""Changing NULL is not supported"""
# TODO: http://www.firebirdfaq.org/faq103/
raise exceptions.NotSupportedError(
"Firebird does not support altering NULL bevahior.")
class FBConstraintGenerator(ansisql.ANSIConstraintGenerator):
"""Firebird constraint generator implementation."""
class FBConstraintDropper(ansisql.ANSIConstraintDropper):
"""Firebird constaint dropper implementation."""
def cascade_constraint(self, constraint):
"""Cascading constraints is not supported"""
raise exceptions.NotSupportedError(
"Firebird does not support cascading constraints")
class FBDialect(ansisql.ANSIDialect):
columngenerator = FBColumnGenerator
columndropper = FBColumnDropper
schemachanger = FBSchemaChanger
constraintgenerator = FBConstraintGenerator
constraintdropper = FBConstraintDropper

View File

@@ -1,65 +0,0 @@
"""
MySQL database specific implementations of changeset classes.
"""
from sqlalchemy.databases import mysql as sa_base
from sqlalchemy import types as sqltypes
from migrate import exceptions
from migrate.changeset import ansisql
MySQLSchemaGenerator = sa_base.MySQLDDLCompiler
class MySQLColumnGenerator(MySQLSchemaGenerator, ansisql.ANSIColumnGenerator):
pass
class MySQLColumnDropper(ansisql.ANSIColumnDropper):
pass
class MySQLSchemaChanger(MySQLSchemaGenerator, ansisql.ANSISchemaChanger):
def visit_column(self, delta):
table = delta.table
colspec = self.get_column_specification(delta.result_column)
if delta.result_column.autoincrement:
primary_keys = [c for c in table.primary_key.columns
if (c.autoincrement and
isinstance(c.type, sqltypes.Integer) and
not c.foreign_keys)]
if primary_keys:
first = primary_keys.pop(0)
if first.name == delta.current_name:
colspec += " AUTO_INCREMENT"
old_col_name = self.preparer.quote(delta.current_name, table.quote)
self.start_alter_table(table)
self.append("CHANGE COLUMN %s " % old_col_name)
self.append(colspec)
self.execute()
def visit_index(self, param):
# If MySQL can do this, I can't find how
raise exceptions.NotSupportedError("MySQL cannot rename indexes")
class MySQLConstraintGenerator(ansisql.ANSIConstraintGenerator):
pass
class MySQLConstraintDropper(MySQLSchemaGenerator, ansisql.ANSIConstraintDropper):
def visit_migrate_check_constraint(self, *p, **k):
raise exceptions.NotSupportedError("MySQL does not support CHECK"
" constraints, use triggers instead.")
class MySQLDialect(ansisql.ANSIDialect):
columngenerator = MySQLColumnGenerator
columndropper = MySQLColumnDropper
schemachanger = MySQLSchemaChanger
constraintgenerator = MySQLConstraintGenerator
constraintdropper = MySQLConstraintDropper

View File

@@ -1,108 +0,0 @@
"""
Oracle database specific implementations of changeset classes.
"""
import sqlalchemy as sa
from sqlalchemy.databases import oracle as sa_base
from migrate import exceptions
from migrate.changeset import ansisql
OracleSchemaGenerator = sa_base.OracleDDLCompiler
class OracleColumnGenerator(OracleSchemaGenerator, ansisql.ANSIColumnGenerator):
pass
class OracleColumnDropper(ansisql.ANSIColumnDropper):
pass
class OracleSchemaChanger(OracleSchemaGenerator, ansisql.ANSISchemaChanger):
def get_column_specification(self, column, **kwargs):
# Ignore the NOT NULL generated
override_nullable = kwargs.pop('override_nullable', None)
if override_nullable:
orig = column.nullable
column.nullable = True
ret = super(OracleSchemaChanger, self).get_column_specification(
column, **kwargs)
if override_nullable:
column.nullable = orig
return ret
def visit_column(self, delta):
keys = delta.keys()
if 'name' in keys:
self._run_subvisit(delta,
self._visit_column_name,
start_alter=False)
if len(set(('type', 'nullable', 'server_default')).intersection(keys)):
self._run_subvisit(delta,
self._visit_column_change,
start_alter=False)
def _visit_column_change(self, table, column, delta):
# Oracle cannot drop a default once created, but it can set it
# to null. We'll do that if default=None
# http://forums.oracle.com/forums/message.jspa?messageID=1273234#1273234
dropdefault_hack = (column.server_default is None \
and 'server_default' in delta.keys())
# Oracle apparently doesn't like it when we say "not null" if
# the column's already not null. Fudge it, so we don't need a
# new function
notnull_hack = ((not column.nullable) \
and ('nullable' not in delta.keys()))
# We need to specify NULL if we're removing a NOT NULL
# constraint
null_hack = (column.nullable and ('nullable' in delta.keys()))
if dropdefault_hack:
column.server_default = sa.PassiveDefault(sa.sql.null())
if notnull_hack:
column.nullable = True
colspec = self.get_column_specification(column,
override_nullable=null_hack)
if null_hack:
colspec += ' NULL'
if notnull_hack:
column.nullable = False
if dropdefault_hack:
column.server_default = None
self.start_alter_table(table)
self.append("MODIFY (")
self.append(colspec)
self.append(")")
class OracleConstraintCommon(object):
def get_constraint_name(self, cons):
# Oracle constraints can't guess their name like other DBs
if not cons.name:
raise exceptions.NotSupportedError(
"Oracle constraint names must be explicitly stated")
return cons.name
class OracleConstraintGenerator(OracleConstraintCommon,
ansisql.ANSIConstraintGenerator):
pass
class OracleConstraintDropper(OracleConstraintCommon,
ansisql.ANSIConstraintDropper):
pass
class OracleDialect(ansisql.ANSIDialect):
columngenerator = OracleColumnGenerator
columndropper = OracleColumnDropper
schemachanger = OracleSchemaChanger
constraintgenerator = OracleConstraintGenerator
constraintdropper = OracleConstraintDropper

View File

@@ -1,42 +0,0 @@
"""
`PostgreSQL`_ database specific implementations of changeset classes.
.. _`PostgreSQL`: http://www.postgresql.org/
"""
from migrate.changeset import ansisql
from sqlalchemy.databases import postgresql as sa_base
PGSchemaGenerator = sa_base.PGDDLCompiler
class PGColumnGenerator(PGSchemaGenerator, ansisql.ANSIColumnGenerator):
"""PostgreSQL column generator implementation."""
pass
class PGColumnDropper(ansisql.ANSIColumnDropper):
"""PostgreSQL column dropper implementation."""
pass
class PGSchemaChanger(ansisql.ANSISchemaChanger):
"""PostgreSQL schema changer implementation."""
pass
class PGConstraintGenerator(ansisql.ANSIConstraintGenerator):
"""PostgreSQL constraint generator implementation."""
pass
class PGConstraintDropper(ansisql.ANSIConstraintDropper):
"""PostgreSQL constaint dropper implementation."""
pass
class PGDialect(ansisql.ANSIDialect):
columngenerator = PGColumnGenerator
columndropper = PGColumnDropper
schemachanger = PGSchemaChanger
constraintgenerator = PGConstraintGenerator
constraintdropper = PGConstraintDropper

View File

@@ -1,153 +0,0 @@
"""
`SQLite`_ database specific implementations of changeset classes.
.. _`SQLite`: http://www.sqlite.org/
"""
from UserDict import DictMixin
from copy import copy
from sqlalchemy.databases import sqlite as sa_base
from migrate import exceptions
from migrate.changeset import ansisql
SQLiteSchemaGenerator = sa_base.SQLiteDDLCompiler
class SQLiteCommon(object):
def _not_supported(self, op):
raise exceptions.NotSupportedError("SQLite does not support "
"%s; see http://www.sqlite.org/lang_altertable.html" % op)
class SQLiteHelper(SQLiteCommon):
def recreate_table(self,table,column=None,delta=None):
table_name = self.preparer.format_table(table)
# we remove all indexes so as not to have
# problems during copy and re-create
for index in table.indexes:
index.drop()
self.append('ALTER TABLE %s RENAME TO migration_tmp' % table_name)
self.execute()
insertion_string = self._modify_table(table, column, delta)
table.create(bind=self.connection)
self.append(insertion_string % {'table_name': table_name})
self.execute()
self.append('DROP TABLE migration_tmp')
self.execute()
def visit_column(self, delta):
if isinstance(delta, DictMixin):
column = delta.result_column
table = self._to_table(delta.table)
else:
column = delta
table = self._to_table(column.table)
self.recreate_table(table,column,delta)
class SQLiteColumnGenerator(SQLiteSchemaGenerator,
ansisql.ANSIColumnGenerator,
# at the end so we get the normal
# visit_column by default
SQLiteHelper,
SQLiteCommon
):
"""SQLite ColumnGenerator"""
def _modify_table(self, table, column, delta):
columns = ' ,'.join(map(
self.preparer.format_column,
[c for c in table.columns if c.name!=column.name]))
return ('INSERT INTO %%(table_name)s (%(cols)s) '
'SELECT %(cols)s from migration_tmp')%{'cols':columns}
def visit_column(self,column):
if column.foreign_keys:
SQLiteHelper.visit_column(self,column)
else:
super(SQLiteColumnGenerator,self).visit_column(column)
class SQLiteColumnDropper(SQLiteHelper, ansisql.ANSIColumnDropper):
"""SQLite ColumnDropper"""
def _modify_table(self, table, column, delta):
columns = ' ,'.join(map(self.preparer.format_column, table.columns))
return 'INSERT INTO %(table_name)s SELECT ' + columns + \
' from migration_tmp'
def visit_column(self,column):
# For SQLite, we *have* to remove the column here so the table
# is re-created properly.
column.remove_from_table(column.table,unset_table=False)
super(SQLiteColumnDropper,self).visit_column(column)
class SQLiteSchemaChanger(SQLiteHelper, ansisql.ANSISchemaChanger):
"""SQLite SchemaChanger"""
def _modify_table(self, table, column, delta):
return 'INSERT INTO %(table_name)s SELECT * from migration_tmp'
def visit_index(self, index):
"""Does not support ALTER INDEX"""
self._not_supported('ALTER INDEX')
class SQLiteConstraintGenerator(ansisql.ANSIConstraintGenerator, SQLiteHelper, SQLiteCommon):
def visit_migrate_primary_key_constraint(self, constraint):
tmpl = "CREATE UNIQUE INDEX %s ON %s ( %s )"
cols = ', '.join(map(self.preparer.format_column, constraint.columns))
tname = self.preparer.format_table(constraint.table)
name = self.get_constraint_name(constraint)
msg = tmpl % (name, tname, cols)
self.append(msg)
self.execute()
def _modify_table(self, table, column, delta):
return 'INSERT INTO %(table_name)s SELECT * from migration_tmp'
def visit_migrate_foreign_key_constraint(self, *p, **k):
self.recreate_table(p[0].table)
def visit_migrate_unique_constraint(self, *p, **k):
self.recreate_table(p[0].table)
class SQLiteConstraintDropper(ansisql.ANSIColumnDropper,
SQLiteCommon,
ansisql.ANSIConstraintCommon):
def visit_migrate_primary_key_constraint(self, constraint):
tmpl = "DROP INDEX %s "
name = self.get_constraint_name(constraint)
msg = tmpl % (name)
self.append(msg)
self.execute()
def visit_migrate_foreign_key_constraint(self, *p, **k):
self._not_supported('ALTER TABLE DROP CONSTRAINT')
def visit_migrate_check_constraint(self, *p, **k):
self._not_supported('ALTER TABLE DROP CONSTRAINT')
def visit_migrate_unique_constraint(self, *p, **k):
self._not_supported('ALTER TABLE DROP CONSTRAINT')
# TODO: technically primary key is a NOT NULL + UNIQUE constraint, should add NOT NULL to index
class SQLiteDialect(ansisql.ANSIDialect):
columngenerator = SQLiteColumnGenerator
columndropper = SQLiteColumnDropper
schemachanger = SQLiteSchemaChanger
constraintgenerator = SQLiteConstraintGenerator
constraintdropper = SQLiteConstraintDropper

View File

@@ -1,78 +0,0 @@
"""
Module for visitor class mapping.
"""
import sqlalchemy as sa
from migrate.changeset import ansisql
from migrate.changeset.databases import (sqlite,
postgres,
mysql,
oracle,
firebird)
# Map SA dialects to the corresponding Migrate extensions
DIALECTS = {
"default": ansisql.ANSIDialect,
"sqlite": sqlite.SQLiteDialect,
"postgres": postgres.PGDialect,
"postgresql": postgres.PGDialect,
"mysql": mysql.MySQLDialect,
"oracle": oracle.OracleDialect,
"firebird": firebird.FBDialect,
}
def get_engine_visitor(engine, name):
"""
Get the visitor implementation for the given database engine.
:param engine: SQLAlchemy Engine
:param name: Name of the visitor
:type name: string
:type engine: Engine
:returns: visitor
"""
# TODO: link to supported visitors
return get_dialect_visitor(engine.dialect, name)
def get_dialect_visitor(sa_dialect, name):
"""
Get the visitor implementation for the given dialect.
Finds the visitor implementation based on the dialect class and
returns and instance initialized with the given name.
Binds dialect specific preparer to visitor.
"""
# map sa dialect to migrate dialect and return visitor
sa_dialect_name = getattr(sa_dialect, 'name', 'default')
migrate_dialect_cls = DIALECTS[sa_dialect_name]
visitor = getattr(migrate_dialect_cls, name)
# bind preparer
visitor.preparer = sa_dialect.preparer(sa_dialect)
return visitor
def run_single_visitor(engine, visitorcallable, element,
connection=None, **kwargs):
"""Taken from :meth:`sqlalchemy.engine.base.Engine._run_single_visitor`
with support for migrate visitors.
"""
if connection is None:
conn = engine.contextual_connect(close_with_result=False)
else:
conn = connection
visitor = visitorcallable(engine.dialect, conn)
try:
if hasattr(element, '__migrate_visit_name__'):
fn = getattr(visitor, 'visit_' + element.__migrate_visit_name__)
else:
fn = getattr(visitor, 'visit_' + element.__visit_name__)
fn(element, **kwargs)
finally:
if connection is None:
conn.close()

View File

@@ -1,655 +0,0 @@
"""
Schema module providing common schema operations.
"""
import warnings
from UserDict import DictMixin
import sqlalchemy
from sqlalchemy.schema import ForeignKeyConstraint
from sqlalchemy.schema import UniqueConstraint
from migrate.exceptions import *
from migrate.changeset import SQLA_07
from migrate.changeset.databases.visitor import (get_engine_visitor,
run_single_visitor)
__all__ = [
'create_column',
'drop_column',
'alter_column',
'rename_table',
'rename_index',
'ChangesetTable',
'ChangesetColumn',
'ChangesetIndex',
'ChangesetDefaultClause',
'ColumnDelta',
]
def create_column(column, table=None, *p, **kw):
"""Create a column, given the table.
API to :meth:`ChangesetColumn.create`.
"""
if table is not None:
return table.create_column(column, *p, **kw)
return column.create(*p, **kw)
def drop_column(column, table=None, *p, **kw):
"""Drop a column, given the table.
API to :meth:`ChangesetColumn.drop`.
"""
if table is not None:
return table.drop_column(column, *p, **kw)
return column.drop(*p, **kw)
def rename_table(table, name, engine=None, **kw):
"""Rename a table.
If Table instance is given, engine is not used.
API to :meth:`ChangesetTable.rename`.
:param table: Table to be renamed.
:param name: New name for Table.
:param engine: Engine instance.
:type table: string or Table instance
:type name: string
:type engine: obj
"""
table = _to_table(table, engine)
table.rename(name, **kw)
def rename_index(index, name, table=None, engine=None, **kw):
"""Rename an index.
If Index instance is given,
table and engine are not used.
API to :meth:`ChangesetIndex.rename`.
:param index: Index to be renamed.
:param name: New name for index.
:param table: Table to which Index is reffered.
:param engine: Engine instance.
:type index: string or Index instance
:type name: string
:type table: string or Table instance
:type engine: obj
"""
index = _to_index(index, table, engine)
index.rename(name, **kw)
def alter_column(*p, **k):
"""Alter a column.
This is a helper function that creates a :class:`ColumnDelta` and
runs it.
:argument column:
The name of the column to be altered or a
:class:`ChangesetColumn` column representing it.
:param table:
A :class:`~sqlalchemy.schema.Table` or table name to
for the table where the column will be changed.
:param engine:
The :class:`~sqlalchemy.engine.base.Engine` to use for table
reflection and schema alterations.
:returns: A :class:`ColumnDelta` instance representing the change.
"""
if 'table' not in k and isinstance(p[0], sqlalchemy.Column):
k['table'] = p[0].table
if 'engine' not in k:
k['engine'] = k['table'].bind
# deprecation
if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
warnings.warn(
"Passing a Column object to alter_column is deprecated."
" Just pass in keyword parameters instead.",
MigrateDeprecationWarning
)
engine = k['engine']
# enough tests seem to break when metadata is always altered
# that this crutch has to be left in until they can be sorted
# out
k['alter_metadata']=True
delta = ColumnDelta(*p, **k)
visitorcallable = get_engine_visitor(engine, 'schemachanger')
engine._run_visitor(visitorcallable, delta)
return delta
def _to_table(table, engine=None):
"""Return if instance of Table, else construct new with metadata"""
if isinstance(table, sqlalchemy.Table):
return table
# Given: table name, maybe an engine
meta = sqlalchemy.MetaData()
if engine is not None:
meta.bind = engine
return sqlalchemy.Table(table, meta)
def _to_index(index, table=None, engine=None):
"""Return if instance of Index, else construct new with metadata"""
if isinstance(index, sqlalchemy.Index):
return index
# Given: index name; table name required
table = _to_table(table, engine)
ret = sqlalchemy.Index(index)
ret.table = table
return ret
class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem):
"""Extracts the differences between two columns/column-parameters
May receive parameters arranged in several different ways:
* **current_column, new_column, \*p, \*\*kw**
Additional parameters can be specified to override column
differences.
* **current_column, \*p, \*\*kw**
Additional parameters alter current_column. Table name is extracted
from current_column object.
Name is changed to current_column.name from current_name,
if current_name is specified.
* **current_col_name, \*p, \*\*kw**
Table kw must specified.
:param table: Table at which current Column should be bound to.\
If table name is given, reflection will be used.
:type table: string or Table instance
:param metadata: A :class:`MetaData` instance to store
reflected table names
:param engine: When reflecting tables, either engine or metadata must \
be specified to acquire engine object.
:type engine: :class:`Engine` instance
:returns: :class:`ColumnDelta` instance provides interface for altered attributes to \
`result_column` through :func:`dict` alike object.
* :class:`ColumnDelta`.result_column is altered column with new attributes
* :class:`ColumnDelta`.current_name is current name of column in db
"""
# Column attributes that can be altered
diff_keys = ('name', 'type', 'primary_key', 'nullable',
'server_onupdate', 'server_default', 'autoincrement')
diffs = dict()
__visit_name__ = 'column'
def __init__(self, *p, **kw):
# 'alter_metadata' is not a public api. It exists purely
# as a crutch until the tests that fail when 'alter_metadata'
# behaviour always happens can be sorted out
self.alter_metadata = kw.pop("alter_metadata", False)
self.meta = kw.pop("metadata", None)
self.engine = kw.pop("engine", None)
# Things are initialized differently depending on how many column
# parameters are given. Figure out how many and call the appropriate
# method.
if len(p) >= 1 and isinstance(p[0], sqlalchemy.Column):
# At least one column specified
if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column):
# Two columns specified
diffs = self.compare_2_columns(*p, **kw)
else:
# Exactly one column specified
diffs = self.compare_1_column(*p, **kw)
else:
# Zero columns specified
if not len(p) or not isinstance(p[0], basestring):
raise ValueError("First argument must be column name")
diffs = self.compare_parameters(*p, **kw)
self.apply_diffs(diffs)
def __repr__(self):
return '<ColumnDelta altermetadata=%r, %s>' % (
self.alter_metadata,
super(ColumnDelta, self).__repr__()
)
def __getitem__(self, key):
if key not in self.keys():
raise KeyError("No such diff key, available: %s" % self.diffs )
return getattr(self.result_column, key)
def __setitem__(self, key, value):
if key not in self.keys():
raise KeyError("No such diff key, available: %s" % self.diffs )
setattr(self.result_column, key, value)
def __delitem__(self, key):
raise NotImplementedError
def keys(self):
return self.diffs.keys()
def compare_parameters(self, current_name, *p, **k):
"""Compares Column objects with reflection"""
self.table = k.pop('table')
self.result_column = self._table.c.get(current_name)
if len(p):
k = self._extract_parameters(p, k, self.result_column)
return k
def compare_1_column(self, col, *p, **k):
"""Compares one Column object"""
self.table = k.pop('table', None)
if self.table is None:
self.table = col.table
self.result_column = col
if len(p):
k = self._extract_parameters(p, k, self.result_column)
return k
def compare_2_columns(self, old_col, new_col, *p, **k):
"""Compares two Column objects"""
self.process_column(new_col)
self.table = k.pop('table', None)
# we cannot use bool() on table in SA06
if self.table is None:
self.table = old_col.table
if self.table is None:
new_col.table
self.result_column = old_col
# set differences
# leave out some stuff for later comp
for key in (set(self.diff_keys) - set(('type',))):
val = getattr(new_col, key, None)
if getattr(self.result_column, key, None) != val:
k.setdefault(key, val)
# inspect types
if not self.are_column_types_eq(self.result_column.type, new_col.type):
k.setdefault('type', new_col.type)
if len(p):
k = self._extract_parameters(p, k, self.result_column)
return k
def apply_diffs(self, diffs):
"""Populate dict and column object with new values"""
self.diffs = diffs
for key in self.diff_keys:
if key in diffs:
setattr(self.result_column, key, diffs[key])
self.process_column(self.result_column)
# create an instance of class type if not yet
if 'type' in diffs and callable(self.result_column.type):
self.result_column.type = self.result_column.type()
# add column to the table
if self.table is not None and self.alter_metadata:
self.result_column.add_to_table(self.table)
def are_column_types_eq(self, old_type, new_type):
"""Compares two types to be equal"""
ret = old_type.__class__ == new_type.__class__
# String length is a special case
if ret and isinstance(new_type, sqlalchemy.types.String):
ret = (getattr(old_type, 'length', None) == \
getattr(new_type, 'length', None))
return ret
def _extract_parameters(self, p, k, column):
"""Extracts data from p and modifies diffs"""
p = list(p)
while len(p):
if isinstance(p[0], basestring):
k.setdefault('name', p.pop(0))
elif isinstance(p[0], sqlalchemy.types.AbstractType):
k.setdefault('type', p.pop(0))
elif callable(p[0]):
p[0] = p[0]()
else:
break
if len(p):
new_col = column.copy_fixed()
new_col._init_items(*p)
k = self.compare_2_columns(column, new_col, **k)
return k
def process_column(self, column):
"""Processes default values for column"""
# XXX: this is a snippet from SA processing of positional parameters
toinit = list()
if column.server_default is not None:
if isinstance(column.server_default, sqlalchemy.FetchedValue):
toinit.append(column.server_default)
else:
toinit.append(sqlalchemy.DefaultClause(column.server_default))
if column.server_onupdate is not None:
if isinstance(column.server_onupdate, FetchedValue):
toinit.append(column.server_default)
else:
toinit.append(sqlalchemy.DefaultClause(column.server_onupdate,
for_update=True))
if toinit:
column._init_items(*toinit)
def _get_table(self):
return getattr(self, '_table', None)
def _set_table(self, table):
if isinstance(table, basestring):
if self.alter_metadata:
if not self.meta:
raise ValueError("metadata must be specified for table"
" reflection when using alter_metadata")
meta = self.meta
if self.engine:
meta.bind = self.engine
else:
if not self.engine and not self.meta:
raise ValueError("engine or metadata must be specified"
" to reflect tables")
if not self.engine:
self.engine = self.meta.bind
meta = sqlalchemy.MetaData(bind=self.engine)
self._table = sqlalchemy.Table(table, meta, autoload=True)
elif isinstance(table, sqlalchemy.Table):
self._table = table
if not self.alter_metadata:
self._table.meta = sqlalchemy.MetaData(bind=self._table.bind)
def _get_result_column(self):
return getattr(self, '_result_column', None)
def _set_result_column(self, column):
"""Set Column to Table based on alter_metadata evaluation."""
self.process_column(column)
if not hasattr(self, 'current_name'):
self.current_name = column.name
if self.alter_metadata:
self._result_column = column
else:
self._result_column = column.copy_fixed()
table = property(_get_table, _set_table)
result_column = property(_get_result_column, _set_result_column)
class ChangesetTable(object):
"""Changeset extensions to SQLAlchemy tables."""
def create_column(self, column, *p, **kw):
"""Creates a column.
The column parameter may be a column definition or the name of
a column in this table.
API to :meth:`ChangesetColumn.create`
:param column: Column to be created
:type column: Column instance or string
"""
if not isinstance(column, sqlalchemy.Column):
# It's a column name
column = getattr(self.c, str(column))
column.create(table=self, *p, **kw)
def drop_column(self, column, *p, **kw):
"""Drop a column, given its name or definition.
API to :meth:`ChangesetColumn.drop`
:param column: Column to be droped
:type column: Column instance or string
"""
if not isinstance(column, sqlalchemy.Column):
# It's a column name
try:
column = getattr(self.c, str(column))
except AttributeError:
# That column isn't part of the table. We don't need
# its entire definition to drop the column, just its
# name, so create a dummy column with the same name.
column = sqlalchemy.Column(str(column), sqlalchemy.Integer())
column.drop(table=self, *p, **kw)
def rename(self, name, connection=None, **kwargs):
"""Rename this table.
:param name: New name of the table.
:type name: string
:param connection: reuse connection istead of creating new one.
:type connection: :class:`sqlalchemy.engine.base.Connection` instance
"""
engine = self.bind
self.new_name = name
visitorcallable = get_engine_visitor(engine, 'schemachanger')
run_single_visitor(engine, visitorcallable, self, connection, **kwargs)
# Fix metadata registration
self.name = name
self.deregister()
self._set_parent(self.metadata)
def _meta_key(self):
"""Get the meta key for this table."""
return sqlalchemy.schema._get_table_key(self.name, self.schema)
def deregister(self):
"""Remove this table from its metadata"""
if SQLA_07:
self.metadata._remove_table(self.name, self.schema)
else:
key = self._meta_key()
meta = self.metadata
if key in meta.tables:
del meta.tables[key]
class ChangesetColumn(object):
"""Changeset extensions to SQLAlchemy columns."""
def alter(self, *p, **k):
"""Makes a call to :func:`alter_column` for the column this
method is called on.
"""
if 'table' not in k:
k['table'] = self.table
if 'engine' not in k:
k['engine'] = k['table'].bind
return alter_column(self, *p, **k)
def create(self, table=None, index_name=None, unique_name=None,
primary_key_name=None, populate_default=True, connection=None, **kwargs):
"""Create this column in the database.
Assumes the given table exists. ``ALTER TABLE ADD COLUMN``,
for most databases.
:param table: Table instance to create on.
:param index_name: Creates :class:`ChangesetIndex` on this column.
:param unique_name: Creates :class:\
`~migrate.changeset.constraint.UniqueConstraint` on this column.
:param primary_key_name: Creates :class:\
`~migrate.changeset.constraint.PrimaryKeyConstraint` on this column.
:param populate_default: If True, created column will be \
populated with defaults
:param connection: reuse connection istead of creating new one.
:type table: Table instance
:type index_name: string
:type unique_name: string
:type primary_key_name: string
:type populate_default: bool
:type connection: :class:`sqlalchemy.engine.base.Connection` instance
:returns: self
"""
self.populate_default = populate_default
self.index_name = index_name
self.unique_name = unique_name
self.primary_key_name = primary_key_name
for cons in ('index_name', 'unique_name', 'primary_key_name'):
self._check_sanity_constraints(cons)
self.add_to_table(table)
engine = self.table.bind
visitorcallable = get_engine_visitor(engine, 'columngenerator')
engine._run_visitor(visitorcallable, self, connection, **kwargs)
# TODO: reuse existing connection
if self.populate_default and self.default is not None:
stmt = table.update().values({self: engine._execute_default(self.default)})
engine.execute(stmt)
return self
def drop(self, table=None, connection=None, **kwargs):
"""Drop this column from the database, leaving its table intact.
``ALTER TABLE DROP COLUMN``, for most databases.
:param connection: reuse connection istead of creating new one.
:type connection: :class:`sqlalchemy.engine.base.Connection` instance
"""
if table is not None:
self.table = table
engine = self.table.bind
visitorcallable = get_engine_visitor(engine, 'columndropper')
engine._run_visitor(visitorcallable, self, connection, **kwargs)
self.remove_from_table(self.table, unset_table=False)
self.table = None
return self
def add_to_table(self, table):
if table is not None and self.table is None:
if SQLA_07:
table.append_column(self)
else:
self._set_parent(table)
def _col_name_in_constraint(self,cons,name):
return False
def remove_from_table(self, table, unset_table=True):
# TODO: remove primary keys, constraints, etc
if unset_table:
self.table = None
to_drop = set()
for index in table.indexes:
columns = []
for col in index.columns:
if col.name!=self.name:
columns.append(col)
if columns:
index.columns=columns
else:
to_drop.add(index)
table.indexes = table.indexes - to_drop
to_drop = set()
for cons in table.constraints:
# TODO: deal with other types of constraint
if isinstance(cons,(ForeignKeyConstraint,
UniqueConstraint)):
for col_name in cons.columns:
if not isinstance(col_name,basestring):
col_name = col_name.name
if self.name==col_name:
to_drop.add(cons)
table.constraints = table.constraints - to_drop
if table.c.contains_column(self):
if SQLA_07:
table._columns.remove(self)
else:
table.c.remove(self)
# TODO: this is fixed in 0.6
def copy_fixed(self, **kw):
"""Create a copy of this ``Column``, with all attributes."""
return sqlalchemy.Column(self.name, self.type, self.default,
key=self.key,
primary_key=self.primary_key,
nullable=self.nullable,
quote=self.quote,
index=self.index,
unique=self.unique,
onupdate=self.onupdate,
autoincrement=self.autoincrement,
server_default=self.server_default,
server_onupdate=self.server_onupdate,
*[c.copy(**kw) for c in self.constraints])
def _check_sanity_constraints(self, name):
"""Check if constraints names are correct"""
obj = getattr(self, name)
if (getattr(self, name[:-5]) and not obj):
raise InvalidConstraintError("Column.create() accepts index_name,"
" primary_key_name and unique_name to generate constraints")
if not isinstance(obj, basestring) and obj is not None:
raise InvalidConstraintError(
"%s argument for column must be constraint name" % name)
class ChangesetIndex(object):
"""Changeset extensions to SQLAlchemy Indexes."""
__visit_name__ = 'index'
def rename(self, name, connection=None, **kwargs):
"""Change the name of an index.
:param name: New name of the Index.
:type name: string
:param connection: reuse connection istead of creating new one.
:type connection: :class:`sqlalchemy.engine.base.Connection` instance
"""
engine = self.table.bind
self.new_name = name
visitorcallable = get_engine_visitor(engine, 'schemachanger')
engine._run_visitor(visitorcallable, self, connection, **kwargs)
self.name = name
class ChangesetDefaultClause(object):
"""Implements comparison between :class:`DefaultClause` instances"""
def __eq__(self, other):
if isinstance(other, self.__class__):
if self.arg == other.arg:
return True
def __ne__(self, other):
return not self.__eq__(other)

View File

@@ -1,87 +0,0 @@
"""
Provide exception classes for :mod:`migrate`
"""
class Error(Exception):
"""Error base class."""
class ApiError(Error):
"""Base class for API errors."""
class KnownError(ApiError):
"""A known error condition."""
class UsageError(ApiError):
"""A known error condition where help should be displayed."""
class ControlledSchemaError(Error):
"""Base class for controlled schema errors."""
class InvalidVersionError(ControlledSchemaError):
"""Invalid version number."""
class DatabaseNotControlledError(ControlledSchemaError):
"""Database should be under version control, but it's not."""
class DatabaseAlreadyControlledError(ControlledSchemaError):
"""Database shouldn't be under version control, but it is"""
class WrongRepositoryError(ControlledSchemaError):
"""This database is under version control by another repository."""
class NoSuchTableError(ControlledSchemaError):
"""The table does not exist."""
class PathError(Error):
"""Base class for path errors."""
class PathNotFoundError(PathError):
"""A path with no file was required; found a file."""
class PathFoundError(PathError):
"""A path with a file was required; found no file."""
class RepositoryError(Error):
"""Base class for repository errors."""
class InvalidRepositoryError(RepositoryError):
"""Invalid repository error."""
class ScriptError(Error):
"""Base class for script errors."""
class InvalidScriptError(ScriptError):
"""Invalid script error."""
class InvalidVersionError(Error):
"""Invalid version error."""
# migrate.changeset
class NotSupportedError(Error):
"""Not supported error"""
class InvalidConstraintError(Error):
"""Invalid constraint error"""
class MigrateDeprecationWarning(DeprecationWarning):
"""Warning for deprecated features in Migrate"""

View File

@@ -1,5 +0,0 @@
"""
This package provides functionality to create and manage
repositories of database schema changesets and to apply these
changesets to databases.
"""

View File

@@ -1,384 +0,0 @@
"""
This module provides an external API to the versioning system.
.. versionchanged:: 0.6.0
:func:`migrate.versioning.api.test` and schema diff functions
changed order of positional arguments so all accept `url` and `repository`
as first arguments.
.. versionchanged:: 0.5.4
``--preview_sql`` displays source file when using SQL scripts.
If Python script is used, it runs the action with mocked engine and
returns captured SQL statements.
.. versionchanged:: 0.5.4
Deprecated ``--echo`` parameter in favour of new
:func:`migrate.versioning.util.construct_engine` behavior.
"""
# Dear migrate developers,
#
# please do not comment this module using sphinx syntax because its
# docstrings are presented as user help and most users cannot
# interpret sphinx annotated ReStructuredText.
#
# Thanks,
# Jan Dittberner
import sys
import inspect
import logging
from migrate import exceptions
from migrate.versioning import (repository, schema, version,
script as script_) # command name conflict
from migrate.versioning.util import catch_known_errors, with_engine
log = logging.getLogger(__name__)
command_desc = {
'help': 'displays help on a given command',
'create': 'create an empty repository at the specified path',
'script': 'create an empty change Python script',
'script_sql': 'create empty change SQL scripts for given database',
'version': 'display the latest version available in a repository',
'db_version': 'show the current version of the repository under version control',
'source': 'display the Python code for a particular version in this repository',
'version_control': 'mark a database as under this repository\'s version control',
'upgrade': 'upgrade a database to a later version',
'downgrade': 'downgrade a database to an earlier version',
'drop_version_control': 'removes version control from a database',
'manage': 'creates a Python script that runs Migrate with a set of default values',
'test': 'performs the upgrade and downgrade command on the given database',
'compare_model_to_db': 'compare MetaData against the current database state',
'create_model': 'dump the current database as a Python model to stdout',
'make_update_script_for_model': 'create a script changing the old MetaData to the new (current) MetaData',
'update_db_from_model': 'modify the database to match the structure of the current MetaData',
}
__all__ = command_desc.keys()
Repository = repository.Repository
ControlledSchema = schema.ControlledSchema
VerNum = version.VerNum
PythonScript = script_.PythonScript
SqlScript = script_.SqlScript
# deprecated
def help(cmd=None, **opts):
"""%prog help COMMAND
Displays help on a given command.
"""
if cmd is None:
raise exceptions.UsageError(None)
try:
func = globals()[cmd]
except:
raise exceptions.UsageError(
"'%s' isn't a valid command. Try 'help COMMAND'" % cmd)
ret = func.__doc__
if sys.argv[0]:
ret = ret.replace('%prog', sys.argv[0])
return ret
@catch_known_errors
def create(repository, name, **opts):
"""%prog create REPOSITORY_PATH NAME [--table=TABLE]
Create an empty repository at the specified path.
You can specify the version_table to be used; by default, it is
'migrate_version'. This table is created in all version-controlled
databases.
"""
repo_path = Repository.create(repository, name, **opts)
@catch_known_errors
def script(description, repository, **opts):
"""%prog script DESCRIPTION REPOSITORY_PATH
Create an empty change script using the next unused version number
appended with the given description.
For instance, manage.py script "Add initial tables" creates:
repository/versions/001_Add_initial_tables.py
"""
repo = Repository(repository)
repo.create_script(description, **opts)
@catch_known_errors
def script_sql(database, description, repository, **opts):
"""%prog script_sql DATABASE DESCRIPTION REPOSITORY_PATH
Create empty change SQL scripts for given DATABASE, where DATABASE
is either specific ('postgresql', 'mysql', 'oracle', 'sqlite', etc.)
or generic ('default').
For instance, manage.py script_sql postgresql description creates:
repository/versions/001_description_postgresql_upgrade.sql and
repository/versions/001_description_postgresql_downgrade.sql
"""
repo = Repository(repository)
repo.create_script_sql(database, description, **opts)
def version(repository, **opts):
"""%prog version REPOSITORY_PATH
Display the latest version available in a repository.
"""
repo = Repository(repository)
return repo.latest
@with_engine
def db_version(url, repository, **opts):
"""%prog db_version URL REPOSITORY_PATH
Show the current version of the repository with the given
connection string, under version control of the specified
repository.
The url should be any valid SQLAlchemy connection string.
"""
engine = opts.pop('engine')
schema = ControlledSchema(engine, repository)
return schema.version
def source(version, dest=None, repository=None, **opts):
"""%prog source VERSION [DESTINATION] --repository=REPOSITORY_PATH
Display the Python code for a particular version in this
repository. Save it to the file at DESTINATION or, if omitted,
send to stdout.
"""
if repository is None:
raise exceptions.UsageError("A repository must be specified")
repo = Repository(repository)
ret = repo.version(version).script().source()
if dest is not None:
dest = open(dest, 'w')
dest.write(ret)
dest.close()
ret = None
return ret
def upgrade(url, repository, version=None, **opts):
"""%prog upgrade URL REPOSITORY_PATH [VERSION] [--preview_py|--preview_sql]
Upgrade a database to a later version.
This runs the upgrade() function defined in your change scripts.
By default, the database is updated to the latest available
version. You may specify a version instead, if you wish.
You may preview the Python or SQL code to be executed, rather than
actually executing it, using the appropriate 'preview' option.
"""
err = "Cannot upgrade a database of version %s to version %s. "\
"Try 'downgrade' instead."
return _migrate(url, repository, version, upgrade=True, err=err, **opts)
def downgrade(url, repository, version, **opts):
"""%prog downgrade URL REPOSITORY_PATH VERSION [--preview_py|--preview_sql]
Downgrade a database to an earlier version.
This is the reverse of upgrade; this runs the downgrade() function
defined in your change scripts.
You may preview the Python or SQL code to be executed, rather than
actually executing it, using the appropriate 'preview' option.
"""
err = "Cannot downgrade a database of version %s to version %s. "\
"Try 'upgrade' instead."
return _migrate(url, repository, version, upgrade=False, err=err, **opts)
@with_engine
def test(url, repository, **opts):
"""%prog test URL REPOSITORY_PATH [VERSION]
Performs the upgrade and downgrade option on the given
database. This is not a real test and may leave the database in a
bad state. You should therefore better run the test on a copy of
your database.
"""
engine = opts.pop('engine')
repos = Repository(repository)
# Upgrade
log.info("Upgrading...")
script = repos.version(None).script(engine.name, 'upgrade')
script.run(engine, 1)
log.info("done")
log.info("Downgrading...")
script = repos.version(None).script(engine.name, 'downgrade')
script.run(engine, -1)
log.info("done")
log.info("Success")
@with_engine
def version_control(url, repository, version=None, **opts):
"""%prog version_control URL REPOSITORY_PATH [VERSION]
Mark a database as under this repository's version control.
Once a database is under version control, schema changes should
only be done via change scripts in this repository.
This creates the table version_table in the database.
The url should be any valid SQLAlchemy connection string.
By default, the database begins at version 0 and is assumed to be
empty. If the database is not empty, you may specify a version at
which to begin instead. No attempt is made to verify this
version's correctness - the database schema is expected to be
identical to what it would be if the database were created from
scratch.
"""
engine = opts.pop('engine')
ControlledSchema.create(engine, repository, version)
@with_engine
def drop_version_control(url, repository, **opts):
"""%prog drop_version_control URL REPOSITORY_PATH
Removes version control from a database.
"""
engine = opts.pop('engine')
schema = ControlledSchema(engine, repository)
schema.drop()
def manage(file, **opts):
"""%prog manage FILENAME [VARIABLES...]
Creates a script that runs Migrate with a set of default values.
For example::
%prog manage manage.py --repository=/path/to/repository \
--url=sqlite:///project.db
would create the script manage.py. The following two commands
would then have exactly the same results::
python manage.py version
%prog version --repository=/path/to/repository
"""
Repository.create_manage_file(file, **opts)
@with_engine
def compare_model_to_db(url, repository, model, **opts):
"""%prog compare_model_to_db URL REPOSITORY_PATH MODEL
Compare the current model (assumed to be a module level variable
of type sqlalchemy.MetaData) against the current database.
NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label
engine = opts.pop('engine')
return ControlledSchema.compare_model_to_db(engine, model, repository)
@with_engine
def create_model(url, repository, **opts):
"""%prog create_model URL REPOSITORY_PATH [DECLERATIVE=True]
Dump the current database as a Python model to stdout.
NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label
engine = opts.pop('engine')
declarative = opts.get('declarative', False)
return ControlledSchema.create_model(engine, repository, declarative)
@catch_known_errors
@with_engine
def make_update_script_for_model(url, repository, oldmodel, model, **opts):
"""%prog make_update_script_for_model URL OLDMODEL MODEL REPOSITORY_PATH
Create a script changing the old Python model to the new (current)
Python model, sending to stdout.
NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label
engine = opts.pop('engine')
return PythonScript.make_update_script_for_model(
engine, oldmodel, model, repository, **opts)
@with_engine
def update_db_from_model(url, repository, model, **opts):
"""%prog update_db_from_model URL REPOSITORY_PATH MODEL
Modify the database to match the structure of the current Python
model. This also sets the db_version number to the latest in the
repository.
NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label
engine = opts.pop('engine')
schema = ControlledSchema(engine, repository)
schema.update_db_from_model(model)
@with_engine
def _migrate(url, repository, version, upgrade, err, **opts):
engine = opts.pop('engine')
url = str(engine.url)
schema = ControlledSchema(engine, repository)
version = _migrate_version(schema, version, upgrade, err)
changeset = schema.changeset(version)
for ver, change in changeset:
nextver = ver + changeset.step
log.info('%s -> %s... ', ver, nextver)
if opts.get('preview_sql'):
if isinstance(change, PythonScript):
log.info(change.preview_sql(url, changeset.step, **opts))
elif isinstance(change, SqlScript):
log.info(change.source())
elif opts.get('preview_py'):
if not isinstance(change, PythonScript):
raise exceptions.UsageError("Python source can be only displayed"
" for python migration files")
source_ver = max(ver, nextver)
module = schema.repository.version(source_ver).script().module
funcname = upgrade and "upgrade" or "downgrade"
func = getattr(module, funcname)
log.info(inspect.getsource(func))
else:
schema.runchange(ver, change, changeset.step)
log.info('done')
def _migrate_version(schema, version, upgrade, err):
if version is None:
return version
# Version is specified: ensure we're upgrading in the right direction
# (current version < target version for upgrading; reverse for down)
version = VerNum(version)
cur = schema.version
if upgrade is not None:
if upgrade:
direction = cur <= version
else:
direction = cur >= version
if not direction:
raise exceptions.KnownError(err % (cur, version))
return version

View File

@@ -1,27 +0,0 @@
"""
Configuration parser module.
"""
from ConfigParser import ConfigParser
from migrate.versioning.config import *
from migrate.versioning import pathed
class Parser(ConfigParser):
"""A project configuration file."""
def to_dict(self, sections=None):
"""It's easier to access config values like dictionaries"""
return self._sections
class Config(pathed.Pathed, Parser):
"""Configuration class."""
def __init__(self, path, *p, **k):
"""Confirm the config file exists; read it."""
self.require_found(path)
pathed.Pathed.__init__(self, path)
Parser.__init__(self, *p, **k)
self.read(path)

View File

@@ -1,14 +0,0 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
from sqlalchemy.util import OrderedDict
__all__ = ['databases', 'operations']
databases = ('sqlite', 'postgres', 'mysql', 'oracle', 'mssql', 'firebird')
# Map operation names to function names
operations = OrderedDict()
operations['upgrade'] = 'upgrade'
operations['downgrade'] = 'downgrade'

View File

@@ -1,285 +0,0 @@
"""
Code to generate a Python model from a database or differences
between a model and database.
Some of this is borrowed heavily from the AutoCode project at:
http://code.google.com/p/sqlautocode/
"""
import sys
import logging
import sqlalchemy
import migrate
import migrate.changeset
log = logging.getLogger(__name__)
HEADER = """
## File autogenerated by genmodel.py
from sqlalchemy import *
meta = MetaData()
"""
DECLARATIVE_HEADER = """
## File autogenerated by genmodel.py
from sqlalchemy import *
from sqlalchemy.ext import declarative
Base = declarative.declarative_base()
"""
class ModelGenerator(object):
"""Various transformations from an A, B diff.
In the implementation, A tends to be called the model and B
the database (although this is not true of all diffs).
The diff is directionless, but transformations apply the diff
in a particular direction, described in the method name.
"""
def __init__(self, diff, engine, declarative=False):
self.diff = diff
self.engine = engine
self.declarative = declarative
def column_repr(self, col):
kwarg = []
if col.key != col.name:
kwarg.append('key')
if col.primary_key:
col.primary_key = True # otherwise it dumps it as 1
kwarg.append('primary_key')
if not col.nullable:
kwarg.append('nullable')
if col.onupdate:
kwarg.append('onupdate')
if col.default:
if col.primary_key:
# I found that PostgreSQL automatically creates a
# default value for the sequence, but let's not show
# that.
pass
else:
kwarg.append('default')
args = ['%s=%r' % (k, getattr(col, k)) for k in kwarg]
# crs: not sure if this is good idea, but it gets rid of extra
# u''
name = col.name.encode('utf8')
type_ = col.type
for cls in col.type.__class__.__mro__:
if cls.__module__ == 'sqlalchemy.types' and \
not cls.__name__.isupper():
if cls is not type_.__class__:
type_ = cls()
break
type_repr = repr(type_)
if type_repr.endswith('()'):
type_repr = type_repr[:-2]
constraints = [repr(cn) for cn in col.constraints]
data = {
'name': name,
'commonStuff': ', '.join([type_repr] + constraints + args),
}
if self.declarative:
return """%(name)s = Column(%(commonStuff)s)""" % data
else:
return """Column(%(name)r, %(commonStuff)s)""" % data
def _getTableDefn(self, table, metaName='meta'):
out = []
tableName = table.name
if self.declarative:
out.append("class %(table)s(Base):" % {'table': tableName})
out.append(" __tablename__ = '%(table)s'\n" %
{'table': tableName})
for col in table.columns:
out.append(" %s" % self.column_repr(col))
out.append('\n')
else:
out.append("%(table)s = Table('%(table)s', %(meta)s," %
{'table': tableName, 'meta': metaName})
for col in table.columns:
out.append(" %s," % self.column_repr(col))
out.append(")\n")
return out
def _get_tables(self,missingA=False,missingB=False,modified=False):
to_process = []
for bool_,names,metadata in (
(missingA,self.diff.tables_missing_from_A,self.diff.metadataB),
(missingB,self.diff.tables_missing_from_B,self.diff.metadataA),
(modified,self.diff.tables_different,self.diff.metadataA),
):
if bool_:
for name in names:
yield metadata.tables.get(name)
def genBDefinition(self):
"""Generates the source code for a definition of B.
Assumes a diff where A is empty.
Was: toPython. Assume database (B) is current and model (A) is empty.
"""
out = []
if self.declarative:
out.append(DECLARATIVE_HEADER)
else:
out.append(HEADER)
out.append("")
for table in self._get_tables(missingA=True):
out.extend(self._getTableDefn(table))
return '\n'.join(out)
def genB2AMigration(self, indent=' '):
'''Generate a migration from B to A.
Was: toUpgradeDowngradePython
Assume model (A) is most current and database (B) is out-of-date.
'''
decls = ['from migrate.changeset import schema',
'pre_meta = MetaData()',
'post_meta = MetaData()',
]
upgradeCommands = ['pre_meta.bind = migrate_engine',
'post_meta.bind = migrate_engine']
downgradeCommands = list(upgradeCommands)
for tn in self.diff.tables_missing_from_A:
pre_table = self.diff.metadataB.tables[tn]
decls.extend(self._getTableDefn(pre_table, metaName='pre_meta'))
upgradeCommands.append(
"pre_meta.tables[%(table)r].drop()" % {'table': tn})
downgradeCommands.append(
"pre_meta.tables[%(table)r].create()" % {'table': tn})
for tn in self.diff.tables_missing_from_B:
post_table = self.diff.metadataA.tables[tn]
decls.extend(self._getTableDefn(post_table, metaName='post_meta'))
upgradeCommands.append(
"post_meta.tables[%(table)r].create()" % {'table': tn})
downgradeCommands.append(
"post_meta.tables[%(table)r].drop()" % {'table': tn})
for (tn, td) in self.diff.tables_different.iteritems():
if td.columns_missing_from_A or td.columns_different:
pre_table = self.diff.metadataB.tables[tn]
decls.extend(self._getTableDefn(
pre_table, metaName='pre_meta'))
if td.columns_missing_from_B or td.columns_different:
post_table = self.diff.metadataA.tables[tn]
decls.extend(self._getTableDefn(
post_table, metaName='post_meta'))
for col in td.columns_missing_from_A:
upgradeCommands.append(
'pre_meta.tables[%r].columns[%r].drop()' % (tn, col))
downgradeCommands.append(
'pre_meta.tables[%r].columns[%r].create()' % (tn, col))
for col in td.columns_missing_from_B:
upgradeCommands.append(
'post_meta.tables[%r].columns[%r].create()' % (tn, col))
downgradeCommands.append(
'post_meta.tables[%r].columns[%r].drop()' % (tn, col))
for modelCol, databaseCol, modelDecl, databaseDecl in td.columns_different:
upgradeCommands.append(
'assert False, "Can\'t alter columns: %s:%s=>%s"' % (
tn, modelCol.name, databaseCol.name))
downgradeCommands.append(
'assert False, "Can\'t alter columns: %s:%s=>%s"' % (
tn, modelCol.name, databaseCol.name))
return (
'\n'.join(decls),
'\n'.join('%s%s' % (indent, line) for line in upgradeCommands),
'\n'.join('%s%s' % (indent, line) for line in downgradeCommands))
def _db_can_handle_this_change(self,td):
"""Check if the database can handle going from B to A."""
if (td.columns_missing_from_B
and not td.columns_missing_from_A
and not td.columns_different):
# Even sqlite can handle column additions.
return True
else:
return not self.engine.url.drivername.startswith('sqlite')
def runB2A(self):
"""Goes from B to A.
Was: applyModel. Apply model (A) to current database (B).
"""
meta = sqlalchemy.MetaData(self.engine)
for table in self._get_tables(missingA=True):
table = table.tometadata(meta)
table.drop()
for table in self._get_tables(missingB=True):
table = table.tometadata(meta)
table.create()
for modelTable in self._get_tables(modified=True):
tableName = modelTable.name
modelTable = modelTable.tometadata(meta)
dbTable = self.diff.metadataB.tables[tableName]
td = self.diff.tables_different[tableName]
if self._db_can_handle_this_change(td):
for col in td.columns_missing_from_B:
modelTable.columns[col].create()
for col in td.columns_missing_from_A:
dbTable.columns[col].drop()
# XXX handle column changes here.
else:
# Sqlite doesn't support drop column, so you have to
# do more: create temp table, copy data to it, drop
# old table, create new table, copy data back.
#
# I wonder if this is guaranteed to be unique?
tempName = '_temp_%s' % modelTable.name
def getCopyStatement():
preparer = self.engine.dialect.preparer
commonCols = []
for modelCol in modelTable.columns:
if modelCol.name in dbTable.columns:
commonCols.append(modelCol.name)
commonColsStr = ', '.join(commonCols)
return 'INSERT INTO %s (%s) SELECT %s FROM %s' % \
(tableName, commonColsStr, commonColsStr, tempName)
# Move the data in one transaction, so that we don't
# leave the database in a nasty state.
connection = self.engine.connect()
trans = connection.begin()
try:
connection.execute(
'CREATE TEMPORARY TABLE %s as SELECT * from %s' % \
(tempName, modelTable.name))
# make sure the drop takes place inside our
# transaction with the bind parameter
modelTable.drop(bind=connection)
modelTable.create(bind=connection)
connection.execute(getCopyStatement())
connection.execute('DROP TABLE %s' % tempName)
trans.commit()
except:
trans.rollback()
raise

View File

@@ -1,100 +0,0 @@
"""
Script to migrate repository from sqlalchemy <= 0.4.4 to the new
repository schema. This shouldn't use any other migrate modules, so
that it can work in any version.
"""
import os
import sys
import logging
log = logging.getLogger(__name__)
def usage():
"""Gives usage information."""
print """Usage: %(prog)s repository-to-migrate
Upgrade your repository to the new flat format.
NOTE: You should probably make a backup before running this.
""" % {'prog': sys.argv[0]}
sys.exit(1)
def delete_file(filepath):
"""Deletes a file and prints a message."""
log.info('Deleting file: %s' % filepath)
os.remove(filepath)
def move_file(src, tgt):
"""Moves a file and prints a message."""
log.info('Moving file %s to %s' % (src, tgt))
if os.path.exists(tgt):
raise Exception(
'Cannot move file %s because target %s already exists' % \
(src, tgt))
os.rename(src, tgt)
def delete_directory(dirpath):
"""Delete a directory and print a message."""
log.info('Deleting directory: %s' % dirpath)
os.rmdir(dirpath)
def migrate_repository(repos):
"""Does the actual migration to the new repository format."""
log.info('Migrating repository at: %s to new format' % repos)
versions = '%s/versions' % repos
dirs = os.listdir(versions)
# Only use int's in list.
numdirs = [int(dirname) for dirname in dirs if dirname.isdigit()]
numdirs.sort() # Sort list.
for dirname in numdirs:
origdir = '%s/%s' % (versions, dirname)
log.info('Working on directory: %s' % origdir)
files = os.listdir(origdir)
files.sort()
for filename in files:
# Delete compiled Python files.
if filename.endswith('.pyc') or filename.endswith('.pyo'):
delete_file('%s/%s' % (origdir, filename))
# Delete empty __init__.py files.
origfile = '%s/__init__.py' % origdir
if os.path.exists(origfile) and len(open(origfile).read()) == 0:
delete_file(origfile)
# Move sql upgrade scripts.
if filename.endswith('.sql'):
version, dbms, operation = filename.split('.', 3)[0:3]
origfile = '%s/%s' % (origdir, filename)
# For instance: 2.postgres.upgrade.sql ->
# 002_postgres_upgrade.sql
tgtfile = '%s/%03d_%s_%s.sql' % (
versions, int(version), dbms, operation)
move_file(origfile, tgtfile)
# Move Python upgrade script.
pyfile = '%s.py' % dirname
pyfilepath = '%s/%s' % (origdir, pyfile)
if os.path.exists(pyfilepath):
tgtfile = '%s/%03d.py' % (versions, int(dirname))
move_file(pyfilepath, tgtfile)
# Try to remove directory. Will fail if it's not empty.
delete_directory(origdir)
def main():
"""Main function to be called when using this script."""
if len(sys.argv) != 2:
usage()
migrate_repository(sys.argv[1])
if __name__ == '__main__':
main()

View File

@@ -1,75 +0,0 @@
"""
A path/directory class.
"""
import os
import shutil
import logging
from migrate import exceptions
from migrate.versioning.config import *
from migrate.versioning.util import KeyedInstance
log = logging.getLogger(__name__)
class Pathed(KeyedInstance):
"""
A class associated with a path/directory tree.
Only one instance of this class may exist for a particular file;
__new__ will return an existing instance if possible
"""
parent = None
@classmethod
def _key(cls, path):
return str(path)
def __init__(self, path):
self.path = path
if self.__class__.parent is not None:
self._init_parent(path)
def _init_parent(self, path):
"""Try to initialize this object's parent, if it has one"""
parent_path = self.__class__._parent_path(path)
self.parent = self.__class__.parent(parent_path)
log.debug("Getting parent %r:%r" % (self.__class__.parent, parent_path))
self.parent._init_child(path, self)
def _init_child(self, child, path):
"""Run when a child of this object is initialized.
Parameters: the child object; the path to this object (its
parent)
"""
@classmethod
def _parent_path(cls, path):
"""
Fetch the path of this object's parent from this object's path.
"""
# os.path.dirname(), but strip directories like files (like
# unix basename)
#
# Treat directories like files...
if path[-1] == '/':
path = path[:-1]
ret = os.path.dirname(path)
return ret
@classmethod
def require_notfound(cls, path):
"""Ensures a given path does not already exist"""
if os.path.exists(path):
raise exceptions.PathFoundError(path)
@classmethod
def require_found(cls, path):
"""Ensures a given path already exists"""
if not os.path.exists(path):
raise exceptions.PathNotFoundError(path)
def __str__(self):
return self.path

View File

@@ -1,242 +0,0 @@
"""
SQLAlchemy migrate repository management.
"""
import os
import shutil
import string
import logging
from pkg_resources import resource_filename
from tempita import Template as TempitaTemplate
from migrate import exceptions
from migrate.versioning import version, pathed, cfgparse
from migrate.versioning.template import Template
from migrate.versioning.config import *
log = logging.getLogger(__name__)
class Changeset(dict):
"""A collection of changes to be applied to a database.
Changesets are bound to a repository and manage a set of
scripts from that repository.
Behaves like a dict, for the most part. Keys are ordered based on step value.
"""
def __init__(self, start, *changes, **k):
"""
Give a start version; step must be explicitly stated.
"""
self.step = k.pop('step', 1)
self.start = version.VerNum(start)
self.end = self.start
for change in changes:
self.add(change)
def __iter__(self):
return iter(self.items())
def keys(self):
"""
In a series of upgrades x -> y, keys are version x. Sorted.
"""
ret = super(Changeset, self).keys()
# Reverse order if downgrading
ret.sort(reverse=(self.step < 1))
return ret
def values(self):
return [self[k] for k in self.keys()]
def items(self):
return zip(self.keys(), self.values())
def add(self, change):
"""Add new change to changeset"""
key = self.end
self.end += self.step
self[key] = change
def run(self, *p, **k):
"""Run the changeset scripts"""
for version, script in self:
script.run(*p, **k)
class Repository(pathed.Pathed):
"""A project's change script repository"""
_config = 'migrate.cfg'
_versions = 'versions'
def __init__(self, path):
log.debug('Loading repository %s...' % path)
self.verify(path)
super(Repository, self).__init__(path)
self.config = cfgparse.Config(os.path.join(self.path, self._config))
self.versions = version.Collection(os.path.join(self.path,
self._versions))
log.debug('Repository %s loaded successfully' % path)
log.debug('Config: %r' % self.config.to_dict())
@classmethod
def verify(cls, path):
"""
Ensure the target path is a valid repository.
:raises: :exc:`InvalidRepositoryError <migrate.exceptions.InvalidRepositoryError>`
"""
# Ensure the existence of required files
try:
cls.require_found(path)
cls.require_found(os.path.join(path, cls._config))
cls.require_found(os.path.join(path, cls._versions))
except exceptions.PathNotFoundError, e:
raise exceptions.InvalidRepositoryError(path)
@classmethod
def prepare_config(cls, tmpl_dir, name, options=None):
"""
Prepare a project configuration file for a new project.
:param tmpl_dir: Path to Repository template
:param config_file: Name of the config file in Repository template
:param name: Repository name
:type tmpl_dir: string
:type config_file: string
:type name: string
:returns: Populated config file
"""
if options is None:
options = {}
options.setdefault('version_table', 'migrate_version')
options.setdefault('repository_id', name)
options.setdefault('required_dbs', [])
options.setdefault('use_timestamp_numbering', False)
tmpl = open(os.path.join(tmpl_dir, cls._config)).read()
ret = TempitaTemplate(tmpl).substitute(options)
# cleanup
del options['__template_name__']
return ret
@classmethod
def create(cls, path, name, **opts):
"""Create a repository at a specified path"""
cls.require_notfound(path)
theme = opts.pop('templates_theme', None)
t_path = opts.pop('templates_path', None)
# Create repository
tmpl_dir = Template(t_path).get_repository(theme=theme)
shutil.copytree(tmpl_dir, path)
# Edit config defaults
config_text = cls.prepare_config(tmpl_dir, name, options=opts)
fd = open(os.path.join(path, cls._config), 'w')
fd.write(config_text)
fd.close()
opts['repository_name'] = name
# Create a management script
manager = os.path.join(path, 'manage.py')
Repository.create_manage_file(manager, templates_theme=theme,
templates_path=t_path, **opts)
return cls(path)
def create_script(self, description, **k):
"""API to :meth:`migrate.versioning.version.Collection.create_new_python_version`"""
k['use_timestamp_numbering'] = self.use_timestamp_numbering
self.versions.create_new_python_version(description, **k)
def create_script_sql(self, database, description, **k):
"""API to :meth:`migrate.versioning.version.Collection.create_new_sql_version`"""
k['use_timestamp_numbering'] = self.use_timestamp_numbering
self.versions.create_new_sql_version(database, description, **k)
@property
def latest(self):
"""API to :attr:`migrate.versioning.version.Collection.latest`"""
return self.versions.latest
@property
def version_table(self):
"""Returns version_table name specified in config"""
return self.config.get('db_settings', 'version_table')
@property
def id(self):
"""Returns repository id specified in config"""
return self.config.get('db_settings', 'repository_id')
@property
def use_timestamp_numbering(self):
"""Returns use_timestamp_numbering specified in config"""
if self.config.has_option('db_settings', 'use_timestamp_numbering'):
return self.config.getboolean('db_settings', 'use_timestamp_numbering')
return False
def version(self, *p, **k):
"""API to :attr:`migrate.versioning.version.Collection.version`"""
return self.versions.version(*p, **k)
@classmethod
def clear(cls):
# TODO: deletes repo
super(Repository, cls).clear()
version.Collection.clear()
def changeset(self, database, start, end=None):
"""Create a changeset to migrate this database from ver. start to end/latest.
:param database: name of database to generate changeset
:param start: version to start at
:param end: version to end at (latest if None given)
:type database: string
:type start: int
:type end: int
:returns: :class:`Changeset instance <migration.versioning.repository.Changeset>`
"""
start = version.VerNum(start)
if end is None:
end = self.latest
else:
end = version.VerNum(end)
if start <= end:
step = 1
range_mod = 1
op = 'upgrade'
else:
step = -1
range_mod = 0
op = 'downgrade'
versions = range(start + range_mod, end + range_mod, step)
changes = [self.version(v).script(database, op) for v in versions]
ret = Changeset(start, step=step, *changes)
return ret
@classmethod
def create_manage_file(cls, file_, **opts):
"""Create a project management script (manage.py)
:param file_: Destination file to be written
:param opts: Options that are passed to :func:`migrate.versioning.shell.main`
"""
mng_file = Template(opts.pop('templates_path', None))\
.get_manage(theme=opts.pop('templates_theme', None))
tmpl = open(mng_file).read()
fd = open(file_, 'w')
fd.write(TempitaTemplate(tmpl).substitute(opts))
fd.close()

View File

@@ -1,220 +0,0 @@
"""
Database schema version management.
"""
import sys
import logging
from sqlalchemy import (Table, Column, MetaData, String, Text, Integer,
create_engine)
from sqlalchemy.sql import and_
from sqlalchemy import exceptions as sa_exceptions
from sqlalchemy.sql import bindparam
from migrate import exceptions
from migrate.changeset import SQLA_07
from migrate.versioning import genmodel, schemadiff
from migrate.versioning.repository import Repository
from migrate.versioning.util import load_model
from migrate.versioning.version import VerNum
log = logging.getLogger(__name__)
class ControlledSchema(object):
"""A database under version control"""
def __init__(self, engine, repository):
if isinstance(repository, basestring):
repository = Repository(repository)
self.engine = engine
self.repository = repository
self.meta = MetaData(engine)
self.load()
def __eq__(self, other):
"""Compare two schemas by repositories and versions"""
return (self.repository is other.repository \
and self.version == other.version)
def load(self):
"""Load controlled schema version info from DB"""
tname = self.repository.version_table
try:
if not hasattr(self, 'table') or self.table is None:
self.table = Table(tname, self.meta, autoload=True)
result = self.engine.execute(self.table.select(
self.table.c.repository_id == str(self.repository.id)))
data = list(result)[0]
except:
cls, exc, tb = sys.exc_info()
raise exceptions.DatabaseNotControlledError, exc.__str__(), tb
self.version = data['version']
return data
def drop(self):
"""
Remove version control from a database.
"""
if SQLA_07:
try:
self.table.drop()
except sa_exceptions.DatabaseError:
raise exceptions.DatabaseNotControlledError(str(self.table))
else:
try:
self.table.drop()
except (sa_exceptions.SQLError):
raise exceptions.DatabaseNotControlledError(str(self.table))
def changeset(self, version=None):
"""API to Changeset creation.
Uses self.version for start version and engine.name
to get database name.
"""
database = self.engine.name
start_ver = self.version
changeset = self.repository.changeset(database, start_ver, version)
return changeset
def runchange(self, ver, change, step):
startver = ver
endver = ver + step
# Current database version must be correct! Don't run if corrupt!
if self.version != startver:
raise exceptions.InvalidVersionError("%s is not %s" % \
(self.version, startver))
# Run the change
change.run(self.engine, step)
# Update/refresh database version
self.update_repository_table(startver, endver)
self.load()
def update_repository_table(self, startver, endver):
"""Update version_table with new information"""
update = self.table.update(and_(self.table.c.version == int(startver),
self.table.c.repository_id == str(self.repository.id)))
self.engine.execute(update, version=int(endver))
def upgrade(self, version=None):
"""
Upgrade (or downgrade) to a specified version, or latest version.
"""
changeset = self.changeset(version)
for ver, change in changeset:
self.runchange(ver, change, changeset.step)
def update_db_from_model(self, model):
"""
Modify the database to match the structure of the current Python model.
"""
model = load_model(model)
diff = schemadiff.getDiffOfModelAgainstDatabase(
model, self.engine, excludeTables=[self.repository.version_table]
)
genmodel.ModelGenerator(diff,self.engine).runB2A()
self.update_repository_table(self.version, int(self.repository.latest))
self.load()
@classmethod
def create(cls, engine, repository, version=None):
"""
Declare a database to be under a repository's version control.
:raises: :exc:`DatabaseAlreadyControlledError`
:returns: :class:`ControlledSchema`
"""
# Confirm that the version # is valid: positive, integer,
# exists in repos
if isinstance(repository, basestring):
repository = Repository(repository)
version = cls._validate_version(repository, version)
table = cls._create_table_version(engine, repository, version)
# TODO: history table
# Load repository information and return
return cls(engine, repository)
@classmethod
def _validate_version(cls, repository, version):
"""
Ensures this is a valid version number for this repository.
:raises: :exc:`InvalidVersionError` if invalid
:return: valid version number
"""
if version is None:
version = 0
try:
version = VerNum(version) # raises valueerror
if version < 0 or version > repository.latest:
raise ValueError()
except ValueError:
raise exceptions.InvalidVersionError(version)
return version
@classmethod
def _create_table_version(cls, engine, repository, version):
"""
Creates the versioning table in a database.
:raises: :exc:`DatabaseAlreadyControlledError`
"""
# Create tables
tname = repository.version_table
meta = MetaData(engine)
table = Table(
tname, meta,
Column('repository_id', String(250), primary_key=True),
Column('repository_path', Text),
Column('version', Integer), )
# there can be multiple repositories/schemas in the same db
if not table.exists():
table.create()
# test for existing repository_id
s = table.select(table.c.repository_id == bindparam("repository_id"))
result = engine.execute(s, repository_id=repository.id)
if result.fetchone():
raise exceptions.DatabaseAlreadyControlledError
# Insert data
engine.execute(table.insert().values(
repository_id=repository.id,
repository_path=repository.path,
version=int(version)))
return table
@classmethod
def compare_model_to_db(cls, engine, model, repository):
"""
Compare the current model against the current database.
"""
if isinstance(repository, basestring):
repository = Repository(repository)
model = load_model(model)
diff = schemadiff.getDiffOfModelAgainstDatabase(
model, engine, excludeTables=[repository.version_table])
return diff
@classmethod
def create_model(cls, engine, repository, declarative=False):
"""
Dump the current database as a Python model.
"""
if isinstance(repository, basestring):
repository = Repository(repository)
diff = schemadiff.getDiffOfModelAgainstDatabase(
MetaData(), engine, excludeTables=[repository.version_table]
)
return genmodel.ModelGenerator(diff, engine, declarative).genBDefinition()

View File

@@ -1,292 +0,0 @@
"""
Schema differencing support.
"""
import logging
import sqlalchemy
from sqlalchemy.types import Float
log = logging.getLogger(__name__)
def getDiffOfModelAgainstDatabase(metadata, engine, excludeTables=None):
"""
Return differences of model against database.
:return: object which will evaluate to :keyword:`True` if there \
are differences else :keyword:`False`.
"""
db_metadata = sqlalchemy.MetaData(engine, reflect=True)
# sqlite will include a dynamically generated 'sqlite_sequence' table if
# there are autoincrement sequences in the database; this should not be
# compared.
if engine.dialect.name == 'sqlite':
if 'sqlite_sequence' in db_metadata.tables:
db_metadata.remove(db_metadata.tables['sqlite_sequence'])
return SchemaDiff(metadata, db_metadata,
labelA='model',
labelB='database',
excludeTables=excludeTables)
def getDiffOfModelAgainstModel(metadataA, metadataB, excludeTables=None):
"""
Return differences of model against another model.
:return: object which will evaluate to :keyword:`True` if there \
are differences else :keyword:`False`.
"""
return SchemaDiff(metadataA, metadataB, excludeTables)
class ColDiff(object):
"""
Container for differences in one :class:`~sqlalchemy.schema.Column`
between two :class:`~sqlalchemy.schema.Table` instances, ``A``
and ``B``.
.. attribute:: col_A
The :class:`~sqlalchemy.schema.Column` object for A.
.. attribute:: col_B
The :class:`~sqlalchemy.schema.Column` object for B.
.. attribute:: type_A
The most generic type of the :class:`~sqlalchemy.schema.Column`
object in A.
.. attribute:: type_B
The most generic type of the :class:`~sqlalchemy.schema.Column`
object in A.
"""
diff = False
def __init__(self,col_A,col_B):
self.col_A = col_A
self.col_B = col_B
self.type_A = col_A.type
self.type_B = col_B.type
self.affinity_A = self.type_A._type_affinity
self.affinity_B = self.type_B._type_affinity
if self.affinity_A is not self.affinity_B:
self.diff = True
return
if isinstance(self.type_A,Float) or isinstance(self.type_B,Float):
if not (isinstance(self.type_A,Float) and isinstance(self.type_B,Float)):
self.diff=True
return
for attr in ('precision','scale','length'):
A = getattr(self.type_A,attr,None)
B = getattr(self.type_B,attr,None)
if not (A is None or B is None) and A!=B:
self.diff=True
return
def __nonzero__(self):
return self.diff
class TableDiff(object):
"""
Container for differences in one :class:`~sqlalchemy.schema.Table`
between two :class:`~sqlalchemy.schema.MetaData` instances, ``A``
and ``B``.
.. attribute:: columns_missing_from_A
A sequence of column names that were found in B but weren't in
A.
.. attribute:: columns_missing_from_B
A sequence of column names that were found in A but weren't in
B.
.. attribute:: columns_different
A dictionary containing information about columns that were
found to be different.
It maps column names to a :class:`ColDiff` objects describing the
differences found.
"""
__slots__ = (
'columns_missing_from_A',
'columns_missing_from_B',
'columns_different',
)
def __nonzero__(self):
return bool(
self.columns_missing_from_A or
self.columns_missing_from_B or
self.columns_different
)
class SchemaDiff(object):
"""
Compute the difference between two :class:`~sqlalchemy.schema.MetaData`
objects.
The string representation of a :class:`SchemaDiff` will summarise
the changes found between the two
:class:`~sqlalchemy.schema.MetaData` objects.
The length of a :class:`SchemaDiff` will give the number of
changes found, enabling it to be used much like a boolean in
expressions.
:param metadataA:
First :class:`~sqlalchemy.schema.MetaData` to compare.
:param metadataB:
Second :class:`~sqlalchemy.schema.MetaData` to compare.
:param labelA:
The label to use in messages about the first
:class:`~sqlalchemy.schema.MetaData`.
:param labelB:
The label to use in messages about the second
:class:`~sqlalchemy.schema.MetaData`.
:param excludeTables:
A sequence of table names to exclude.
.. attribute:: tables_missing_from_A
A sequence of table names that were found in B but weren't in
A.
.. attribute:: tables_missing_from_B
A sequence of table names that were found in A but weren't in
B.
.. attribute:: tables_different
A dictionary containing information about tables that were found
to be different.
It maps table names to a :class:`TableDiff` objects describing the
differences found.
"""
def __init__(self,
metadataA, metadataB,
labelA='metadataA',
labelB='metadataB',
excludeTables=None):
self.metadataA, self.metadataB = metadataA, metadataB
self.labelA, self.labelB = labelA, labelB
self.label_width = max(len(labelA),len(labelB))
excludeTables = set(excludeTables or [])
A_table_names = set(metadataA.tables.keys())
B_table_names = set(metadataB.tables.keys())
self.tables_missing_from_A = sorted(
B_table_names - A_table_names - excludeTables
)
self.tables_missing_from_B = sorted(
A_table_names - B_table_names - excludeTables
)
self.tables_different = {}
for table_name in A_table_names.intersection(B_table_names):
td = TableDiff()
A_table = metadataA.tables[table_name]
B_table = metadataB.tables[table_name]
A_column_names = set(A_table.columns.keys())
B_column_names = set(B_table.columns.keys())
td.columns_missing_from_A = sorted(
B_column_names - A_column_names
)
td.columns_missing_from_B = sorted(
A_column_names - B_column_names
)
td.columns_different = {}
for col_name in A_column_names.intersection(B_column_names):
cd = ColDiff(
A_table.columns.get(col_name),
B_table.columns.get(col_name)
)
if cd:
td.columns_different[col_name]=cd
# XXX - index and constraint differences should
# be checked for here
if td:
self.tables_different[table_name]=td
def __str__(self):
''' Summarize differences. '''
out = []
column_template =' %%%is: %%r' % self.label_width
for names,label in (
(self.tables_missing_from_A,self.labelA),
(self.tables_missing_from_B,self.labelB),
):
if names:
out.append(
' tables missing from %s: %s' % (
label,', '.join(sorted(names))
)
)
for name,td in sorted(self.tables_different.items()):
out.append(
' table with differences: %s' % name
)
for names,label in (
(td.columns_missing_from_A,self.labelA),
(td.columns_missing_from_B,self.labelB),
):
if names:
out.append(
' %s missing these columns: %s' % (
label,', '.join(sorted(names))
)
)
for name,cd in td.columns_different.items():
out.append(' column with differences: %s' % name)
out.append(column_template % (self.labelA,cd.col_A))
out.append(column_template % (self.labelB,cd.col_B))
if out:
out.insert(0, 'Schema diffs:')
return '\n'.join(out)
else:
return 'No schema diffs'
def __len__(self):
"""
Used in bool evaluation, return of 0 means no diffs.
"""
return (
len(self.tables_missing_from_A) +
len(self.tables_missing_from_B) +
len(self.tables_different)
)

View File

@@ -1,6 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from migrate.versioning.script.base import BaseScript
from migrate.versioning.script.py import PythonScript
from migrate.versioning.script.sql import SqlScript

View File

@@ -1,57 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
from migrate import exceptions
from migrate.versioning.config import operations
from migrate.versioning import pathed
log = logging.getLogger(__name__)
class BaseScript(pathed.Pathed):
"""Base class for other types of scripts.
All scripts have the following properties:
source (script.source())
The source code of the script
version (script.version())
The version number of the script
operations (script.operations())
The operations defined by the script: upgrade(), downgrade() or both.
Returns a tuple of operations.
Can also check for an operation with ex. script.operation(Script.ops.up)
""" # TODO: sphinxfy this and implement it correctly
def __init__(self, path):
log.debug('Loading script %s...' % path)
self.verify(path)
super(BaseScript, self).__init__(path)
log.debug('Script %s loaded successfully' % path)
@classmethod
def verify(cls, path):
"""Ensure this is a valid script
This version simply ensures the script file's existence
:raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>`
"""
try:
cls.require_found(path)
except:
raise exceptions.InvalidScriptError(path)
def source(self):
""":returns: source code of the script.
:rtype: string
"""
fd = open(self.path)
ret = fd.read()
fd.close()
return ret
def run(self, engine):
"""Core of each BaseScript subclass.
This method executes the script.
"""
raise NotImplementedError()

View File

@@ -1,160 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import shutil
import warnings
import logging
import inspect
from StringIO import StringIO
import migrate
from migrate.versioning import genmodel, schemadiff
from migrate.versioning.config import operations
from migrate.versioning.template import Template
from migrate.versioning.script import base
from migrate.versioning.util import import_path, load_model, with_engine
from migrate.exceptions import MigrateDeprecationWarning, InvalidScriptError, ScriptError
log = logging.getLogger(__name__)
__all__ = ['PythonScript']
class PythonScript(base.BaseScript):
"""Base for Python scripts"""
@classmethod
def create(cls, path, **opts):
"""Create an empty migration script at specified path
:returns: :class:`PythonScript instance <migrate.versioning.script.py.PythonScript>`"""
cls.require_notfound(path)
src = Template(opts.pop('templates_path', None)).get_script(theme=opts.pop('templates_theme', None))
shutil.copy(src, path)
return cls(path)
@classmethod
def make_update_script_for_model(cls, engine, oldmodel,
model, repository, **opts):
"""Create a migration script based on difference between two SA models.
:param repository: path to migrate repository
:param oldmodel: dotted.module.name:SAClass or SAClass object
:param model: dotted.module.name:SAClass or SAClass object
:param engine: SQLAlchemy engine
:type repository: string or :class:`Repository instance <migrate.versioning.repository.Repository>`
:type oldmodel: string or Class
:type model: string or Class
:type engine: Engine instance
:returns: Upgrade / Downgrade script
:rtype: string
"""
if isinstance(repository, basestring):
# oh dear, an import cycle!
from migrate.versioning.repository import Repository
repository = Repository(repository)
oldmodel = load_model(oldmodel)
model = load_model(model)
# Compute differences.
diff = schemadiff.getDiffOfModelAgainstModel(
model,
oldmodel,
excludeTables=[repository.version_table])
# TODO: diff can be False (there is no difference?)
decls, upgradeCommands, downgradeCommands = \
genmodel.ModelGenerator(diff,engine).genB2AMigration()
# Store differences into file.
src = Template(opts.pop('templates_path', None)).get_script(opts.pop('templates_theme', None))
f = open(src)
contents = f.read()
f.close()
# generate source
search = 'def upgrade(migrate_engine):'
contents = contents.replace(search, '\n\n'.join((decls, search)), 1)
if upgradeCommands:
contents = contents.replace(' pass', upgradeCommands, 1)
if downgradeCommands:
contents = contents.replace(' pass', downgradeCommands, 1)
return contents
@classmethod
def verify_module(cls, path):
"""Ensure path is a valid script
:param path: Script location
:type path: string
:raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>`
:returns: Python module
"""
# Try to import and get the upgrade() func
module = import_path(path)
try:
assert callable(module.upgrade)
except Exception, e:
raise InvalidScriptError(path + ': %s' % str(e))
return module
def preview_sql(self, url, step, **args):
"""Mocks SQLAlchemy Engine to store all executed calls in a string
and runs :meth:`PythonScript.run <migrate.versioning.script.py.PythonScript.run>`
:returns: SQL file
"""
buf = StringIO()
args['engine_arg_strategy'] = 'mock'
args['engine_arg_executor'] = lambda s, p = '': buf.write(str(s) + p)
@with_engine
def go(url, step, **kw):
engine = kw.pop('engine')
self.run(engine, step)
return buf.getvalue()
return go(url, step, **args)
def run(self, engine, step):
"""Core method of Script file.
Exectues :func:`update` or :func:`downgrade` functions
:param engine: SQLAlchemy Engine
:param step: Operation to run
:type engine: string
:type step: int
"""
if step > 0:
op = 'upgrade'
elif step < 0:
op = 'downgrade'
else:
raise ScriptError("%d is not a valid step" % step)
funcname = base.operations[op]
script_func = self._func(funcname)
# check for old way of using engine
if not inspect.getargspec(script_func)[0]:
raise TypeError("upgrade/downgrade functions must accept engine"
" parameter (since version 0.5.4)")
script_func(engine)
@property
def module(self):
"""Calls :meth:`migrate.versioning.script.py.verify_module`
and returns it.
"""
if not hasattr(self, '_module'):
self._module = self.verify_module(self.path)
return self._module
def _func(self, funcname):
if not hasattr(self.module, funcname):
msg = "Function '%s' is not defined in this script"
raise ScriptError(msg % funcname)
return getattr(self.module, funcname)

View File

@@ -1,49 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
import shutil
from migrate.versioning.script import base
from migrate.versioning.template import Template
log = logging.getLogger(__name__)
class SqlScript(base.BaseScript):
"""A file containing plain SQL statements."""
@classmethod
def create(cls, path, **opts):
"""Create an empty migration script at specified path
:returns: :class:`SqlScript instance <migrate.versioning.script.sql.SqlScript>`"""
cls.require_notfound(path)
src = Template(opts.pop('templates_path', None)).get_sql_script(theme=opts.pop('templates_theme', None))
shutil.copy(src, path)
return cls(path)
# TODO: why is step parameter even here?
def run(self, engine, step=None, executemany=True):
"""Runs SQL script through raw dbapi execute call"""
text = self.source()
# Don't rely on SA's autocommit here
# (SA uses .startswith to check if a commit is needed. What if script
# starts with a comment?)
conn = engine.connect()
try:
trans = conn.begin()
try:
# HACK: SQLite doesn't allow multiple statements through
# its execute() method, but it provides executescript() instead
dbapi = conn.engine.raw_connection()
if executemany and getattr(dbapi, 'executescript', None):
dbapi.executescript(text)
else:
conn.execute(text)
trans.commit()
except:
trans.rollback()
raise
finally:
conn.close()

View File

@@ -1,214 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""The migrate command-line tool."""
import sys
import inspect
import logging
from optparse import OptionParser, BadOptionError
from migrate import exceptions
from migrate.versioning import api
from migrate.versioning.config import *
from migrate.versioning.util import asbool
alias = dict(
s=api.script,
vc=api.version_control,
dbv=api.db_version,
v=api.version,
)
def alias_setup():
global alias
for key, val in alias.iteritems():
setattr(api, key, val)
alias_setup()
class PassiveOptionParser(OptionParser):
def _process_args(self, largs, rargs, values):
"""little hack to support all --some_option=value parameters"""
while rargs:
arg = rargs[0]
if arg == "--":
del rargs[0]
return
elif arg[0:2] == "--":
# if parser does not know about the option
# pass it along (make it anonymous)
try:
opt = arg.split('=', 1)[0]
self._match_long_opt(opt)
except BadOptionError:
largs.append(arg)
del rargs[0]
else:
self._process_long_opt(rargs, values)
elif arg[:1] == "-" and len(arg) > 1:
self._process_short_opts(rargs, values)
elif self.allow_interspersed_args:
largs.append(arg)
del rargs[0]
def main(argv=None, **kwargs):
"""Shell interface to :mod:`migrate.versioning.api`.
kwargs are default options that can be overriden with passing
--some_option as command line option
:param disable_logging: Let migrate configure logging
:type disable_logging: bool
"""
if argv is not None:
argv = argv
else:
argv = list(sys.argv[1:])
commands = list(api.__all__)
commands.sort()
usage = """%%prog COMMAND ...
Available commands:
%s
Enter "%%prog help COMMAND" for information on a particular command.
""" % '\n\t'.join(["%s - %s" % (command.ljust(28), api.command_desc.get(command)) for command in commands])
parser = PassiveOptionParser(usage=usage)
parser.add_option("-d", "--debug",
action="store_true",
dest="debug",
default=False,
help="Shortcut to turn on DEBUG mode for logging")
parser.add_option("-q", "--disable_logging",
action="store_true",
dest="disable_logging",
default=False,
help="Use this option to disable logging configuration")
help_commands = ['help', '-h', '--help']
HELP = False
try:
command = argv.pop(0)
if command in help_commands:
HELP = True
command = argv.pop(0)
except IndexError:
parser.print_help()
return
command_func = getattr(api, command, None)
if command_func is None or command.startswith('_'):
parser.error("Invalid command %s" % command)
parser.set_usage(inspect.getdoc(command_func))
f_args, f_varargs, f_kwargs, f_defaults = inspect.getargspec(command_func)
for arg in f_args:
parser.add_option(
"--%s" % arg,
dest=arg,
action='store',
type="string")
# display help of the current command
if HELP:
parser.print_help()
return
options, args = parser.parse_args(argv)
# override kwargs with anonymous parameters
override_kwargs = dict()
for arg in list(args):
if arg.startswith('--'):
args.remove(arg)
if '=' in arg:
opt, value = arg[2:].split('=', 1)
else:
opt = arg[2:]
value = True
override_kwargs[opt] = value
# override kwargs with options if user is overwriting
for key, value in options.__dict__.iteritems():
if value is not None:
override_kwargs[key] = value
# arguments that function accepts without passed kwargs
f_required = list(f_args)
candidates = dict(kwargs)
candidates.update(override_kwargs)
for key, value in candidates.iteritems():
if key in f_args:
f_required.remove(key)
# map function arguments to parsed arguments
for arg in args:
try:
kw = f_required.pop(0)
except IndexError:
parser.error("Too many arguments for command %s: %s" % (command,
arg))
kwargs[kw] = arg
# apply overrides
kwargs.update(override_kwargs)
# configure options
for key, value in options.__dict__.iteritems():
kwargs.setdefault(key, value)
# configure logging
if not asbool(kwargs.pop('disable_logging', False)):
# filter to log =< INFO into stdout and rest to stderr
class SingleLevelFilter(logging.Filter):
def __init__(self, min=None, max=None):
self.min = min or 0
self.max = max or 100
def filter(self, record):
return self.min <= record.levelno <= self.max
logger = logging.getLogger()
h1 = logging.StreamHandler(sys.stdout)
f1 = SingleLevelFilter(max=logging.INFO)
h1.addFilter(f1)
h2 = logging.StreamHandler(sys.stderr)
f2 = SingleLevelFilter(min=logging.WARN)
h2.addFilter(f2)
logger.addHandler(h1)
logger.addHandler(h2)
if options.debug:
logger.setLevel(logging.DEBUG)
else:
logger.setLevel(logging.INFO)
log = logging.getLogger(__name__)
# check if all args are given
try:
num_defaults = len(f_defaults)
except TypeError:
num_defaults = 0
f_args_default = f_args[len(f_args) - num_defaults:]
required = list(set(f_required) - set(f_args_default))
if required:
parser.error("Not enough arguments for command %s: %s not specified" \
% (command, ', '.join(required)))
# handle command
try:
ret = command_func(**kwargs)
if ret is not None:
log.info(ret)
except (exceptions.UsageError, exceptions.KnownError), e:
parser.error(e.args[0])
if __name__ == "__main__":
main()

View File

@@ -1,93 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import shutil
import sys
from pkg_resources import resource_filename
from migrate.versioning.config import *
from migrate.versioning import pathed
class Collection(pathed.Pathed):
"""A collection of templates of a specific type"""
_mask = None
def get_path(self, file):
return os.path.join(self.path, str(file))
class RepositoryCollection(Collection):
_mask = '%s'
class ScriptCollection(Collection):
_mask = '%s.py_tmpl'
class ManageCollection(Collection):
_mask = '%s.py_tmpl'
class SQLScriptCollection(Collection):
_mask = '%s.py_tmpl'
class Template(pathed.Pathed):
"""Finds the paths/packages of various Migrate templates.
:param path: Templates are loaded from migrate package
if `path` is not provided.
"""
pkg = 'migrate.versioning.templates'
def __new__(cls, path=None):
if path is None:
path = cls._find_path(cls.pkg)
return super(Template, cls).__new__(cls, path)
def __init__(self, path=None):
if path is None:
path = Template._find_path(self.pkg)
super(Template, self).__init__(path)
self.repository = RepositoryCollection(os.path.join(path, 'repository'))
self.script = ScriptCollection(os.path.join(path, 'script'))
self.manage = ManageCollection(os.path.join(path, 'manage'))
self.sql_script = SQLScriptCollection(os.path.join(path, 'sql_script'))
@classmethod
def _find_path(cls, pkg):
"""Returns absolute path to dotted python package."""
tmp_pkg = pkg.rsplit('.', 1)
if len(tmp_pkg) != 1:
return resource_filename(tmp_pkg[0], tmp_pkg[1])
else:
return resource_filename(tmp_pkg[0], '')
def _get_item(self, collection, theme=None):
"""Locates and returns collection.
:param collection: name of collection to locate
:param type_: type of subfolder in collection (defaults to "_default")
:returns: (package, source)
:rtype: str, str
"""
item = getattr(self, collection)
theme_mask = getattr(item, '_mask')
theme = theme_mask % (theme or 'default')
return item.get_path(theme)
def get_repository(self, *a, **kw):
"""Calls self._get_item('repository', *a, **kw)"""
return self._get_item('repository', *a, **kw)
def get_script(self, *a, **kw):
"""Calls self._get_item('script', *a, **kw)"""
return self._get_item('script', *a, **kw)
def get_sql_script(self, *a, **kw):
"""Calls self._get_item('sql_script', *a, **kw)"""
return self._get_item('sql_script', *a, **kw)
def get_manage(self, *a, **kw):
"""Calls self._get_item('manage', *a, **kw)"""
return self._get_item('manage', *a, **kw)

View File

@@ -1,12 +0,0 @@
#!/usr/bin/env python
from migrate.versioning.shell import main
{{py:
_vars = locals().copy()
del _vars['__template_name__']
_vars.pop('repository_name', None)
defaults = ", ".join(["%s='%s'" % var for var in _vars.iteritems()])
}}
if __name__ == '__main__':
main({{ defaults }})

View File

@@ -1,30 +0,0 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
import sys
from sqlalchemy import engine_from_config
from paste.deploy.loadwsgi import ConfigLoader
from migrate.versioning.shell import main
from {{ locals().pop('repository_name') }}.model import migrations
if '-c' in sys.argv:
pos = sys.argv.index('-c')
conf_path = sys.argv[pos + 1]
del sys.argv[pos:pos + 2]
else:
conf_path = 'development.ini'
{{py:
_vars = locals().copy()
del _vars['__template_name__']
defaults = ", ".join(["%s='%s'" % var for var in _vars.iteritems()])
}}
conf_dict = ConfigLoader(conf_path).parser._sections['app:main']
# migrate supports passing url as an existing Engine instance (since 0.6.0)
# usage: migrate -c path/to/config.ini COMMANDS
if __name__ == '__main__':
main(url=engine_from_config(conf_dict), repository=migrations.__path__[0],{{ defaults }})

View File

@@ -1,4 +0,0 @@
This is a database migration repository.
More information at
http://code.google.com/p/sqlalchemy-migrate/

View File

@@ -1,25 +0,0 @@
[db_settings]
# Used to identify which repository this database is versioned under.
# You can use the name of your project.
repository_id={{ locals().pop('repository_id') }}
# The name of the database table used to track the schema version.
# This name shouldn't already be used by your project.
# If this is changed once a database is under version control, you'll need to
# change the table name in each database too.
version_table={{ locals().pop('version_table') }}
# When committing a change script, Migrate will attempt to generate the
# sql for all supported databases; normally, if one of them fails - probably
# because you don't have that database installed - it is ignored and the
# commit continues, perhaps ending successfully.
# Databases in this list MUST compile successfully during a commit, or the
# entire commit will fail. List the databases your application will actually
# be using to ensure your updates to that database work properly.
# This must be a list; example: ['postgres','sqlite']
required_dbs={{ locals().pop('required_dbs') }}
# When creating new change scripts, Migrate will stamp the new script with
# a version number. By default this is latest_version + 1. You can set this
# to 'true' to tell Migrate to use the UTC timestamp instead.
use_timestamp_numbering={{ locals().pop('use_timestamp_numbering') }}

View File

@@ -1,4 +0,0 @@
This is a database migration repository.
More information at
http://code.google.com/p/sqlalchemy-migrate/

View File

@@ -1,25 +0,0 @@
[db_settings]
# Used to identify which repository this database is versioned under.
# You can use the name of your project.
repository_id={{ locals().pop('repository_id') }}
# The name of the database table used to track the schema version.
# This name shouldn't already be used by your project.
# If this is changed once a database is under version control, you'll need to
# change the table name in each database too.
version_table={{ locals().pop('version_table') }}
# When committing a change script, Migrate will attempt to generate the
# sql for all supported databases; normally, if one of them fails - probably
# because you don't have that database installed - it is ignored and the
# commit continues, perhaps ending successfully.
# Databases in this list MUST compile successfully during a commit, or the
# entire commit will fail. List the databases your application will actually
# be using to ensure your updates to that database work properly.
# This must be a list; example: ['postgres','sqlite']
required_dbs={{ locals().pop('required_dbs') }}
# When creating new change scripts, Migrate will stamp the new script with
# a version number. By default this is latest_version + 1. You can set this
# to 'true' to tell Migrate to use the UTC timestamp instead.
use_timestamp_numbering={{ locals().pop('use_timestamp_numbering') }}

View File

@@ -1,13 +0,0 @@
from sqlalchemy import *
from migrate import *
def upgrade(migrate_engine):
# Upgrade operations go here. Don't create your own engine; bind
# migrate_engine to your metadata
pass
def downgrade(migrate_engine):
# Operations to reverse the above upgrade go here.
pass

View File

@@ -1,13 +0,0 @@
from sqlalchemy import *
from migrate import *
def upgrade(migrate_engine):
# Upgrade operations go here. Don't create your own engine; bind
# migrate_engine to your metadata
pass
def downgrade(migrate_engine):
# Operations to reverse the above upgrade go here.
pass

View File

@@ -1,179 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
""".. currentmodule:: migrate.versioning.util"""
import warnings
import logging
from decorator import decorator
from pkg_resources import EntryPoint
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.pool import StaticPool
from migrate import exceptions
from migrate.versioning.util.keyedinstance import KeyedInstance
from migrate.versioning.util.importpath import import_path
log = logging.getLogger(__name__)
def load_model(dotted_name):
"""Import module and use module-level variable".
:param dotted_name: path to model in form of string: ``some.python.module:Class``
.. versionchanged:: 0.5.4
"""
if isinstance(dotted_name, basestring):
if ':' not in dotted_name:
# backwards compatibility
warnings.warn('model should be in form of module.model:User '
'and not module.model.User', exceptions.MigrateDeprecationWarning)
dotted_name = ':'.join(dotted_name.rsplit('.', 1))
return EntryPoint.parse('x=%s' % dotted_name).load(False)
else:
# Assume it's already loaded.
return dotted_name
def asbool(obj):
"""Do everything to use object as bool"""
if isinstance(obj, basestring):
obj = obj.strip().lower()
if obj in ['true', 'yes', 'on', 'y', 't', '1']:
return True
elif obj in ['false', 'no', 'off', 'n', 'f', '0']:
return False
else:
raise ValueError("String is not true/false: %r" % obj)
if obj in (True, False):
return bool(obj)
else:
raise ValueError("String is not true/false: %r" % obj)
def guess_obj_type(obj):
"""Do everything to guess object type from string
Tries to convert to `int`, `bool` and finally returns if not succeded.
.. versionadded: 0.5.4
"""
result = None
try:
result = int(obj)
except:
pass
if result is None:
try:
result = asbool(obj)
except:
pass
if result is not None:
return result
else:
return obj
@decorator
def catch_known_errors(f, *a, **kw):
"""Decorator that catches known api errors
.. versionadded: 0.5.4
"""
try:
return f(*a, **kw)
except exceptions.PathFoundError, e:
raise exceptions.KnownError("The path %s already exists" % e.args[0])
def construct_engine(engine, **opts):
""".. versionadded:: 0.5.4
Constructs and returns SQLAlchemy engine.
Currently, there are 2 ways to pass create_engine options to :mod:`migrate.versioning.api` functions:
:param engine: connection string or a existing engine
:param engine_dict: python dictionary of options to pass to `create_engine`
:param engine_arg_*: keyword parameters to pass to `create_engine` (evaluated with :func:`migrate.versioning.util.guess_obj_type`)
:type engine_dict: dict
:type engine: string or Engine instance
:type engine_arg_*: string
:returns: SQLAlchemy Engine
.. note::
keyword parameters override ``engine_dict`` values.
"""
if isinstance(engine, Engine):
return engine
elif not isinstance(engine, basestring):
raise ValueError("you need to pass either an existing engine or a database uri")
# get options for create_engine
if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict):
kwargs = opts['engine_dict']
else:
kwargs = dict()
# DEPRECATED: handle echo the old way
echo = asbool(opts.get('echo', False))
if echo:
warnings.warn('echo=True parameter is deprecated, pass '
'engine_arg_echo=True or engine_dict={"echo": True}',
exceptions.MigrateDeprecationWarning)
kwargs['echo'] = echo
# parse keyword arguments
for key, value in opts.iteritems():
if key.startswith('engine_arg_'):
kwargs[key[11:]] = guess_obj_type(value)
log.debug('Constructing engine')
# TODO: return create_engine(engine, poolclass=StaticPool, **kwargs)
# seems like 0.5.x branch does not work with engine.dispose and staticpool
return create_engine(engine, **kwargs)
@decorator
def with_engine(f, *a, **kw):
"""Decorator for :mod:`migrate.versioning.api` functions
to safely close resources after function usage.
Passes engine parameters to :func:`construct_engine` and
resulting parameter is available as kw['engine'].
Engine is disposed after wrapped function is executed.
.. versionadded: 0.6.0
"""
url = a[0]
engine = construct_engine(url, **kw)
try:
kw['engine'] = engine
return f(*a, **kw)
finally:
if isinstance(engine, Engine):
log.debug('Disposing SQLAlchemy engine %s', engine)
engine.dispose()
class Memoize:
"""Memoize(fn) - an instance which acts like fn but memoizes its arguments
Will only work on functions with non-mutable arguments
ActiveState Code 52201
"""
def __init__(self, fn):
self.fn = fn
self.memo = {}
def __call__(self, *args):
if not self.memo.has_key(args):
self.memo[args] = self.fn(*args)
return self.memo[args]

View File

@@ -1,16 +0,0 @@
import os
import sys
def import_path(fullpath):
""" Import a file with full path specification. Allows one to
import from anywhere, something __import__ does not do.
"""
# http://zephyrfalcon.org/weblog/arch_d7_2002_08_31.html
path, filename = os.path.split(fullpath)
filename, ext = os.path.splitext(filename)
sys.path.append(path)
module = __import__(filename)
reload(module) # Might be out of date during tests
del sys.path[-1]
return module

View File

@@ -1,36 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
class KeyedInstance(object):
"""A class whose instances have a unique identifier of some sort
No two instances with the same unique ID should exist - if we try to create
a second instance, the first should be returned.
"""
_instances = dict()
def __new__(cls, *p, **k):
instances = cls._instances
clskey = str(cls)
if clskey not in instances:
instances[clskey] = dict()
instances = instances[clskey]
key = cls._key(*p, **k)
if key not in instances:
instances[key] = super(KeyedInstance, cls).__new__(cls)
return instances[key]
@classmethod
def _key(cls, *p, **k):
"""Given a unique identifier, return a dictionary key
This should be overridden by child classes, to specify which parameters
should determine an object's uniqueness
"""
raise NotImplementedError()
@classmethod
def clear(cls):
# Allow cls.clear() as well as uniqueInstance.clear(cls)
if str(cls) in cls._instances:
del cls._instances[str(cls)]

View File

@@ -1,238 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import re
import shutil
import logging
from migrate import exceptions
from migrate.versioning import pathed, script
from datetime import datetime
log = logging.getLogger(__name__)
class VerNum(object):
"""A version number that behaves like a string and int at the same time"""
_instances = dict()
def __new__(cls, value):
val = str(value)
if val not in cls._instances:
cls._instances[val] = super(VerNum, cls).__new__(cls)
ret = cls._instances[val]
return ret
def __init__(self,value):
self.value = str(int(value))
if self < 0:
raise ValueError("Version number cannot be negative")
def __add__(self, value):
ret = int(self) + int(value)
return VerNum(ret)
def __sub__(self, value):
return self + (int(value) * -1)
def __cmp__(self, value):
return int(self) - int(value)
def __repr__(self):
return "<VerNum(%s)>" % self.value
def __str__(self):
return str(self.value)
def __int__(self):
return int(self.value)
class Collection(pathed.Pathed):
"""A collection of versioning scripts in a repository"""
FILENAME_WITH_VERSION = re.compile(r'^(\d{3,}).*')
def __init__(self, path):
"""Collect current version scripts in repository
and store them in self.versions
"""
super(Collection, self).__init__(path)
# Create temporary list of files, allowing skipped version numbers.
files = os.listdir(path)
if '1' in files:
# deprecation
raise Exception('It looks like you have a repository in the old '
'format (with directories for each version). '
'Please convert repository before proceeding.')
tempVersions = dict()
for filename in files:
match = self.FILENAME_WITH_VERSION.match(filename)
if match:
num = int(match.group(1))
tempVersions.setdefault(num, []).append(filename)
else:
pass # Must be a helper file or something, let's ignore it.
# Create the versions member where the keys
# are VerNum's and the values are Version's.
self.versions = dict()
for num, files in tempVersions.items():
self.versions[VerNum(num)] = Version(num, path, files)
@property
def latest(self):
""":returns: Latest version in Collection"""
return max([VerNum(0)] + self.versions.keys())
def _next_ver_num(self, use_timestamp_numbering):
if use_timestamp_numbering == True:
return VerNum(int(datetime.utcnow().strftime('%Y%m%d%H%M%S')))
else:
return self.latest + 1
def create_new_python_version(self, description, **k):
"""Create Python files for new version"""
ver = self._next_ver_num(k.pop('use_timestamp_numbering', False))
extra = str_to_filename(description)
if extra:
if extra == '_':
extra = ''
elif not extra.startswith('_'):
extra = '_%s' % extra
filename = '%03d%s.py' % (ver, extra)
filepath = self._version_path(filename)
script.PythonScript.create(filepath, **k)
self.versions[ver] = Version(ver, self.path, [filename])
def create_new_sql_version(self, database, description, **k):
"""Create SQL files for new version"""
ver = self._next_ver_num(k.pop('use_timestamp_numbering', False))
self.versions[ver] = Version(ver, self.path, [])
extra = str_to_filename(description)
if extra:
if extra == '_':
extra = ''
elif not extra.startswith('_'):
extra = '_%s' % extra
# Create new files.
for op in ('upgrade', 'downgrade'):
filename = '%03d%s_%s_%s.sql' % (ver, extra, database, op)
filepath = self._version_path(filename)
script.SqlScript.create(filepath, **k)
self.versions[ver].add_script(filepath)
def version(self, vernum=None):
"""Returns latest Version if vernum is not given.
Otherwise, returns wanted version"""
if vernum is None:
vernum = self.latest
return self.versions[VerNum(vernum)]
@classmethod
def clear(cls):
super(Collection, cls).clear()
def _version_path(self, ver):
"""Returns path of file in versions repository"""
return os.path.join(self.path, str(ver))
class Version(object):
"""A single version in a collection
:param vernum: Version Number
:param path: Path to script files
:param filelist: List of scripts
:type vernum: int, VerNum
:type path: string
:type filelist: list
"""
def __init__(self, vernum, path, filelist):
self.version = VerNum(vernum)
# Collect scripts in this folder
self.sql = dict()
self.python = None
for script in filelist:
self.add_script(os.path.join(path, script))
def script(self, database=None, operation=None):
"""Returns SQL or Python Script"""
for db in (database, 'default'):
# Try to return a .sql script first
try:
return self.sql[db][operation]
except KeyError:
continue # No .sql script exists
# TODO: maybe add force Python parameter?
ret = self.python
assert ret is not None, \
"There is no script for %d version" % self.version
return ret
def add_script(self, path):
"""Add script to Collection/Version"""
if path.endswith(Extensions.py):
self._add_script_py(path)
elif path.endswith(Extensions.sql):
self._add_script_sql(path)
SQL_FILENAME = re.compile(r'^.*\.sql')
def _add_script_sql(self, path):
basename = os.path.basename(path)
match = self.SQL_FILENAME.match(basename)
if match:
basename = basename.replace('.sql', '')
parts = basename.split('_')
if len(parts) < 3:
raise exceptions.ScriptError(
"Invalid SQL script name %s " % basename + \
"(needs to be ###_description_database_operation.sql)")
version = parts[0]
op = parts[-1]
dbms = parts[-2]
else:
raise exceptions.ScriptError(
"Invalid SQL script name %s " % basename + \
"(needs to be ###_description_database_operation.sql)")
# File the script into a dictionary
self.sql.setdefault(dbms, {})[op] = script.SqlScript(path)
def _add_script_py(self, path):
if self.python is not None:
raise exceptions.ScriptError('You can only have one Python script '
'per version, but you have: %s and %s' % (self.python, path))
self.python = script.PythonScript(path)
class Extensions:
"""A namespace for file extensions"""
py = 'py'
sql = 'sql'
def str_to_filename(s):
"""Replaces spaces, (double and single) quotes
and double underscores to underscores
"""
s = s.replace(' ', '_').replace('"', '_').replace("'", '_').replace(".", "_")
while '__' in s:
s = s.replace('__', '_')
return s

View File

@@ -1,133 +0,0 @@
# sqlalchemy/__init__.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from .sql import (
alias,
and_,
asc,
between,
bindparam,
case,
cast,
collate,
delete,
desc,
distinct,
except_,
except_all,
exists,
extract,
false,
func,
insert,
intersect,
intersect_all,
join,
literal,
literal_column,
modifier,
not_,
null,
or_,
outerjoin,
outparam,
over,
select,
subquery,
text,
true,
tuple_,
type_coerce,
union,
union_all,
update,
)
from .types import (
BIGINT,
BINARY,
BLOB,
BOOLEAN,
BigInteger,
Binary,
Boolean,
CHAR,
CLOB,
DATE,
DATETIME,
DECIMAL,
Date,
DateTime,
Enum,
FLOAT,
Float,
INT,
INTEGER,
Integer,
Interval,
LargeBinary,
NCHAR,
NVARCHAR,
NUMERIC,
Numeric,
PickleType,
REAL,
SMALLINT,
SmallInteger,
String,
TEXT,
TIME,
TIMESTAMP,
Text,
Time,
TypeDecorator,
Unicode,
UnicodeText,
VARBINARY,
VARCHAR,
)
from .schema import (
CheckConstraint,
Column,
ColumnDefault,
Constraint,
DefaultClause,
FetchedValue,
ForeignKey,
ForeignKeyConstraint,
Index,
MetaData,
PassiveDefault,
PrimaryKeyConstraint,
Sequence,
Table,
ThreadLocalMetaData,
UniqueConstraint,
DDL,
)
from .inspection import inspect
from .engine import create_engine, engine_from_config
__version__ = '0.9.1'
def __go(lcls):
global __all__
from . import events
from . import util as _sa_util
import inspect as _inspect
__all__ = sorted(name for name, obj in lcls.items()
if not (name.startswith('_') or _inspect.ismodule(obj)))
_sa_util.dependencies.resolve_all("sqlalchemy")
__go(locals())

View File

@@ -1,665 +0,0 @@
/*
processors.c
Copyright (C) 2010-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
Copyright (C) 2010-2011 Gaetan de Menten gdementen@gmail.com
This module is part of SQLAlchemy and is released under
the MIT License: http://www.opensource.org/licenses/mit-license.php
*/
#include <Python.h>
#include <datetime.h>
#define MODULE_NAME "cprocessors"
#define MODULE_DOC "Module containing C versions of data processing functions."
#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN)
typedef int Py_ssize_t;
#define PY_SSIZE_T_MAX INT_MAX
#define PY_SSIZE_T_MIN INT_MIN
#endif
static PyObject *
int_to_boolean(PyObject *self, PyObject *arg)
{
long l = 0;
PyObject *res;
if (arg == Py_None)
Py_RETURN_NONE;
#if PY_MAJOR_VERSION >= 3
l = PyLong_AsLong(arg);
#else
l = PyInt_AsLong(arg);
#endif
if (l == 0) {
res = Py_False;
} else if (l == 1) {
res = Py_True;
} else if ((l == -1) && PyErr_Occurred()) {
/* -1 can be either the actual value, or an error flag. */
return NULL;
} else {
PyErr_SetString(PyExc_ValueError,
"int_to_boolean only accepts None, 0 or 1");
return NULL;
}
Py_INCREF(res);
return res;
}
static PyObject *
to_str(PyObject *self, PyObject *arg)
{
if (arg == Py_None)
Py_RETURN_NONE;
return PyObject_Str(arg);
}
static PyObject *
to_float(PyObject *self, PyObject *arg)
{
if (arg == Py_None)
Py_RETURN_NONE;
return PyNumber_Float(arg);
}
static PyObject *
str_to_datetime(PyObject *self, PyObject *arg)
{
#if PY_MAJOR_VERSION >= 3
PyObject *bytes;
PyObject *err_bytes;
#endif
const char *str;
int numparsed;
unsigned int year, month, day, hour, minute, second, microsecond = 0;
PyObject *err_repr;
if (arg == Py_None)
Py_RETURN_NONE;
#if PY_MAJOR_VERSION >= 3
bytes = PyUnicode_AsASCIIString(arg);
if (bytes == NULL)
str = NULL;
else
str = PyBytes_AS_STRING(bytes);
#else
str = PyString_AsString(arg);
#endif
if (str == NULL) {
err_repr = PyObject_Repr(arg);
if (err_repr == NULL)
return NULL;
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(err_repr);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_ValueError,
"Couldn't parse datetime string '%.200s' "
"- value is not a string.",
PyBytes_AS_STRING(err_bytes));
Py_DECREF(err_bytes);
#else
PyErr_Format(
PyExc_ValueError,
"Couldn't parse datetime string '%.200s' "
"- value is not a string.",
PyString_AsString(err_repr));
#endif
Py_DECREF(err_repr);
return NULL;
}
/* microseconds are optional */
/*
TODO: this is slightly less picky than the Python version which would
not accept "2000-01-01 00:00:00.". I don't know which is better, but they
should be coherent.
*/
numparsed = sscanf(str, "%4u-%2u-%2u %2u:%2u:%2u.%6u", &year, &month, &day,
&hour, &minute, &second, &microsecond);
#if PY_MAJOR_VERSION >= 3
Py_DECREF(bytes);
#endif
if (numparsed < 6) {
err_repr = PyObject_Repr(arg);
if (err_repr == NULL)
return NULL;
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(err_repr);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_ValueError,
"Couldn't parse datetime string: %.200s",
PyBytes_AS_STRING(err_bytes));
Py_DECREF(err_bytes);
#else
PyErr_Format(
PyExc_ValueError,
"Couldn't parse datetime string: %.200s",
PyString_AsString(err_repr));
#endif
Py_DECREF(err_repr);
return NULL;
}
return PyDateTime_FromDateAndTime(year, month, day,
hour, minute, second, microsecond);
}
static PyObject *
str_to_time(PyObject *self, PyObject *arg)
{
#if PY_MAJOR_VERSION >= 3
PyObject *bytes;
PyObject *err_bytes;
#endif
const char *str;
int numparsed;
unsigned int hour, minute, second, microsecond = 0;
PyObject *err_repr;
if (arg == Py_None)
Py_RETURN_NONE;
#if PY_MAJOR_VERSION >= 3
bytes = PyUnicode_AsASCIIString(arg);
if (bytes == NULL)
str = NULL;
else
str = PyBytes_AS_STRING(bytes);
#else
str = PyString_AsString(arg);
#endif
if (str == NULL) {
err_repr = PyObject_Repr(arg);
if (err_repr == NULL)
return NULL;
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(err_repr);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_ValueError,
"Couldn't parse time string '%.200s' - value is not a string.",
PyBytes_AS_STRING(err_bytes));
Py_DECREF(err_bytes);
#else
PyErr_Format(
PyExc_ValueError,
"Couldn't parse time string '%.200s' - value is not a string.",
PyString_AsString(err_repr));
#endif
Py_DECREF(err_repr);
return NULL;
}
/* microseconds are optional */
/*
TODO: this is slightly less picky than the Python version which would
not accept "00:00:00.". I don't know which is better, but they should be
coherent.
*/
numparsed = sscanf(str, "%2u:%2u:%2u.%6u", &hour, &minute, &second,
&microsecond);
#if PY_MAJOR_VERSION >= 3
Py_DECREF(bytes);
#endif
if (numparsed < 3) {
err_repr = PyObject_Repr(arg);
if (err_repr == NULL)
return NULL;
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(err_repr);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_ValueError,
"Couldn't parse time string: %.200s",
PyBytes_AS_STRING(err_bytes));
Py_DECREF(err_bytes);
#else
PyErr_Format(
PyExc_ValueError,
"Couldn't parse time string: %.200s",
PyString_AsString(err_repr));
#endif
Py_DECREF(err_repr);
return NULL;
}
return PyTime_FromTime(hour, minute, second, microsecond);
}
static PyObject *
str_to_date(PyObject *self, PyObject *arg)
{
#if PY_MAJOR_VERSION >= 3
PyObject *bytes;
PyObject *err_bytes;
#endif
const char *str;
int numparsed;
unsigned int year, month, day;
PyObject *err_repr;
if (arg == Py_None)
Py_RETURN_NONE;
#if PY_MAJOR_VERSION >= 3
bytes = PyUnicode_AsASCIIString(arg);
if (bytes == NULL)
str = NULL;
else
str = PyBytes_AS_STRING(bytes);
#else
str = PyString_AsString(arg);
#endif
if (str == NULL) {
err_repr = PyObject_Repr(arg);
if (err_repr == NULL)
return NULL;
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(err_repr);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_ValueError,
"Couldn't parse date string '%.200s' - value is not a string.",
PyBytes_AS_STRING(err_bytes));
Py_DECREF(err_bytes);
#else
PyErr_Format(
PyExc_ValueError,
"Couldn't parse date string '%.200s' - value is not a string.",
PyString_AsString(err_repr));
#endif
Py_DECREF(err_repr);
return NULL;
}
numparsed = sscanf(str, "%4u-%2u-%2u", &year, &month, &day);
#if PY_MAJOR_VERSION >= 3
Py_DECREF(bytes);
#endif
if (numparsed != 3) {
err_repr = PyObject_Repr(arg);
if (err_repr == NULL)
return NULL;
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(err_repr);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_ValueError,
"Couldn't parse date string: %.200s",
PyBytes_AS_STRING(err_bytes));
Py_DECREF(err_bytes);
#else
PyErr_Format(
PyExc_ValueError,
"Couldn't parse date string: %.200s",
PyString_AsString(err_repr));
#endif
Py_DECREF(err_repr);
return NULL;
}
return PyDate_FromDate(year, month, day);
}
/***********
* Structs *
***********/
typedef struct {
PyObject_HEAD
PyObject *encoding;
PyObject *errors;
} UnicodeResultProcessor;
typedef struct {
PyObject_HEAD
PyObject *type;
PyObject *format;
} DecimalResultProcessor;
/**************************
* UnicodeResultProcessor *
**************************/
static int
UnicodeResultProcessor_init(UnicodeResultProcessor *self, PyObject *args,
PyObject *kwds)
{
PyObject *encoding, *errors = NULL;
static char *kwlist[] = {"encoding", "errors", NULL};
#if PY_MAJOR_VERSION >= 3
if (!PyArg_ParseTupleAndKeywords(args, kwds, "U|U:__init__", kwlist,
&encoding, &errors))
return -1;
#else
if (!PyArg_ParseTupleAndKeywords(args, kwds, "S|S:__init__", kwlist,
&encoding, &errors))
return -1;
#endif
#if PY_MAJOR_VERSION >= 3
encoding = PyUnicode_AsASCIIString(encoding);
#else
Py_INCREF(encoding);
#endif
self->encoding = encoding;
if (errors) {
#if PY_MAJOR_VERSION >= 3
errors = PyUnicode_AsASCIIString(errors);
#else
Py_INCREF(errors);
#endif
} else {
#if PY_MAJOR_VERSION >= 3
errors = PyBytes_FromString("strict");
#else
errors = PyString_FromString("strict");
#endif
if (errors == NULL)
return -1;
}
self->errors = errors;
return 0;
}
static PyObject *
UnicodeResultProcessor_process(UnicodeResultProcessor *self, PyObject *value)
{
const char *encoding, *errors;
char *str;
Py_ssize_t len;
if (value == Py_None)
Py_RETURN_NONE;
#if PY_MAJOR_VERSION >= 3
if (PyBytes_AsStringAndSize(value, &str, &len))
return NULL;
encoding = PyBytes_AS_STRING(self->encoding);
errors = PyBytes_AS_STRING(self->errors);
#else
if (PyString_AsStringAndSize(value, &str, &len))
return NULL;
encoding = PyString_AS_STRING(self->encoding);
errors = PyString_AS_STRING(self->errors);
#endif
return PyUnicode_Decode(str, len, encoding, errors);
}
static void
UnicodeResultProcessor_dealloc(UnicodeResultProcessor *self)
{
Py_XDECREF(self->encoding);
Py_XDECREF(self->errors);
#if PY_MAJOR_VERSION >= 3
Py_TYPE(self)->tp_free((PyObject*)self);
#else
self->ob_type->tp_free((PyObject*)self);
#endif
}
static PyMethodDef UnicodeResultProcessor_methods[] = {
{"process", (PyCFunction)UnicodeResultProcessor_process, METH_O,
"The value processor itself."},
{NULL} /* Sentinel */
};
static PyTypeObject UnicodeResultProcessorType = {
PyVarObject_HEAD_INIT(NULL, 0)
"sqlalchemy.cprocessors.UnicodeResultProcessor", /* tp_name */
sizeof(UnicodeResultProcessor), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)UnicodeResultProcessor_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
"UnicodeResultProcessor objects", /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
UnicodeResultProcessor_methods, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc)UnicodeResultProcessor_init, /* tp_init */
0, /* tp_alloc */
0, /* tp_new */
};
/**************************
* DecimalResultProcessor *
**************************/
static int
DecimalResultProcessor_init(DecimalResultProcessor *self, PyObject *args,
PyObject *kwds)
{
PyObject *type, *format;
#if PY_MAJOR_VERSION >= 3
if (!PyArg_ParseTuple(args, "OU", &type, &format))
#else
if (!PyArg_ParseTuple(args, "OS", &type, &format))
#endif
return -1;
Py_INCREF(type);
self->type = type;
Py_INCREF(format);
self->format = format;
return 0;
}
static PyObject *
DecimalResultProcessor_process(DecimalResultProcessor *self, PyObject *value)
{
PyObject *str, *result, *args;
if (value == Py_None)
Py_RETURN_NONE;
/* Decimal does not accept float values directly */
/* SQLite can also give us an integer here (see [ticket:2432]) */
/* XXX: starting with Python 3.1, we could use Decimal.from_float(f),
but the result wouldn't be the same */
args = PyTuple_Pack(1, value);
if (args == NULL)
return NULL;
#if PY_MAJOR_VERSION >= 3
str = PyUnicode_Format(self->format, args);
#else
str = PyString_Format(self->format, args);
#endif
Py_DECREF(args);
if (str == NULL)
return NULL;
result = PyObject_CallFunctionObjArgs(self->type, str, NULL);
Py_DECREF(str);
return result;
}
static void
DecimalResultProcessor_dealloc(DecimalResultProcessor *self)
{
Py_XDECREF(self->type);
Py_XDECREF(self->format);
#if PY_MAJOR_VERSION >= 3
Py_TYPE(self)->tp_free((PyObject*)self);
#else
self->ob_type->tp_free((PyObject*)self);
#endif
}
static PyMethodDef DecimalResultProcessor_methods[] = {
{"process", (PyCFunction)DecimalResultProcessor_process, METH_O,
"The value processor itself."},
{NULL} /* Sentinel */
};
static PyTypeObject DecimalResultProcessorType = {
PyVarObject_HEAD_INIT(NULL, 0)
"sqlalchemy.DecimalResultProcessor", /* tp_name */
sizeof(DecimalResultProcessor), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)DecimalResultProcessor_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
"DecimalResultProcessor objects", /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
DecimalResultProcessor_methods, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc)DecimalResultProcessor_init, /* tp_init */
0, /* tp_alloc */
0, /* tp_new */
};
static PyMethodDef module_methods[] = {
{"int_to_boolean", int_to_boolean, METH_O,
"Convert an integer to a boolean."},
{"to_str", to_str, METH_O,
"Convert any value to its string representation."},
{"to_float", to_float, METH_O,
"Convert any value to its floating point representation."},
{"str_to_datetime", str_to_datetime, METH_O,
"Convert an ISO string to a datetime.datetime object."},
{"str_to_time", str_to_time, METH_O,
"Convert an ISO string to a datetime.time object."},
{"str_to_date", str_to_date, METH_O,
"Convert an ISO string to a datetime.date object."},
{NULL, NULL, 0, NULL} /* Sentinel */
};
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
#define PyMODINIT_FUNC void
#endif
#if PY_MAJOR_VERSION >= 3
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
MODULE_NAME,
MODULE_DOC,
-1,
module_methods
};
#define INITERROR return NULL
PyObject *
PyInit_cprocessors(void)
#else
#define INITERROR return
PyMODINIT_FUNC
initcprocessors(void)
#endif
{
PyObject *m;
UnicodeResultProcessorType.tp_new = PyType_GenericNew;
if (PyType_Ready(&UnicodeResultProcessorType) < 0)
INITERROR;
DecimalResultProcessorType.tp_new = PyType_GenericNew;
if (PyType_Ready(&DecimalResultProcessorType) < 0)
INITERROR;
#if PY_MAJOR_VERSION >= 3
m = PyModule_Create(&module_def);
#else
m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC);
#endif
if (m == NULL)
INITERROR;
PyDateTime_IMPORT;
Py_INCREF(&UnicodeResultProcessorType);
PyModule_AddObject(m, "UnicodeResultProcessor",
(PyObject *)&UnicodeResultProcessorType);
Py_INCREF(&DecimalResultProcessorType);
PyModule_AddObject(m, "DecimalResultProcessor",
(PyObject *)&DecimalResultProcessorType);
#if PY_MAJOR_VERSION >= 3
return m;
#endif
}

View File

@@ -1,718 +0,0 @@
/*
resultproxy.c
Copyright (C) 2010-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
Copyright (C) 2010-2011 Gaetan de Menten gdementen@gmail.com
This module is part of SQLAlchemy and is released under
the MIT License: http://www.opensource.org/licenses/mit-license.php
*/
#include <Python.h>
#define MODULE_NAME "cresultproxy"
#define MODULE_DOC "Module containing C versions of core ResultProxy classes."
#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN)
typedef int Py_ssize_t;
#define PY_SSIZE_T_MAX INT_MAX
#define PY_SSIZE_T_MIN INT_MIN
typedef Py_ssize_t (*lenfunc)(PyObject *);
#define PyInt_FromSsize_t(x) PyInt_FromLong(x)
typedef intargfunc ssizeargfunc;
#endif
/***********
* Structs *
***********/
typedef struct {
PyObject_HEAD
PyObject *parent;
PyObject *row;
PyObject *processors;
PyObject *keymap;
} BaseRowProxy;
/****************
* BaseRowProxy *
****************/
static PyObject *
safe_rowproxy_reconstructor(PyObject *self, PyObject *args)
{
PyObject *cls, *state, *tmp;
BaseRowProxy *obj;
if (!PyArg_ParseTuple(args, "OO", &cls, &state))
return NULL;
obj = (BaseRowProxy *)PyObject_CallMethod(cls, "__new__", "O", cls);
if (obj == NULL)
return NULL;
tmp = PyObject_CallMethod((PyObject *)obj, "__setstate__", "O", state);
if (tmp == NULL) {
Py_DECREF(obj);
return NULL;
}
Py_DECREF(tmp);
if (obj->parent == NULL || obj->row == NULL ||
obj->processors == NULL || obj->keymap == NULL) {
PyErr_SetString(PyExc_RuntimeError,
"__setstate__ for BaseRowProxy subclasses must set values "
"for parent, row, processors and keymap");
Py_DECREF(obj);
return NULL;
}
return (PyObject *)obj;
}
static int
BaseRowProxy_init(BaseRowProxy *self, PyObject *args, PyObject *kwds)
{
PyObject *parent, *row, *processors, *keymap;
if (!PyArg_UnpackTuple(args, "BaseRowProxy", 4, 4,
&parent, &row, &processors, &keymap))
return -1;
Py_INCREF(parent);
self->parent = parent;
if (!PySequence_Check(row)) {
PyErr_SetString(PyExc_TypeError, "row must be a sequence");
return -1;
}
Py_INCREF(row);
self->row = row;
if (!PyList_CheckExact(processors)) {
PyErr_SetString(PyExc_TypeError, "processors must be a list");
return -1;
}
Py_INCREF(processors);
self->processors = processors;
if (!PyDict_CheckExact(keymap)) {
PyErr_SetString(PyExc_TypeError, "keymap must be a dict");
return -1;
}
Py_INCREF(keymap);
self->keymap = keymap;
return 0;
}
/* We need the reduce method because otherwise the default implementation
* does very weird stuff for pickle protocol 0 and 1. It calls
* BaseRowProxy.__new__(RowProxy_instance) upon *pickling*.
*/
static PyObject *
BaseRowProxy_reduce(PyObject *self)
{
PyObject *method, *state;
PyObject *module, *reconstructor, *cls;
method = PyObject_GetAttrString(self, "__getstate__");
if (method == NULL)
return NULL;
state = PyObject_CallObject(method, NULL);
Py_DECREF(method);
if (state == NULL)
return NULL;
module = PyImport_ImportModule("sqlalchemy.engine.result");
if (module == NULL)
return NULL;
reconstructor = PyObject_GetAttrString(module, "rowproxy_reconstructor");
Py_DECREF(module);
if (reconstructor == NULL) {
Py_DECREF(state);
return NULL;
}
cls = PyObject_GetAttrString(self, "__class__");
if (cls == NULL) {
Py_DECREF(reconstructor);
Py_DECREF(state);
return NULL;
}
return Py_BuildValue("(N(NN))", reconstructor, cls, state);
}
static void
BaseRowProxy_dealloc(BaseRowProxy *self)
{
Py_XDECREF(self->parent);
Py_XDECREF(self->row);
Py_XDECREF(self->processors);
Py_XDECREF(self->keymap);
#if PY_MAJOR_VERSION >= 3
Py_TYPE(self)->tp_free((PyObject *)self);
#else
self->ob_type->tp_free((PyObject *)self);
#endif
}
static PyObject *
BaseRowProxy_processvalues(PyObject *values, PyObject *processors, int astuple)
{
Py_ssize_t num_values, num_processors;
PyObject **valueptr, **funcptr, **resultptr;
PyObject *func, *result, *processed_value, *values_fastseq;
num_values = PySequence_Length(values);
num_processors = PyList_Size(processors);
if (num_values != num_processors) {
PyErr_Format(PyExc_RuntimeError,
"number of values in row (%d) differ from number of column "
"processors (%d)",
(int)num_values, (int)num_processors);
return NULL;
}
if (astuple) {
result = PyTuple_New(num_values);
} else {
result = PyList_New(num_values);
}
if (result == NULL)
return NULL;
values_fastseq = PySequence_Fast(values, "row must be a sequence");
if (values_fastseq == NULL)
return NULL;
valueptr = PySequence_Fast_ITEMS(values_fastseq);
funcptr = PySequence_Fast_ITEMS(processors);
resultptr = PySequence_Fast_ITEMS(result);
while (--num_values >= 0) {
func = *funcptr;
if (func != Py_None) {
processed_value = PyObject_CallFunctionObjArgs(func, *valueptr,
NULL);
if (processed_value == NULL) {
Py_DECREF(values_fastseq);
Py_DECREF(result);
return NULL;
}
*resultptr = processed_value;
} else {
Py_INCREF(*valueptr);
*resultptr = *valueptr;
}
valueptr++;
funcptr++;
resultptr++;
}
Py_DECREF(values_fastseq);
return result;
}
static PyListObject *
BaseRowProxy_values(BaseRowProxy *self)
{
return (PyListObject *)BaseRowProxy_processvalues(self->row,
self->processors, 0);
}
static PyObject *
BaseRowProxy_iter(BaseRowProxy *self)
{
PyObject *values, *result;
values = BaseRowProxy_processvalues(self->row, self->processors, 1);
if (values == NULL)
return NULL;
result = PyObject_GetIter(values);
Py_DECREF(values);
if (result == NULL)
return NULL;
return result;
}
static Py_ssize_t
BaseRowProxy_length(BaseRowProxy *self)
{
return PySequence_Length(self->row);
}
static PyObject *
BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key)
{
PyObject *processors, *values;
PyObject *processor, *value, *processed_value;
PyObject *row, *record, *result, *indexobject;
PyObject *exc_module, *exception, *cstr_obj;
#if PY_MAJOR_VERSION >= 3
PyObject *bytes;
#endif
char *cstr_key;
long index;
int key_fallback = 0;
int tuple_check = 0;
#if PY_MAJOR_VERSION < 3
if (PyInt_CheckExact(key)) {
index = PyInt_AS_LONG(key);
}
#endif
if (PyLong_CheckExact(key)) {
index = PyLong_AsLong(key);
if ((index == -1) && PyErr_Occurred())
/* -1 can be either the actual value, or an error flag. */
return NULL;
} else if (PySlice_Check(key)) {
values = PyObject_GetItem(self->row, key);
if (values == NULL)
return NULL;
processors = PyObject_GetItem(self->processors, key);
if (processors == NULL) {
Py_DECREF(values);
return NULL;
}
result = BaseRowProxy_processvalues(values, processors, 1);
Py_DECREF(values);
Py_DECREF(processors);
return result;
} else {
record = PyDict_GetItem((PyObject *)self->keymap, key);
if (record == NULL) {
record = PyObject_CallMethod(self->parent, "_key_fallback",
"O", key);
if (record == NULL)
return NULL;
key_fallback = 1;
}
indexobject = PyTuple_GetItem(record, 2);
if (indexobject == NULL)
return NULL;
if (key_fallback) {
Py_DECREF(record);
}
if (indexobject == Py_None) {
exc_module = PyImport_ImportModule("sqlalchemy.exc");
if (exc_module == NULL)
return NULL;
exception = PyObject_GetAttrString(exc_module,
"InvalidRequestError");
Py_DECREF(exc_module);
if (exception == NULL)
return NULL;
// wow. this seems quite excessive.
cstr_obj = PyObject_Str(key);
if (cstr_obj == NULL)
return NULL;
/*
FIXME: raise encoding error exception (in both versions below)
if the key contains non-ascii chars, instead of an
InvalidRequestError without any message like in the
python version.
*/
#if PY_MAJOR_VERSION >= 3
bytes = PyUnicode_AsASCIIString(cstr_obj);
if (bytes == NULL)
return NULL;
cstr_key = PyBytes_AS_STRING(bytes);
#else
cstr_key = PyString_AsString(cstr_obj);
#endif
if (cstr_key == NULL) {
Py_DECREF(cstr_obj);
return NULL;
}
Py_DECREF(cstr_obj);
PyErr_Format(exception,
"Ambiguous column name '%.200s' in result set! "
"try 'use_labels' option on select statement.", cstr_key);
return NULL;
}
#if PY_MAJOR_VERSION >= 3
index = PyLong_AsLong(indexobject);
#else
index = PyInt_AsLong(indexobject);
#endif
if ((index == -1) && PyErr_Occurred())
/* -1 can be either the actual value, or an error flag. */
return NULL;
}
processor = PyList_GetItem(self->processors, index);
if (processor == NULL)
return NULL;
row = self->row;
if (PyTuple_CheckExact(row)) {
value = PyTuple_GetItem(row, index);
tuple_check = 1;
}
else {
value = PySequence_GetItem(row, index);
tuple_check = 0;
}
if (value == NULL)
return NULL;
if (processor != Py_None) {
processed_value = PyObject_CallFunctionObjArgs(processor, value, NULL);
if (!tuple_check) {
Py_DECREF(value);
}
return processed_value;
} else {
if (tuple_check) {
Py_INCREF(value);
}
return value;
}
}
static PyObject *
BaseRowProxy_getitem(PyObject *self, Py_ssize_t i)
{
PyObject *index;
#if PY_MAJOR_VERSION >= 3
index = PyLong_FromSsize_t(i);
#else
index = PyInt_FromSsize_t(i);
#endif
return BaseRowProxy_subscript((BaseRowProxy*)self, index);
}
static PyObject *
BaseRowProxy_getattro(BaseRowProxy *self, PyObject *name)
{
PyObject *tmp;
#if PY_MAJOR_VERSION >= 3
PyObject *err_bytes;
#endif
if (!(tmp = PyObject_GenericGetAttr((PyObject *)self, name))) {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
return NULL;
PyErr_Clear();
}
else
return tmp;
tmp = BaseRowProxy_subscript(self, name);
if (tmp == NULL && PyErr_ExceptionMatches(PyExc_KeyError)) {
#if PY_MAJOR_VERSION >= 3
err_bytes = PyUnicode_AsASCIIString(name);
if (err_bytes == NULL)
return NULL;
PyErr_Format(
PyExc_AttributeError,
"Could not locate column in row for column '%.200s'",
PyBytes_AS_STRING(err_bytes)
);
#else
PyErr_Format(
PyExc_AttributeError,
"Could not locate column in row for column '%.200s'",
PyString_AsString(name)
);
#endif
return NULL;
}
return tmp;
}
/***********************
* getters and setters *
***********************/
static PyObject *
BaseRowProxy_getparent(BaseRowProxy *self, void *closure)
{
Py_INCREF(self->parent);
return self->parent;
}
static int
BaseRowProxy_setparent(BaseRowProxy *self, PyObject *value, void *closure)
{
PyObject *module, *cls;
if (value == NULL) {
PyErr_SetString(PyExc_TypeError,
"Cannot delete the 'parent' attribute");
return -1;
}
module = PyImport_ImportModule("sqlalchemy.engine.result");
if (module == NULL)
return -1;
cls = PyObject_GetAttrString(module, "ResultMetaData");
Py_DECREF(module);
if (cls == NULL)
return -1;
if (PyObject_IsInstance(value, cls) != 1) {
PyErr_SetString(PyExc_TypeError,
"The 'parent' attribute value must be an instance of "
"ResultMetaData");
return -1;
}
Py_DECREF(cls);
Py_XDECREF(self->parent);
Py_INCREF(value);
self->parent = value;
return 0;
}
static PyObject *
BaseRowProxy_getrow(BaseRowProxy *self, void *closure)
{
Py_INCREF(self->row);
return self->row;
}
static int
BaseRowProxy_setrow(BaseRowProxy *self, PyObject *value, void *closure)
{
if (value == NULL) {
PyErr_SetString(PyExc_TypeError,
"Cannot delete the 'row' attribute");
return -1;
}
if (!PySequence_Check(value)) {
PyErr_SetString(PyExc_TypeError,
"The 'row' attribute value must be a sequence");
return -1;
}
Py_XDECREF(self->row);
Py_INCREF(value);
self->row = value;
return 0;
}
static PyObject *
BaseRowProxy_getprocessors(BaseRowProxy *self, void *closure)
{
Py_INCREF(self->processors);
return self->processors;
}
static int
BaseRowProxy_setprocessors(BaseRowProxy *self, PyObject *value, void *closure)
{
if (value == NULL) {
PyErr_SetString(PyExc_TypeError,
"Cannot delete the 'processors' attribute");
return -1;
}
if (!PyList_CheckExact(value)) {
PyErr_SetString(PyExc_TypeError,
"The 'processors' attribute value must be a list");
return -1;
}
Py_XDECREF(self->processors);
Py_INCREF(value);
self->processors = value;
return 0;
}
static PyObject *
BaseRowProxy_getkeymap(BaseRowProxy *self, void *closure)
{
Py_INCREF(self->keymap);
return self->keymap;
}
static int
BaseRowProxy_setkeymap(BaseRowProxy *self, PyObject *value, void *closure)
{
if (value == NULL) {
PyErr_SetString(PyExc_TypeError,
"Cannot delete the 'keymap' attribute");
return -1;
}
if (!PyDict_CheckExact(value)) {
PyErr_SetString(PyExc_TypeError,
"The 'keymap' attribute value must be a dict");
return -1;
}
Py_XDECREF(self->keymap);
Py_INCREF(value);
self->keymap = value;
return 0;
}
static PyGetSetDef BaseRowProxy_getseters[] = {
{"_parent",
(getter)BaseRowProxy_getparent, (setter)BaseRowProxy_setparent,
"ResultMetaData",
NULL},
{"_row",
(getter)BaseRowProxy_getrow, (setter)BaseRowProxy_setrow,
"Original row tuple",
NULL},
{"_processors",
(getter)BaseRowProxy_getprocessors, (setter)BaseRowProxy_setprocessors,
"list of type processors",
NULL},
{"_keymap",
(getter)BaseRowProxy_getkeymap, (setter)BaseRowProxy_setkeymap,
"Key to (processor, index) dict",
NULL},
{NULL}
};
static PyMethodDef BaseRowProxy_methods[] = {
{"values", (PyCFunction)BaseRowProxy_values, METH_NOARGS,
"Return the values represented by this BaseRowProxy as a list."},
{"__reduce__", (PyCFunction)BaseRowProxy_reduce, METH_NOARGS,
"Pickle support method."},
{NULL} /* Sentinel */
};
static PySequenceMethods BaseRowProxy_as_sequence = {
(lenfunc)BaseRowProxy_length, /* sq_length */
0, /* sq_concat */
0, /* sq_repeat */
(ssizeargfunc)BaseRowProxy_getitem, /* sq_item */
0, /* sq_slice */
0, /* sq_ass_item */
0, /* sq_ass_slice */
0, /* sq_contains */
0, /* sq_inplace_concat */
0, /* sq_inplace_repeat */
};
static PyMappingMethods BaseRowProxy_as_mapping = {
(lenfunc)BaseRowProxy_length, /* mp_length */
(binaryfunc)BaseRowProxy_subscript, /* mp_subscript */
0 /* mp_ass_subscript */
};
static PyTypeObject BaseRowProxyType = {
PyVarObject_HEAD_INIT(NULL, 0)
"sqlalchemy.cresultproxy.BaseRowProxy", /* tp_name */
sizeof(BaseRowProxy), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)BaseRowProxy_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
&BaseRowProxy_as_sequence, /* tp_as_sequence */
&BaseRowProxy_as_mapping, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
(getattrofunc)BaseRowProxy_getattro,/* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
"BaseRowProxy is a abstract base class for RowProxy", /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
(getiterfunc)BaseRowProxy_iter, /* tp_iter */
0, /* tp_iternext */
BaseRowProxy_methods, /* tp_methods */
0, /* tp_members */
BaseRowProxy_getseters, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc)BaseRowProxy_init, /* tp_init */
0, /* tp_alloc */
0 /* tp_new */
};
static PyMethodDef module_methods[] = {
{"safe_rowproxy_reconstructor", safe_rowproxy_reconstructor, METH_VARARGS,
"reconstruct a RowProxy instance from its pickled form."},
{NULL, NULL, 0, NULL} /* Sentinel */
};
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
#define PyMODINIT_FUNC void
#endif
#if PY_MAJOR_VERSION >= 3
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
MODULE_NAME,
MODULE_DOC,
-1,
module_methods
};
#define INITERROR return NULL
PyObject *
PyInit_cresultproxy(void)
#else
#define INITERROR return
PyMODINIT_FUNC
initcresultproxy(void)
#endif
{
PyObject *m;
BaseRowProxyType.tp_new = PyType_GenericNew;
if (PyType_Ready(&BaseRowProxyType) < 0)
INITERROR;
#if PY_MAJOR_VERSION >= 3
m = PyModule_Create(&module_def);
#else
m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC);
#endif
if (m == NULL)
INITERROR;
Py_INCREF(&BaseRowProxyType);
PyModule_AddObject(m, "BaseRowProxy", (PyObject *)&BaseRowProxyType);
#if PY_MAJOR_VERSION >= 3
return m;
#endif
}

View File

@@ -1,225 +0,0 @@
/*
utils.c
Copyright (C) 2012-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
This module is part of SQLAlchemy and is released under
the MIT License: http://www.opensource.org/licenses/mit-license.php
*/
#include <Python.h>
#define MODULE_NAME "cutils"
#define MODULE_DOC "Module containing C versions of utility functions."
/*
Given arguments from the calling form *multiparams, **params,
return a list of bind parameter structures, usually a list of
dictionaries.
In the case of 'raw' execution which accepts positional parameters,
it may be a list of tuples or lists.
*/
static PyObject *
distill_params(PyObject *self, PyObject *args)
{
PyObject *multiparams, *params;
PyObject *enclosing_list, *double_enclosing_list;
PyObject *zero_element, *zero_element_item;
Py_ssize_t multiparam_size, zero_element_length;
if (!PyArg_UnpackTuple(args, "_distill_params", 2, 2, &multiparams, &params)) {
return NULL;
}
if (multiparams != Py_None) {
multiparam_size = PyTuple_Size(multiparams);
if (multiparam_size < 0) {
return NULL;
}
}
else {
multiparam_size = 0;
}
if (multiparam_size == 0) {
if (params != Py_None && PyDict_Size(params) != 0) {
enclosing_list = PyList_New(1);
if (enclosing_list == NULL) {
return NULL;
}
Py_INCREF(params);
if (PyList_SetItem(enclosing_list, 0, params) == -1) {
Py_DECREF(params);
Py_DECREF(enclosing_list);
return NULL;
}
}
else {
enclosing_list = PyList_New(0);
if (enclosing_list == NULL) {
return NULL;
}
}
return enclosing_list;
}
else if (multiparam_size == 1) {
zero_element = PyTuple_GetItem(multiparams, 0);
if (PyTuple_Check(zero_element) || PyList_Check(zero_element)) {
zero_element_length = PySequence_Length(zero_element);
if (zero_element_length != 0) {
zero_element_item = PySequence_GetItem(zero_element, 0);
if (zero_element_item == NULL) {
return NULL;
}
}
else {
zero_element_item = NULL;
}
if (zero_element_length == 0 ||
(
PyObject_HasAttrString(zero_element_item, "__iter__") &&
!PyObject_HasAttrString(zero_element_item, "strip")
)
) {
/*
* execute(stmt, [{}, {}, {}, ...])
* execute(stmt, [(), (), (), ...])
*/
Py_XDECREF(zero_element_item);
Py_INCREF(zero_element);
return zero_element;
}
else {
/*
* execute(stmt, ("value", "value"))
*/
Py_XDECREF(zero_element_item);
enclosing_list = PyList_New(1);
if (enclosing_list == NULL) {
return NULL;
}
Py_INCREF(zero_element);
if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) {
Py_DECREF(zero_element);
Py_DECREF(enclosing_list);
return NULL;
}
return enclosing_list;
}
}
else if (PyObject_HasAttrString(zero_element, "keys")) {
/*
* execute(stmt, {"key":"value"})
*/
enclosing_list = PyList_New(1);
if (enclosing_list == NULL) {
return NULL;
}
Py_INCREF(zero_element);
if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) {
Py_DECREF(zero_element);
Py_DECREF(enclosing_list);
return NULL;
}
return enclosing_list;
} else {
enclosing_list = PyList_New(1);
if (enclosing_list == NULL) {
return NULL;
}
double_enclosing_list = PyList_New(1);
if (double_enclosing_list == NULL) {
Py_DECREF(enclosing_list);
return NULL;
}
Py_INCREF(zero_element);
if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) {
Py_DECREF(zero_element);
Py_DECREF(enclosing_list);
Py_DECREF(double_enclosing_list);
return NULL;
}
if (PyList_SetItem(double_enclosing_list, 0, enclosing_list) == -1) {
Py_DECREF(zero_element);
Py_DECREF(enclosing_list);
Py_DECREF(double_enclosing_list);
return NULL;
}
return double_enclosing_list;
}
}
else {
zero_element = PyTuple_GetItem(multiparams, 0);
if (PyObject_HasAttrString(zero_element, "__iter__") &&
!PyObject_HasAttrString(zero_element, "strip")
) {
Py_INCREF(multiparams);
return multiparams;
}
else {
enclosing_list = PyList_New(1);
if (enclosing_list == NULL) {
return NULL;
}
Py_INCREF(multiparams);
if (PyList_SetItem(enclosing_list, 0, multiparams) == -1) {
Py_DECREF(multiparams);
Py_DECREF(enclosing_list);
return NULL;
}
return enclosing_list;
}
}
}
static PyMethodDef module_methods[] = {
{"_distill_params", distill_params, METH_VARARGS,
"Distill an execute() parameter structure."},
{NULL, NULL, 0, NULL} /* Sentinel */
};
#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
#define PyMODINIT_FUNC void
#endif
#if PY_MAJOR_VERSION >= 3
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
MODULE_NAME,
MODULE_DOC,
-1,
module_methods
};
#endif
#if PY_MAJOR_VERSION >= 3
PyObject *
PyInit_cutils(void)
#else
PyMODINIT_FUNC
initcutils(void)
#endif
{
PyObject *m;
#if PY_MAJOR_VERSION >= 3
m = PyModule_Create(&module_def);
#else
m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC);
#endif
#if PY_MAJOR_VERSION >= 3
if (m == NULL)
return NULL;
return m;
#else
if (m == NULL)
return;
#endif
}

View File

@@ -1,9 +0,0 @@
# connectors/__init__.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
class Connector(object):
pass

View File

@@ -1,149 +0,0 @@
# connectors/mxodbc.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
Provide an SQLALchemy connector for the eGenix mxODBC commercial
Python adapter for ODBC. This is not a free product, but eGenix
provides SQLAlchemy with a license for use in continuous integration
testing.
This has been tested for use with mxODBC 3.1.2 on SQL Server 2005
and 2008, using the SQL Server Native driver. However, it is
possible for this to be used on other database platforms.
For more info on mxODBC, see http://www.egenix.com/
"""
import sys
import re
import warnings
from . import Connector
class MxODBCConnector(Connector):
driver = 'mxodbc'
supports_sane_multi_rowcount = False
supports_unicode_statements = True
supports_unicode_binds = True
supports_native_decimal = True
@classmethod
def dbapi(cls):
# this classmethod will normally be replaced by an instance
# attribute of the same name, so this is normally only called once.
cls._load_mx_exceptions()
platform = sys.platform
if platform == 'win32':
from mx.ODBC import Windows as module
# this can be the string "linux2", and possibly others
elif 'linux' in platform:
from mx.ODBC import unixODBC as module
elif platform == 'darwin':
from mx.ODBC import iODBC as module
else:
raise ImportError("Unrecognized platform for mxODBC import")
return module
@classmethod
def _load_mx_exceptions(cls):
""" Import mxODBC exception classes into the module namespace,
as if they had been imported normally. This is done here
to avoid requiring all SQLAlchemy users to install mxODBC.
"""
global InterfaceError, ProgrammingError
from mx.ODBC import InterfaceError
from mx.ODBC import ProgrammingError
def on_connect(self):
def connect(conn):
conn.stringformat = self.dbapi.MIXED_STRINGFORMAT
conn.datetimeformat = self.dbapi.PYDATETIME_DATETIMEFORMAT
conn.decimalformat = self.dbapi.DECIMAL_DECIMALFORMAT
conn.errorhandler = self._error_handler()
return connect
def _error_handler(self):
""" Return a handler that adjusts mxODBC's raised Warnings to
emit Python standard warnings.
"""
from mx.ODBC.Error import Warning as MxOdbcWarning
def error_handler(connection, cursor, errorclass, errorvalue):
if issubclass(errorclass, MxOdbcWarning):
errorclass.__bases__ = (Warning,)
warnings.warn(message=str(errorvalue),
category=errorclass,
stacklevel=2)
else:
raise errorclass(errorvalue)
return error_handler
def create_connect_args(self, url):
""" Return a tuple of *args,**kwargs for creating a connection.
The mxODBC 3.x connection constructor looks like this:
connect(dsn, user='', password='',
clear_auto_commit=1, errorhandler=None)
This method translates the values in the provided uri
into args and kwargs needed to instantiate an mxODBC Connection.
The arg 'errorhandler' is not used by SQLAlchemy and will
not be populated.
"""
opts = url.translate_connect_args(username='user')
opts.update(url.query)
args = opts.pop('host')
opts.pop('port', None)
opts.pop('database', None)
return (args,), opts
def is_disconnect(self, e, connection, cursor):
# TODO: eGenix recommends checking connection.closed here
# Does that detect dropped connections ?
if isinstance(e, self.dbapi.ProgrammingError):
return "connection already closed" in str(e)
elif isinstance(e, self.dbapi.Error):
return '[08S01]' in str(e)
else:
return False
def _get_server_version_info(self, connection):
# eGenix suggests using conn.dbms_version instead
# of what we're doing here
dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
# 18 == pyodbc.SQL_DBMS_VER
for n in r.split(dbapi_con.getinfo(18)[1]):
try:
version.append(int(n))
except ValueError:
version.append(n)
return tuple(version)
def _get_direct(self, context):
if context:
native_odbc_execute = context.execution_options.\
get('native_odbc_execute', 'auto')
# default to direct=True in all cases, is more generally
# compatible especially with SQL Server
return False if native_odbc_execute is True else True
else:
return True
def do_executemany(self, cursor, statement, parameters, context=None):
cursor.executemany(
statement, parameters, direct=self._get_direct(context))
def do_execute(self, cursor, statement, parameters, context=None):
cursor.execute(statement, parameters, direct=self._get_direct(context))

View File

@@ -1,162 +0,0 @@
# connectors/mysqldb.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Define behaviors common to MySQLdb dialects.
Currently includes MySQL and Drizzle.
"""
from . import Connector
from ..engine import base as engine_base, default
from ..sql import operators as sql_operators
from .. import exc, log, schema, sql, types as sqltypes, util, processors
import re
# the subclassing of Connector by all classes
# here is not strictly necessary
class MySQLDBExecutionContext(Connector):
@property
def rowcount(self):
if hasattr(self, '_rowcount'):
return self._rowcount
else:
return self.cursor.rowcount
class MySQLDBCompiler(Connector):
def visit_mod_binary(self, binary, operator, **kw):
return self.process(binary.left, **kw) + " %% " + \
self.process(binary.right, **kw)
def post_process_text(self, text):
return text.replace('%', '%%')
class MySQLDBIdentifierPreparer(Connector):
def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote)
return value.replace("%", "%%")
class MySQLDBConnector(Connector):
driver = 'mysqldb'
supports_unicode_statements = False
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
supports_native_decimal = True
default_paramstyle = 'format'
@classmethod
def dbapi(cls):
# is overridden when pymysql is used
return __import__('MySQLdb')
def do_executemany(self, cursor, statement, parameters, context=None):
rowcount = cursor.executemany(statement, parameters)
if context is not None:
context._rowcount = rowcount
def create_connect_args(self, url):
opts = url.translate_connect_args(database='db', username='user',
password='passwd')
opts.update(url.query)
util.coerce_kw_type(opts, 'compress', bool)
util.coerce_kw_type(opts, 'connect_timeout', int)
util.coerce_kw_type(opts, 'read_timeout', int)
util.coerce_kw_type(opts, 'client_flag', int)
util.coerce_kw_type(opts, 'local_infile', int)
# Note: using either of the below will cause all strings to be returned
# as Unicode, both in raw SQL operations and with column types like
# String and MSString.
util.coerce_kw_type(opts, 'use_unicode', bool)
util.coerce_kw_type(opts, 'charset', str)
# Rich values 'cursorclass' and 'conv' are not supported via
# query string.
ssl = {}
keys = ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']
for key in keys:
if key in opts:
ssl[key[4:]] = opts[key]
util.coerce_kw_type(ssl, key[4:], str)
del opts[key]
if ssl:
opts['ssl'] = ssl
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
# supports_sane_rowcount.
client_flag = opts.get('client_flag', 0)
if self.dbapi is not None:
try:
CLIENT_FLAGS = __import__(
self.dbapi.__name__ + '.constants.CLIENT'
).constants.CLIENT
client_flag |= CLIENT_FLAGS.FOUND_ROWS
except (AttributeError, ImportError):
self.supports_sane_rowcount = False
opts['client_flag'] = client_flag
return [[], opts]
def _get_server_version_info(self, connection):
dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
for n in r.split(dbapi_con.get_server_info()):
try:
version.append(int(n))
except ValueError:
version.append(n)
return tuple(version)
def _extract_error_code(self, exception):
return exception.args[0]
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
# Note: MySQL-python 1.2.1c7 seems to ignore changes made
# on a connection via set_character_set()
if self.server_version_info < (4, 1, 0):
try:
return connection.connection.character_set_name()
except AttributeError:
# < 1.2.1 final MySQL-python drivers have no charset support.
# a query is needed.
pass
# Prefer 'character_set_results' for the current connection over the
# value in the driver. SET NAMES or individual variable SETs will
# change the charset without updating the driver's view of the world.
#
# If it's decided that issuing that sort of SQL leaves you SOL, then
# this can prefer the driver value.
rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)])
if 'character_set_results' in opts:
return opts['character_set_results']
try:
return connection.connection.character_set_name()
except AttributeError:
# Still no charset on < 1.2.1 final...
if 'character_set' in opts:
return opts['character_set']
else:
util.warn(
"Could not detect the connection character set with this "
"combination of MySQL server and MySQL-python. "
"MySQL-python >= 1.2.2 is recommended. Assuming latin1.")
return 'latin1'

View File

@@ -1,170 +0,0 @@
# connectors/pyodbc.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from . import Connector
from .. import util
import sys
import re
class PyODBCConnector(Connector):
driver = 'pyodbc'
supports_sane_multi_rowcount = False
if util.py2k:
# PyODBC unicode is broken on UCS-4 builds
supports_unicode = sys.maxunicode == 65535
supports_unicode_statements = supports_unicode
supports_native_decimal = True
default_paramstyle = 'named'
# for non-DSN connections, this should
# hold the desired driver name
pyodbc_driver_name = None
# will be set to True after initialize()
# if the freetds.so is detected
freetds = False
# will be set to the string version of
# the FreeTDS driver if freetds is detected
freetds_driver_version = None
# will be set to True after initialize()
# if the libessqlsrv.so is detected
easysoft = False
def __init__(self, supports_unicode_binds=None, **kw):
super(PyODBCConnector, self).__init__(**kw)
self._user_supports_unicode_binds = supports_unicode_binds
@classmethod
def dbapi(cls):
return __import__('pyodbc')
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
opts.update(url.query)
keys = opts
query = url.query
connect_args = {}
for param in ('ansi', 'unicode_results', 'autocommit'):
if param in keys:
connect_args[param] = util.asbool(keys.pop(param))
if 'odbc_connect' in keys:
connectors = [util.unquote_plus(keys.pop('odbc_connect'))]
else:
dsn_connection = 'dsn' in keys or \
('host' in keys and 'database' not in keys)
if dsn_connection:
connectors = ['dsn=%s' % (keys.pop('host', '') or \
keys.pop('dsn', ''))]
else:
port = ''
if 'port' in keys and not 'port' in query:
port = ',%d' % int(keys.pop('port'))
connectors = ["DRIVER={%s}" %
keys.pop('driver', self.pyodbc_driver_name),
'Server=%s%s' % (keys.pop('host', ''), port),
'Database=%s' % keys.pop('database', '')]
user = keys.pop("user", None)
if user:
connectors.append("UID=%s" % user)
connectors.append("PWD=%s" % keys.pop('password', ''))
else:
connectors.append("Trusted_Connection=Yes")
# if set to 'Yes', the ODBC layer will try to automagically
# convert textual data from your database encoding to your
# client encoding. This should obviously be set to 'No' if
# you query a cp1253 encoded database from a latin1 client...
if 'odbc_autotranslate' in keys:
connectors.append("AutoTranslate=%s" %
keys.pop("odbc_autotranslate"))
connectors.extend(['%s=%s' % (k, v) for k, v in keys.items()])
return [[";".join(connectors)], connect_args]
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.ProgrammingError):
return "The cursor's connection has been closed." in str(e) or \
'Attempt to use a closed connection.' in str(e)
elif isinstance(e, self.dbapi.Error):
return '[08S01]' in str(e)
else:
return False
def initialize(self, connection):
# determine FreeTDS first. can't issue SQL easily
# without getting unicode_statements/binds set up.
pyodbc = self.dbapi
dbapi_con = connection.connection
_sql_driver_name = dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME)
self.freetds = bool(re.match(r".*libtdsodbc.*\.so", _sql_driver_name
))
self.easysoft = bool(re.match(r".*libessqlsrv.*\.so", _sql_driver_name
))
if self.freetds:
self.freetds_driver_version = dbapi_con.getinfo(
pyodbc.SQL_DRIVER_VER)
self.supports_unicode_statements = (
not util.py2k or
(not self.freetds and not self.easysoft)
)
if self._user_supports_unicode_binds is not None:
self.supports_unicode_binds = self._user_supports_unicode_binds
elif util.py2k:
self.supports_unicode_binds = (
not self.freetds or self.freetds_driver_version >= '0.91'
) and not self.easysoft
else:
self.supports_unicode_binds = True
# run other initialization which asks for user name, etc.
super(PyODBCConnector, self).initialize(connection)
def _dbapi_version(self):
if not self.dbapi:
return ()
return self._parse_dbapi_version(self.dbapi.version)
def _parse_dbapi_version(self, vers):
m = re.match(
r'(?:py.*-)?([\d\.]+)(?:-(\w+))?',
vers
)
if not m:
return ()
vers = tuple([int(x) for x in m.group(1).split(".")])
if m.group(2):
vers += (m.group(2),)
return vers
def _get_server_version_info(self, connection):
dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
try:
version.append(int(n))
except ValueError:
version.append(n)
return tuple(version)

View File

@@ -1,59 +0,0 @@
# connectors/zxJDBC.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import sys
from . import Connector
class ZxJDBCConnector(Connector):
driver = 'zxjdbc'
supports_sane_rowcount = False
supports_sane_multi_rowcount = False
supports_unicode_binds = True
supports_unicode_statements = sys.version > '2.5.0+'
description_encoding = None
default_paramstyle = 'qmark'
jdbc_db_name = None
jdbc_driver_name = None
@classmethod
def dbapi(cls):
from com.ziclix.python.sql import zxJDBC
return zxJDBC
def _driver_kwargs(self):
"""Return kw arg dict to be sent to connect()."""
return {}
def _create_jdbc_url(self, url):
"""Create a JDBC url from a :class:`~sqlalchemy.engine.url.URL`"""
return 'jdbc:%s://%s%s/%s' % (self.jdbc_db_name, url.host,
url.port is not None
and ':%s' % url.port or '',
url.database)
def create_connect_args(self, url):
opts = self._driver_kwargs()
opts.update(url.query)
return [
[self._create_jdbc_url(url),
url.username, url.password,
self.jdbc_driver_name],
opts]
def is_disconnect(self, e, connection, cursor):
if not isinstance(e, self.dbapi.ProgrammingError):
return False
e = str(e)
return 'connection is closed' in e or 'cursor is closed' in e
def _get_server_version_info(self, connection):
# use connection.connection.dbversion, and parse appropriately
# to get a tuple
raise NotImplementedError()

View File

@@ -1,31 +0,0 @@
# databases/__init__.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Include imports from the sqlalchemy.dialects package for backwards
compatibility with pre 0.6 versions.
"""
from ..dialects.sqlite import base as sqlite
from ..dialects.postgresql import base as postgresql
postgres = postgresql
from ..dialects.mysql import base as mysql
from ..dialects.drizzle import base as drizzle
from ..dialects.oracle import base as oracle
from ..dialects.firebird import base as firebird
from ..dialects.mssql import base as mssql
from ..dialects.sybase import base as sybase
__all__ = (
'drizzle',
'firebird',
'mssql',
'mysql',
'postgresql',
'sqlite',
'oracle',
'sybase',
)

View File

@@ -1,44 +0,0 @@
# dialects/__init__.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
__all__ = (
'drizzle',
'firebird',
'mssql',
'mysql',
'oracle',
'postgresql',
'sqlite',
'sybase',
)
from .. import util
def _auto_fn(name):
"""default dialect importer.
plugs into the :class:`.PluginLoader`
as a first-hit system.
"""
if "." in name:
dialect, driver = name.split(".")
else:
dialect = name
driver = "base"
try:
module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects
except ImportError:
return None
module = getattr(module, dialect)
if hasattr(module, driver):
module = getattr(module, driver)
return lambda: module.dialect
else:
return None
registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn)

View File

@@ -1,22 +0,0 @@
from sqlalchemy.dialects.drizzle import base, mysqldb
base.dialect = mysqldb.dialect
from sqlalchemy.dialects.drizzle.base import \
BIGINT, BINARY, BLOB, \
BOOLEAN, CHAR, DATE, \
DATETIME, DECIMAL, DOUBLE, \
ENUM, FLOAT, INTEGER, \
NUMERIC, REAL, TEXT, \
TIME, TIMESTAMP, VARBINARY, \
VARCHAR, dialect
__all__ = (
'BIGINT', 'BINARY', 'BLOB',
'BOOLEAN', 'CHAR', 'DATE',
'DATETIME', 'DECIMAL', 'DOUBLE',
'ENUM', 'FLOAT', 'INTEGER',
'NUMERIC', 'REAL', 'TEXT',
'TIME', 'TIMESTAMP', 'VARBINARY',
'VARCHAR', 'dialect'
)

View File

@@ -1,498 +0,0 @@
# drizzle/base.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
# Copyright (C) 2010-2011 Monty Taylor <mordred@inaugust.com>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: drizzle
:name: Drizzle
Drizzle is a variant of MySQL. Unlike MySQL, Drizzle's default storage engine
is InnoDB (transactions, foreign-keys) rather than MyISAM. For more
`Notable Differences <http://docs.drizzle.org/mysql_differences.html>`_, visit
the `Drizzle Documentation <http://docs.drizzle.org/index.html>`_.
The SQLAlchemy Drizzle dialect leans heavily on the MySQL dialect, so much of
the :doc:`SQLAlchemy MySQL <mysql>` documentation is also relevant.
"""
from sqlalchemy import exc
from sqlalchemy import log
from sqlalchemy import types as sqltypes
from sqlalchemy.engine import reflection
from sqlalchemy.dialects.mysql import base as mysql_dialect
from sqlalchemy.types import DATE, DATETIME, BOOLEAN, TIME, \
BLOB, BINARY, VARBINARY
class _NumericType(object):
"""Base for Drizzle numeric types."""
def __init__(self, **kw):
super(_NumericType, self).__init__(**kw)
class _FloatType(_NumericType, sqltypes.Float):
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
if isinstance(self, (REAL, DOUBLE)) and \
(
(precision is None and scale is not None) or
(precision is not None and scale is None)
):
raise exc.ArgumentError(
"You must specify both precision and scale or omit "
"both altogether.")
super(_FloatType, self).__init__(precision=precision,
asdecimal=asdecimal, **kw)
self.scale = scale
class _StringType(mysql_dialect._StringType):
"""Base for Drizzle string types."""
def __init__(self, collation=None, binary=False, **kw):
kw['national'] = False
super(_StringType, self).__init__(collation=collation, binary=binary,
**kw)
class NUMERIC(_NumericType, sqltypes.NUMERIC):
"""Drizzle NUMERIC type."""
__visit_name__ = 'NUMERIC'
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a NUMERIC.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
"""
super(NUMERIC, self).__init__(precision=precision, scale=scale,
asdecimal=asdecimal, **kw)
class DECIMAL(_NumericType, sqltypes.DECIMAL):
"""Drizzle DECIMAL type."""
__visit_name__ = 'DECIMAL'
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a DECIMAL.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
"""
super(DECIMAL, self).__init__(precision=precision, scale=scale,
asdecimal=asdecimal, **kw)
class DOUBLE(_FloatType):
"""Drizzle DOUBLE type."""
__visit_name__ = 'DOUBLE'
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a DOUBLE.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
"""
super(DOUBLE, self).__init__(precision=precision, scale=scale,
asdecimal=asdecimal, **kw)
class REAL(_FloatType, sqltypes.REAL):
"""Drizzle REAL type."""
__visit_name__ = 'REAL'
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a REAL.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
"""
super(REAL, self).__init__(precision=precision, scale=scale,
asdecimal=asdecimal, **kw)
class FLOAT(_FloatType, sqltypes.FLOAT):
"""Drizzle FLOAT type."""
__visit_name__ = 'FLOAT'
def __init__(self, precision=None, scale=None, asdecimal=False, **kw):
"""Construct a FLOAT.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
"""
super(FLOAT, self).__init__(precision=precision, scale=scale,
asdecimal=asdecimal, **kw)
def bind_processor(self, dialect):
return None
class INTEGER(sqltypes.INTEGER):
"""Drizzle INTEGER type."""
__visit_name__ = 'INTEGER'
def __init__(self, **kw):
"""Construct an INTEGER."""
super(INTEGER, self).__init__(**kw)
class BIGINT(sqltypes.BIGINT):
"""Drizzle BIGINTEGER type."""
__visit_name__ = 'BIGINT'
def __init__(self, **kw):
"""Construct a BIGINTEGER."""
super(BIGINT, self).__init__(**kw)
class TIME(mysql_dialect.TIME):
"""Drizzle TIME type."""
class TIMESTAMP(sqltypes.TIMESTAMP):
"""Drizzle TIMESTAMP type."""
__visit_name__ = 'TIMESTAMP'
class TEXT(_StringType, sqltypes.TEXT):
"""Drizzle TEXT type, for text up to 2^16 characters."""
__visit_name__ = 'TEXT'
def __init__(self, length=None, **kw):
"""Construct a TEXT.
:param length: Optional, if provided the server may optimize storage
by substituting the smallest TEXT type sufficient to store
``length`` characters.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
super(TEXT, self).__init__(length=length, **kw)
class VARCHAR(_StringType, sqltypes.VARCHAR):
"""Drizzle VARCHAR type, for variable-length character data."""
__visit_name__ = 'VARCHAR'
def __init__(self, length=None, **kwargs):
"""Construct a VARCHAR.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
super(VARCHAR, self).__init__(length=length, **kwargs)
class CHAR(_StringType, sqltypes.CHAR):
"""Drizzle CHAR type, for fixed-length character data."""
__visit_name__ = 'CHAR'
def __init__(self, length=None, **kwargs):
"""Construct a CHAR.
:param length: Maximum data length, in characters.
:param binary: Optional, use the default binary collation for the
national character set. This does not affect the type of data
stored, use a BINARY type for binary data.
:param collation: Optional, request a particular collation. Must be
compatible with the national character set.
"""
super(CHAR, self).__init__(length=length, **kwargs)
class ENUM(mysql_dialect.ENUM):
"""Drizzle ENUM type."""
def __init__(self, *enums, **kw):
"""Construct an ENUM.
Example:
Column('myenum', ENUM("foo", "bar", "baz"))
:param enums: The range of valid values for this ENUM. Values will be
quoted when generating the schema according to the quoting flag (see
below).
:param strict: Defaults to False: ensure that a given value is in this
ENUM's range of permissible values when inserting or updating rows.
Note that Drizzle will not raise a fatal error if you attempt to
store an out of range value- an alternate value will be stored
instead.
(See Drizzle ENUM documentation.)
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
:param quoting: Defaults to 'auto': automatically determine enum value
quoting. If all enum values are surrounded by the same quoting
character, then use 'quoted' mode. Otherwise, use 'unquoted' mode.
'quoted': values in enums are already quoted, they will be used
directly when generating the schema - this usage is deprecated.
'unquoted': values in enums are not quoted, they will be escaped and
surrounded by single quotes when generating the schema.
Previous versions of this type always required manually quoted
values to be supplied; future versions will always quote the string
literals for you. This is a transitional option.
"""
super(ENUM, self).__init__(*enums, **kw)
class _DrizzleBoolean(sqltypes.Boolean):
def get_dbapi_type(self, dbapi):
return dbapi.NUMERIC
colspecs = {
sqltypes.Numeric: NUMERIC,
sqltypes.Float: FLOAT,
sqltypes.Time: TIME,
sqltypes.Enum: ENUM,
sqltypes.Boolean: _DrizzleBoolean,
}
# All the types we have in Drizzle
ischema_names = {
'BIGINT': BIGINT,
'BINARY': BINARY,
'BLOB': BLOB,
'BOOLEAN': BOOLEAN,
'CHAR': CHAR,
'DATE': DATE,
'DATETIME': DATETIME,
'DECIMAL': DECIMAL,
'DOUBLE': DOUBLE,
'ENUM': ENUM,
'FLOAT': FLOAT,
'INT': INTEGER,
'INTEGER': INTEGER,
'NUMERIC': NUMERIC,
'TEXT': TEXT,
'TIME': TIME,
'TIMESTAMP': TIMESTAMP,
'VARBINARY': VARBINARY,
'VARCHAR': VARCHAR,
}
class DrizzleCompiler(mysql_dialect.MySQLCompiler):
def visit_typeclause(self, typeclause):
type_ = typeclause.type.dialect_impl(self.dialect)
if isinstance(type_, sqltypes.Integer):
return 'INTEGER'
else:
return super(DrizzleCompiler, self).visit_typeclause(typeclause)
def visit_cast(self, cast, **kwargs):
type_ = self.process(cast.typeclause)
if type_ is None:
return self.process(cast.clause)
return 'CAST(%s AS %s)' % (self.process(cast.clause), type_)
class DrizzleDDLCompiler(mysql_dialect.MySQLDDLCompiler):
pass
class DrizzleTypeCompiler(mysql_dialect.MySQLTypeCompiler):
def _extend_numeric(self, type_, spec):
return spec
def _extend_string(self, type_, defaults, spec):
"""Extend a string-type declaration with standard SQL
COLLATE annotations and Drizzle specific extensions.
"""
def attr(name):
return getattr(type_, name, defaults.get(name))
if attr('collation'):
collation = 'COLLATE %s' % type_.collation
elif attr('binary'):
collation = 'BINARY'
else:
collation = None
return ' '.join([c for c in (spec, collation)
if c is not None])
def visit_NCHAR(self, type):
raise NotImplementedError("Drizzle does not support NCHAR")
def visit_NVARCHAR(self, type):
raise NotImplementedError("Drizzle does not support NVARCHAR")
def visit_FLOAT(self, type_):
if type_.scale is not None and type_.precision is not None:
return "FLOAT(%s, %s)" % (type_.precision, type_.scale)
else:
return "FLOAT"
def visit_BOOLEAN(self, type_):
return "BOOLEAN"
def visit_BLOB(self, type_):
return "BLOB"
class DrizzleExecutionContext(mysql_dialect.MySQLExecutionContext):
pass
class DrizzleIdentifierPreparer(mysql_dialect.MySQLIdentifierPreparer):
pass
@log.class_logger
class DrizzleDialect(mysql_dialect.MySQLDialect):
"""Details of the Drizzle dialect.
Not used directly in application code.
"""
name = 'drizzle'
_supports_cast = True
supports_sequences = False
supports_native_boolean = True
supports_views = False
default_paramstyle = 'format'
colspecs = colspecs
statement_compiler = DrizzleCompiler
ddl_compiler = DrizzleDDLCompiler
type_compiler = DrizzleTypeCompiler
ischema_names = ischema_names
preparer = DrizzleIdentifierPreparer
def on_connect(self):
"""Force autocommit - Drizzle Bug#707842 doesn't set this properly"""
def connect(conn):
conn.autocommit(False)
return connect
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
"""Return a Unicode SHOW TABLES from a given schema."""
if schema is not None:
current_schema = schema
else:
current_schema = self.default_schema_name
charset = 'utf8'
rp = connection.execute("SHOW TABLES FROM %s" %
self.identifier_preparer.quote_identifier(current_schema))
return [row[0] for row in self._compat_fetchall(rp, charset=charset)]
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
raise NotImplementedError
def _detect_casing(self, connection):
"""Sniff out identifier case sensitivity.
Cached per-connection. This value can not change without a server
restart.
"""
return 0
def _detect_collations(self, connection):
"""Pull the active COLLATIONS list from the server.
Cached per-connection.
"""
collations = {}
charset = self._connection_charset
rs = connection.execute(
'SELECT CHARACTER_SET_NAME, COLLATION_NAME FROM'
' data_dictionary.COLLATIONS')
for row in self._compat_fetchall(rs, charset):
collations[row[0]] = row[1]
return collations
def _detect_ansiquotes(self, connection):
"""Detect and adjust for the ANSI_QUOTES sql mode."""
self._server_ansiquotes = False
self._backslash_escapes = False

View File

@@ -1,48 +0,0 @@
"""
.. dialect:: drizzle+mysqldb
:name: MySQL-Python
:dbapi: mysqldb
:connectstring: drizzle+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>
:url: http://sourceforge.net/projects/mysql-python
"""
from sqlalchemy.dialects.drizzle.base import (
DrizzleDialect,
DrizzleExecutionContext,
DrizzleCompiler,
DrizzleIdentifierPreparer)
from sqlalchemy.connectors.mysqldb import (
MySQLDBExecutionContext,
MySQLDBCompiler,
MySQLDBIdentifierPreparer,
MySQLDBConnector)
class DrizzleExecutionContext_mysqldb(MySQLDBExecutionContext,
DrizzleExecutionContext):
pass
class DrizzleCompiler_mysqldb(MySQLDBCompiler, DrizzleCompiler):
pass
class DrizzleIdentifierPreparer_mysqldb(MySQLDBIdentifierPreparer,
DrizzleIdentifierPreparer):
pass
class DrizzleDialect_mysqldb(MySQLDBConnector, DrizzleDialect):
execution_ctx_cls = DrizzleExecutionContext_mysqldb
statement_compiler = DrizzleCompiler_mysqldb
preparer = DrizzleIdentifierPreparer_mysqldb
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
return 'utf8'
dialect = DrizzleDialect_mysqldb

View File

@@ -1,20 +0,0 @@
# firebird/__init__.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from sqlalchemy.dialects.firebird import base, kinterbasdb, fdb
base.dialect = fdb.dialect
from sqlalchemy.dialects.firebird.base import \
SMALLINT, BIGINT, FLOAT, FLOAT, DATE, TIME, \
TEXT, NUMERIC, FLOAT, TIMESTAMP, VARCHAR, CHAR, BLOB,\
dialect
__all__ = (
'SMALLINT', 'BIGINT', 'FLOAT', 'FLOAT', 'DATE', 'TIME',
'TEXT', 'NUMERIC', 'FLOAT', 'TIMESTAMP', 'VARCHAR', 'CHAR', 'BLOB',
'dialect'
)

View File

@@ -1,736 +0,0 @@
# firebird/base.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: firebird
:name: Firebird
Firebird Dialects
-----------------
Firebird offers two distinct dialects_ (not to be confused with a
SQLAlchemy ``Dialect``):
dialect 1
This is the old syntax and behaviour, inherited from Interbase pre-6.0.
dialect 3
This is the newer and supported syntax, introduced in Interbase 6.0.
The SQLAlchemy Firebird dialect detects these versions and
adjusts its representation of SQL accordingly. However,
support for dialect 1 is not well tested and probably has
incompatibilities.
Locking Behavior
----------------
Firebird locks tables aggressively. For this reason, a DROP TABLE may
hang until other transactions are released. SQLAlchemy does its best
to release transactions as quickly as possible. The most common cause
of hanging transactions is a non-fully consumed result set, i.e.::
result = engine.execute("select * from table")
row = result.fetchone()
return
Where above, the ``ResultProxy`` has not been fully consumed. The
connection will be returned to the pool and the transactional state
rolled back once the Python garbage collector reclaims the objects
which hold onto the connection, which often occurs asynchronously.
The above use case can be alleviated by calling ``first()`` on the
``ResultProxy`` which will fetch the first row and immediately close
all remaining cursor/connection resources.
RETURNING support
-----------------
Firebird 2.0 supports returning a result set from inserts, and 2.1
extends that to deletes and updates. This is generically exposed by
the SQLAlchemy ``returning()`` method, such as::
# INSERT..RETURNING
result = table.insert().returning(table.c.col1, table.c.col2).\\
values(name='foo')
print result.fetchall()
# UPDATE..RETURNING
raises = empl.update().returning(empl.c.id, empl.c.salary).\\
where(empl.c.sales>100).\\
values(dict(salary=empl.c.salary * 1.1))
print raises.fetchall()
.. _dialects: http://mc-computing.com/Databases/Firebird/SQL_Dialect.html
"""
import datetime
from sqlalchemy import schema as sa_schema
from sqlalchemy import exc, types as sqltypes, sql, util
from sqlalchemy.sql import expression
from sqlalchemy.engine import base, default, reflection
from sqlalchemy.sql import compiler
from sqlalchemy.types import (BIGINT, BLOB, DATE, FLOAT, INTEGER, NUMERIC,
SMALLINT, TEXT, TIME, TIMESTAMP, Integer)
RESERVED_WORDS = set([
"active", "add", "admin", "after", "all", "alter", "and", "any", "as",
"asc", "ascending", "at", "auto", "avg", "before", "begin", "between",
"bigint", "bit_length", "blob", "both", "by", "case", "cast", "char",
"character", "character_length", "char_length", "check", "close",
"collate", "column", "commit", "committed", "computed", "conditional",
"connect", "constraint", "containing", "count", "create", "cross",
"cstring", "current", "current_connection", "current_date",
"current_role", "current_time", "current_timestamp",
"current_transaction", "current_user", "cursor", "database", "date",
"day", "dec", "decimal", "declare", "default", "delete", "desc",
"descending", "disconnect", "distinct", "do", "domain", "double",
"drop", "else", "end", "entry_point", "escape", "exception",
"execute", "exists", "exit", "external", "extract", "fetch", "file",
"filter", "float", "for", "foreign", "from", "full", "function",
"gdscode", "generator", "gen_id", "global", "grant", "group",
"having", "hour", "if", "in", "inactive", "index", "inner",
"input_type", "insensitive", "insert", "int", "integer", "into", "is",
"isolation", "join", "key", "leading", "left", "length", "level",
"like", "long", "lower", "manual", "max", "maximum_segment", "merge",
"min", "minute", "module_name", "month", "names", "national",
"natural", "nchar", "no", "not", "null", "numeric", "octet_length",
"of", "on", "only", "open", "option", "or", "order", "outer",
"output_type", "overflow", "page", "pages", "page_size", "parameter",
"password", "plan", "position", "post_event", "precision", "primary",
"privileges", "procedure", "protected", "rdb$db_key", "read", "real",
"record_version", "recreate", "recursive", "references", "release",
"reserv", "reserving", "retain", "returning_values", "returns",
"revoke", "right", "rollback", "rows", "row_count", "savepoint",
"schema", "second", "segment", "select", "sensitive", "set", "shadow",
"shared", "singular", "size", "smallint", "snapshot", "some", "sort",
"sqlcode", "stability", "start", "starting", "starts", "statistics",
"sub_type", "sum", "suspend", "table", "then", "time", "timestamp",
"to", "trailing", "transaction", "trigger", "trim", "uncommitted",
"union", "unique", "update", "upper", "user", "using", "value",
"values", "varchar", "variable", "varying", "view", "wait", "when",
"where", "while", "with", "work", "write", "year",
])
class _StringType(sqltypes.String):
"""Base for Firebird string types."""
def __init__(self, charset=None, **kw):
self.charset = charset
super(_StringType, self).__init__(**kw)
class VARCHAR(_StringType, sqltypes.VARCHAR):
"""Firebird VARCHAR type"""
__visit_name__ = 'VARCHAR'
def __init__(self, length=None, **kwargs):
super(VARCHAR, self).__init__(length=length, **kwargs)
class CHAR(_StringType, sqltypes.CHAR):
"""Firebird CHAR type"""
__visit_name__ = 'CHAR'
def __init__(self, length=None, **kwargs):
super(CHAR, self).__init__(length=length, **kwargs)
class _FBDateTime(sqltypes.DateTime):
def bind_processor(self, dialect):
def process(value):
if type(value) == datetime.date:
return datetime.datetime(value.year, value.month, value.day)
else:
return value
return process
colspecs = {
sqltypes.DateTime: _FBDateTime
}
ischema_names = {
'SHORT': SMALLINT,
'LONG': INTEGER,
'QUAD': FLOAT,
'FLOAT': FLOAT,
'DATE': DATE,
'TIME': TIME,
'TEXT': TEXT,
'INT64': BIGINT,
'DOUBLE': FLOAT,
'TIMESTAMP': TIMESTAMP,
'VARYING': VARCHAR,
'CSTRING': CHAR,
'BLOB': BLOB,
}
# TODO: date conversion types (should be implemented as _FBDateTime,
# _FBDate, etc. as bind/result functionality is required)
class FBTypeCompiler(compiler.GenericTypeCompiler):
def visit_boolean(self, type_):
return self.visit_SMALLINT(type_)
def visit_datetime(self, type_):
return self.visit_TIMESTAMP(type_)
def visit_TEXT(self, type_):
return "BLOB SUB_TYPE 1"
def visit_BLOB(self, type_):
return "BLOB SUB_TYPE 0"
def _extend_string(self, type_, basic):
charset = getattr(type_, 'charset', None)
if charset is None:
return basic
else:
return '%s CHARACTER SET %s' % (basic, charset)
def visit_CHAR(self, type_):
basic = super(FBTypeCompiler, self).visit_CHAR(type_)
return self._extend_string(type_, basic)
def visit_VARCHAR(self, type_):
if not type_.length:
raise exc.CompileError(
"VARCHAR requires a length on dialect %s" %
self.dialect.name)
basic = super(FBTypeCompiler, self).visit_VARCHAR(type_)
return self._extend_string(type_, basic)
class FBCompiler(sql.compiler.SQLCompiler):
"""Firebird specific idiosyncrasies"""
ansi_bind_rules = True
#def visit_contains_op_binary(self, binary, operator, **kw):
# cant use CONTAINING b.c. it's case insensitive.
#def visit_notcontains_op_binary(self, binary, operator, **kw):
# cant use NOT CONTAINING b.c. it's case insensitive.
def visit_now_func(self, fn, **kw):
return "CURRENT_TIMESTAMP"
def visit_startswith_op_binary(self, binary, operator, **kw):
return '%s STARTING WITH %s' % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw))
def visit_notstartswith_op_binary(self, binary, operator, **kw):
return '%s NOT STARTING WITH %s' % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw))
def visit_mod_binary(self, binary, operator, **kw):
return "mod(%s, %s)" % (
self.process(binary.left, **kw),
self.process(binary.right, **kw))
def visit_alias(self, alias, asfrom=False, **kwargs):
if self.dialect._version_two:
return super(FBCompiler, self).\
visit_alias(alias, asfrom=asfrom, **kwargs)
else:
# Override to not use the AS keyword which FB 1.5 does not like
if asfrom:
alias_name = isinstance(alias.name,
expression._truncated_label) and \
self._truncated_identifier("alias",
alias.name) or alias.name
return self.process(
alias.original, asfrom=asfrom, **kwargs) + \
" " + \
self.preparer.format_alias(alias, alias_name)
else:
return self.process(alias.original, **kwargs)
def visit_substring_func(self, func, **kw):
s = self.process(func.clauses.clauses[0])
start = self.process(func.clauses.clauses[1])
if len(func.clauses.clauses) > 2:
length = self.process(func.clauses.clauses[2])
return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
else:
return "SUBSTRING(%s FROM %s)" % (s, start)
def visit_length_func(self, function, **kw):
if self.dialect._version_two:
return "char_length" + self.function_argspec(function)
else:
return "strlen" + self.function_argspec(function)
visit_char_length_func = visit_length_func
def function_argspec(self, func, **kw):
# TODO: this probably will need to be
# narrowed to a fixed list, some no-arg functions
# may require parens - see similar example in the oracle
# dialect
if func.clauses is not None and len(func.clauses):
return self.process(func.clause_expr, **kw)
else:
return ""
def default_from(self):
return " FROM rdb$database"
def visit_sequence(self, seq):
return "gen_id(%s, 1)" % self.preparer.format_sequence(seq)
def get_select_precolumns(self, select):
"""Called when building a ``SELECT`` statement, position is just
before column list Firebird puts the limit and offset right
after the ``SELECT``...
"""
result = ""
if select._limit:
result += "FIRST %s " % self.process(sql.literal(select._limit))
if select._offset:
result += "SKIP %s " % self.process(sql.literal(select._offset))
if select._distinct:
result += "DISTINCT "
return result
def limit_clause(self, select):
"""Already taken care of in the `get_select_precolumns` method."""
return ""
def returning_clause(self, stmt, returning_cols):
columns = [
self._label_select_column(None, c, True, False, {})
for c in expression._select_iterables(returning_cols)
]
return 'RETURNING ' + ', '.join(columns)
class FBDDLCompiler(sql.compiler.DDLCompiler):
"""Firebird syntactic idiosyncrasies"""
def visit_create_sequence(self, create):
"""Generate a ``CREATE GENERATOR`` statement for the sequence."""
# no syntax for these
# http://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html
if create.element.start is not None:
raise NotImplemented(
"Firebird SEQUENCE doesn't support START WITH")
if create.element.increment is not None:
raise NotImplemented(
"Firebird SEQUENCE doesn't support INCREMENT BY")
if self.dialect._version_two:
return "CREATE SEQUENCE %s" % \
self.preparer.format_sequence(create.element)
else:
return "CREATE GENERATOR %s" % \
self.preparer.format_sequence(create.element)
def visit_drop_sequence(self, drop):
"""Generate a ``DROP GENERATOR`` statement for the sequence."""
if self.dialect._version_two:
return "DROP SEQUENCE %s" % \
self.preparer.format_sequence(drop.element)
else:
return "DROP GENERATOR %s" % \
self.preparer.format_sequence(drop.element)
class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
"""Install Firebird specific reserved words."""
reserved_words = RESERVED_WORDS
illegal_initial_characters = compiler.ILLEGAL_INITIAL_CHARACTERS.union(['_'])
def __init__(self, dialect):
super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
class FBExecutionContext(default.DefaultExecutionContext):
def fire_sequence(self, seq, type_):
"""Get the next value from the sequence using ``gen_id()``."""
return self._execute_scalar(
"SELECT gen_id(%s, 1) FROM rdb$database" %
self.dialect.identifier_preparer.format_sequence(seq),
type_
)
class FBDialect(default.DefaultDialect):
"""Firebird dialect"""
name = 'firebird'
max_identifier_length = 31
supports_sequences = True
sequences_optional = False
supports_default_values = True
postfetch_lastrowid = False
supports_native_boolean = False
requires_name_normalize = True
supports_empty_insert = False
statement_compiler = FBCompiler
ddl_compiler = FBDDLCompiler
preparer = FBIdentifierPreparer
type_compiler = FBTypeCompiler
execution_ctx_cls = FBExecutionContext
colspecs = colspecs
ischema_names = ischema_names
# defaults to dialect ver. 3,
# will be autodetected off upon
# first connect
_version_two = True
def initialize(self, connection):
super(FBDialect, self).initialize(connection)
self._version_two = ('firebird' in self.server_version_info and \
self.server_version_info >= (2, )
) or \
('interbase' in self.server_version_info and \
self.server_version_info >= (6, )
)
if not self._version_two:
# TODO: whatever other pre < 2.0 stuff goes here
self.ischema_names = ischema_names.copy()
self.ischema_names['TIMESTAMP'] = sqltypes.DATE
self.colspecs = {
sqltypes.DateTime: sqltypes.DATE
}
self.implicit_returning = self._version_two and \
self.__dict__.get('implicit_returning', True)
def normalize_name(self, name):
# Remove trailing spaces: FB uses a CHAR() type,
# that is padded with spaces
name = name and name.rstrip()
if name is None:
return None
elif name.upper() == name and \
not self.identifier_preparer._requires_quotes(name.lower()):
return name.lower()
else:
return name
def denormalize_name(self, name):
if name is None:
return None
elif name.lower() == name and \
not self.identifier_preparer._requires_quotes(name.lower()):
return name.upper()
else:
return name
def has_table(self, connection, table_name, schema=None):
"""Return ``True`` if the given table exists, ignoring
the `schema`."""
tblqry = """
SELECT 1 AS has_table FROM rdb$database
WHERE EXISTS (SELECT rdb$relation_name
FROM rdb$relations
WHERE rdb$relation_name=?)
"""
c = connection.execute(tblqry, [self.denormalize_name(table_name)])
return c.first() is not None
def has_sequence(self, connection, sequence_name, schema=None):
"""Return ``True`` if the given sequence (generator) exists."""
genqry = """
SELECT 1 AS has_sequence FROM rdb$database
WHERE EXISTS (SELECT rdb$generator_name
FROM rdb$generators
WHERE rdb$generator_name=?)
"""
c = connection.execute(genqry, [self.denormalize_name(sequence_name)])
return c.first() is not None
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
# there are two queries commonly mentioned for this.
# this one, using view_blr, is at the Firebird FAQ among other places:
# http://www.firebirdfaq.org/faq174/
s = """
select rdb$relation_name
from rdb$relations
where rdb$view_blr is null
and (rdb$system_flag is null or rdb$system_flag = 0);
"""
# the other query is this one. It's not clear if there's really
# any difference between these two. This link:
# http://www.alberton.info/firebird_sql_meta_info.html#.Ur3vXfZGni8
# states them as interchangeable. Some discussion at [ticket:2898]
# SELECT DISTINCT rdb$relation_name
# FROM rdb$relation_fields
# WHERE rdb$system_flag=0 AND rdb$view_context IS NULL
return [self.normalize_name(row[0]) for row in connection.execute(s)]
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
# see http://www.firebirdfaq.org/faq174/
s = """
select rdb$relation_name
from rdb$relations
where rdb$view_blr is not null
and (rdb$system_flag is null or rdb$system_flag = 0);
"""
return [self.normalize_name(row[0]) for row in connection.execute(s)]
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
qry = """
SELECT rdb$view_source AS view_source
FROM rdb$relations
WHERE rdb$relation_name=?
"""
rp = connection.execute(qry, [self.denormalize_name(view_name)])
row = rp.first()
if row:
return row['view_source']
else:
return None
@reflection.cache
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
# Query to extract the PK/FK constrained fields of the given table
keyqry = """
SELECT se.rdb$field_name AS fname
FROM rdb$relation_constraints rc
JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name
WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
"""
tablename = self.denormalize_name(table_name)
# get primary key fields
c = connection.execute(keyqry, ["PRIMARY KEY", tablename])
pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()]
return {'constrained_columns': pkfields, 'name': None}
@reflection.cache
def get_column_sequence(self, connection,
table_name, column_name,
schema=None, **kw):
tablename = self.denormalize_name(table_name)
colname = self.denormalize_name(column_name)
# Heuristic-query to determine the generator associated to a PK field
genqry = """
SELECT trigdep.rdb$depended_on_name AS fgenerator
FROM rdb$dependencies tabdep
JOIN rdb$dependencies trigdep
ON tabdep.rdb$dependent_name=trigdep.rdb$dependent_name
AND trigdep.rdb$depended_on_type=14
AND trigdep.rdb$dependent_type=2
JOIN rdb$triggers trig ON
trig.rdb$trigger_name=tabdep.rdb$dependent_name
WHERE tabdep.rdb$depended_on_name=?
AND tabdep.rdb$depended_on_type=0
AND trig.rdb$trigger_type=1
AND tabdep.rdb$field_name=?
AND (SELECT count(*)
FROM rdb$dependencies trigdep2
WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2
"""
genr = connection.execute(genqry, [tablename, colname]).first()
if genr is not None:
return dict(name=self.normalize_name(genr['fgenerator']))
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
# Query to extract the details of all the fields of the given table
tblqry = """
SELECT r.rdb$field_name AS fname,
r.rdb$null_flag AS null_flag,
t.rdb$type_name AS ftype,
f.rdb$field_sub_type AS stype,
f.rdb$field_length/
COALESCE(cs.rdb$bytes_per_character,1) AS flen,
f.rdb$field_precision AS fprec,
f.rdb$field_scale AS fscale,
COALESCE(r.rdb$default_source,
f.rdb$default_source) AS fdefault
FROM rdb$relation_fields r
JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name
JOIN rdb$types t
ON t.rdb$type=f.rdb$field_type AND
t.rdb$field_name='RDB$FIELD_TYPE'
LEFT JOIN rdb$character_sets cs ON
f.rdb$character_set_id=cs.rdb$character_set_id
WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=?
ORDER BY r.rdb$field_position
"""
# get the PK, used to determine the eventual associated sequence
pk_constraint = self.get_pk_constraint(connection, table_name)
pkey_cols = pk_constraint['constrained_columns']
tablename = self.denormalize_name(table_name)
# get all of the fields for this table
c = connection.execute(tblqry, [tablename])
cols = []
while True:
row = c.fetchone()
if row is None:
break
name = self.normalize_name(row['fname'])
orig_colname = row['fname']
# get the data type
colspec = row['ftype'].rstrip()
coltype = self.ischema_names.get(colspec)
if coltype is None:
util.warn("Did not recognize type '%s' of column '%s'" %
(colspec, name))
coltype = sqltypes.NULLTYPE
elif issubclass(coltype, Integer) and row['fprec'] != 0:
coltype = NUMERIC(
precision=row['fprec'],
scale=row['fscale'] * -1)
elif colspec in ('VARYING', 'CSTRING'):
coltype = coltype(row['flen'])
elif colspec == 'TEXT':
coltype = TEXT(row['flen'])
elif colspec == 'BLOB':
if row['stype'] == 1:
coltype = TEXT()
else:
coltype = BLOB()
else:
coltype = coltype()
# does it have a default value?
defvalue = None
if row['fdefault'] is not None:
# the value comes down as "DEFAULT 'value'": there may be
# more than one whitespace around the "DEFAULT" keyword
# and it may also be lower case
# (see also http://tracker.firebirdsql.org/browse/CORE-356)
defexpr = row['fdefault'].lstrip()
assert defexpr[:8].rstrip().upper() == \
'DEFAULT', "Unrecognized default value: %s" % \
defexpr
defvalue = defexpr[8:].strip()
if defvalue == 'NULL':
# Redundant
defvalue = None
col_d = {
'name': name,
'type': coltype,
'nullable': not bool(row['null_flag']),
'default': defvalue,
'autoincrement': defvalue is None
}
if orig_colname.lower() == orig_colname:
col_d['quote'] = True
# if the PK is a single field, try to see if its linked to
# a sequence thru a trigger
if len(pkey_cols) == 1 and name == pkey_cols[0]:
seq_d = self.get_column_sequence(connection, tablename, name)
if seq_d is not None:
col_d['sequence'] = seq_d
cols.append(col_d)
return cols
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
# Query to extract the details of each UK/FK of the given table
fkqry = """
SELECT rc.rdb$constraint_name AS cname,
cse.rdb$field_name AS fname,
ix2.rdb$relation_name AS targetrname,
se.rdb$field_name AS targetfname
FROM rdb$relation_constraints rc
JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name
JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key
JOIN rdb$index_segments cse ON
cse.rdb$index_name=ix1.rdb$index_name
JOIN rdb$index_segments se
ON se.rdb$index_name=ix2.rdb$index_name
AND se.rdb$field_position=cse.rdb$field_position
WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
ORDER BY se.rdb$index_name, se.rdb$field_position
"""
tablename = self.denormalize_name(table_name)
c = connection.execute(fkqry, ["FOREIGN KEY", tablename])
fks = util.defaultdict(lambda: {
'name': None,
'constrained_columns': [],
'referred_schema': None,
'referred_table': None,
'referred_columns': []
})
for row in c:
cname = self.normalize_name(row['cname'])
fk = fks[cname]
if not fk['name']:
fk['name'] = cname
fk['referred_table'] = self.normalize_name(row['targetrname'])
fk['constrained_columns'].append(
self.normalize_name(row['fname']))
fk['referred_columns'].append(
self.normalize_name(row['targetfname']))
return list(fks.values())
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
qry = """
SELECT ix.rdb$index_name AS index_name,
ix.rdb$unique_flag AS unique_flag,
ic.rdb$field_name AS field_name
FROM rdb$indices ix
JOIN rdb$index_segments ic
ON ix.rdb$index_name=ic.rdb$index_name
LEFT OUTER JOIN rdb$relation_constraints
ON rdb$relation_constraints.rdb$index_name =
ic.rdb$index_name
WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL
AND rdb$relation_constraints.rdb$constraint_type IS NULL
ORDER BY index_name, ic.rdb$field_position
"""
c = connection.execute(qry, [self.denormalize_name(table_name)])
indexes = util.defaultdict(dict)
for row in c:
indexrec = indexes[row['index_name']]
if 'name' not in indexrec:
indexrec['name'] = self.normalize_name(row['index_name'])
indexrec['column_names'] = []
indexrec['unique'] = bool(row['unique_flag'])
indexrec['column_names'].append(
self.normalize_name(row['field_name']))
return list(indexes.values())

View File

@@ -1,115 +0,0 @@
# firebird/fdb.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: firebird+fdb
:name: fdb
:dbapi: pyodbc
:connectstring: firebird+fdb://user:password@host:port/path/to/db[?key=value&key=value...]
:url: http://pypi.python.org/pypi/fdb/
fdb is a kinterbasdb compatible DBAPI for Firebird.
.. versionadded:: 0.8 - Support for the fdb Firebird driver.
.. versionchanged:: 0.9 - The fdb dialect is now the default dialect
under the ``firebird://`` URL space, as ``fdb`` is now the official
Python driver for Firebird.
Arguments
----------
The ``fdb`` dialect is based on the :mod:`sqlalchemy.dialects.firebird.kinterbasdb`
dialect, however does not accept every argument that Kinterbasdb does.
* ``enable_rowcount`` - True by default, setting this to False disables
the usage of "cursor.rowcount" with the
Kinterbasdb dialect, which SQLAlchemy ordinarily calls upon automatically
after any UPDATE or DELETE statement. When disabled, SQLAlchemy's
ResultProxy will return -1 for result.rowcount. The rationale here is
that Kinterbasdb requires a second round trip to the database when
.rowcount is called - since SQLA's resultproxy automatically closes
the cursor after a non-result-returning statement, rowcount must be
called, if at all, before the result object is returned. Additionally,
cursor.rowcount may not return correct results with older versions
of Firebird, and setting this flag to False will also cause the
SQLAlchemy ORM to ignore its usage. The behavior can also be controlled on a
per-execution basis using the ``enable_rowcount`` option with
:meth:`.Connection.execution_options`::
conn = engine.connect().execution_options(enable_rowcount=True)
r = conn.execute(stmt)
print r.rowcount
* ``retaining`` - False by default. Setting this to True will pass the
``retaining=True`` keyword argument to the ``.commit()`` and ``.rollback()``
methods of the DBAPI connection, which can improve performance in some
situations, but apparently with significant caveats.
Please read the fdb and/or kinterbasdb DBAPI documentation in order to
understand the implications of this flag.
.. versionadded:: 0.8.2 - ``retaining`` keyword argument specifying
transaction retaining behavior - in 0.8 it defaults to ``True``
for backwards compatibility.
.. versionchanged:: 0.9.0 - the ``retaining`` flag defaults to ``False``.
In 0.8 it defaulted to ``True``.
.. seealso::
http://pythonhosted.org/fdb/usage-guide.html#retaining-transactions - information
on the "retaining" flag.
"""
from .kinterbasdb import FBDialect_kinterbasdb
from ... import util
class FBDialect_fdb(FBDialect_kinterbasdb):
def __init__(self, enable_rowcount=True,
retaining=False, **kwargs):
super(FBDialect_fdb, self).__init__(
enable_rowcount=enable_rowcount,
retaining=retaining, **kwargs)
@classmethod
def dbapi(cls):
return __import__('fdb')
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
if opts.get('port'):
opts['host'] = "%s/%s" % (opts['host'], opts['port'])
del opts['port']
opts.update(url.query)
util.coerce_kw_type(opts, 'type_conv', int)
return ([], opts)
def _get_server_version_info(self, connection):
"""Get the version of the Firebird server used by a connection.
Returns a tuple of (`major`, `minor`, `build`), three integers
representing the version of the attached server.
"""
# This is the simpler approach (the other uses the services api),
# that for backward compatibility reasons returns a string like
# LI-V6.3.3.12981 Firebird 2.0
# where the first version is a fake one resembling the old
# Interbase signature.
isc_info_firebird_version = 103
fbconn = connection.connection
version = fbconn.db_info(isc_info_firebird_version)
return self._parse_version_info(version)
dialect = FBDialect_fdb

View File

@@ -1,179 +0,0 @@
# firebird/kinterbasdb.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: firebird+kinterbasdb
:name: kinterbasdb
:dbapi: kinterbasdb
:connectstring: firebird+kinterbasdb://user:password@host:port/path/to/db[?key=value&key=value...]
:url: http://firebirdsql.org/index.php?op=devel&sub=python
Arguments
----------
The Kinterbasdb backend accepts the ``enable_rowcount`` and ``retaining``
arguments accepted by the :mod:`sqlalchemy.dialects.firebird.fdb` dialect. In addition, it
also accepts the following:
* ``type_conv`` - select the kind of mapping done on the types: by default
SQLAlchemy uses 200 with Unicode, datetime and decimal support. See
the linked documents below for further information.
* ``concurrency_level`` - set the backend policy with regards to threading
issues: by default SQLAlchemy uses policy 1. See the linked documents
below for futher information.
.. seealso::
http://sourceforge.net/projects/kinterbasdb
http://kinterbasdb.sourceforge.net/dist_docs/usage.html#adv_param_conv_dynamic_type_translation
http://kinterbasdb.sourceforge.net/dist_docs/usage.html#special_issue_concurrency
"""
from .base import FBDialect, FBExecutionContext
from ... import util, types as sqltypes
from re import match
import decimal
class _kinterbasdb_numeric(object):
def bind_processor(self, dialect):
def process(value):
if isinstance(value, decimal.Decimal):
return str(value)
else:
return value
return process
class _FBNumeric_kinterbasdb(_kinterbasdb_numeric, sqltypes.Numeric):
pass
class _FBFloat_kinterbasdb(_kinterbasdb_numeric, sqltypes.Float):
pass
class FBExecutionContext_kinterbasdb(FBExecutionContext):
@property
def rowcount(self):
if self.execution_options.get('enable_rowcount',
self.dialect.enable_rowcount):
return self.cursor.rowcount
else:
return -1
class FBDialect_kinterbasdb(FBDialect):
driver = 'kinterbasdb'
supports_sane_rowcount = False
supports_sane_multi_rowcount = False
execution_ctx_cls = FBExecutionContext_kinterbasdb
supports_native_decimal = True
colspecs = util.update_copy(
FBDialect.colspecs,
{
sqltypes.Numeric: _FBNumeric_kinterbasdb,
sqltypes.Float: _FBFloat_kinterbasdb,
}
)
def __init__(self, type_conv=200, concurrency_level=1,
enable_rowcount=True,
retaining=False, **kwargs):
super(FBDialect_kinterbasdb, self).__init__(**kwargs)
self.enable_rowcount = enable_rowcount
self.type_conv = type_conv
self.concurrency_level = concurrency_level
self.retaining = retaining
if enable_rowcount:
self.supports_sane_rowcount = True
@classmethod
def dbapi(cls):
return __import__('kinterbasdb')
def do_execute(self, cursor, statement, parameters, context=None):
# kinterbase does not accept a None, but wants an empty list
# when there are no arguments.
cursor.execute(statement, parameters or [])
def do_rollback(self, dbapi_connection):
dbapi_connection.rollback(self.retaining)
def do_commit(self, dbapi_connection):
dbapi_connection.commit(self.retaining)
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
if opts.get('port'):
opts['host'] = "%s/%s" % (opts['host'], opts['port'])
del opts['port']
opts.update(url.query)
util.coerce_kw_type(opts, 'type_conv', int)
type_conv = opts.pop('type_conv', self.type_conv)
concurrency_level = opts.pop('concurrency_level',
self.concurrency_level)
if self.dbapi is not None:
initialized = getattr(self.dbapi, 'initialized', None)
if initialized is None:
# CVS rev 1.96 changed the name of the attribute:
# http://kinterbasdb.cvs.sourceforge.net/viewvc/kinterbasdb/
# Kinterbasdb-3.0/__init__.py?r1=1.95&r2=1.96
initialized = getattr(self.dbapi, '_initialized', False)
if not initialized:
self.dbapi.init(type_conv=type_conv,
concurrency_level=concurrency_level)
return ([], opts)
def _get_server_version_info(self, connection):
"""Get the version of the Firebird server used by a connection.
Returns a tuple of (`major`, `minor`, `build`), three integers
representing the version of the attached server.
"""
# This is the simpler approach (the other uses the services api),
# that for backward compatibility reasons returns a string like
# LI-V6.3.3.12981 Firebird 2.0
# where the first version is a fake one resembling the old
# Interbase signature.
fbconn = connection.connection
version = fbconn.server_version
return self._parse_version_info(version)
def _parse_version_info(self, version):
m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+)( \w+ (\d+)\.(\d+))?', version)
if not m:
raise AssertionError(
"Could not determine version from string '%s'" % version)
if m.group(5) != None:
return tuple([int(x) for x in m.group(6, 7, 4)] + ['firebird'])
else:
return tuple([int(x) for x in m.group(1, 2, 3)] + ['interbase'])
def is_disconnect(self, e, connection, cursor):
if isinstance(e, (self.dbapi.OperationalError,
self.dbapi.ProgrammingError)):
msg = str(e)
return ('Unable to complete network request to host' in msg or
'Invalid connection state' in msg or
'Invalid cursor state' in msg or
'connection shutdown' in msg)
else:
return False
dialect = FBDialect_kinterbasdb

View File

@@ -1,26 +0,0 @@
# mssql/__init__.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from sqlalchemy.dialects.mssql import base, pyodbc, adodbapi, \
pymssql, zxjdbc, mxodbc
base.dialect = pyodbc.dialect
from sqlalchemy.dialects.mssql.base import \
INTEGER, BIGINT, SMALLINT, TINYINT, VARCHAR, NVARCHAR, CHAR, \
NCHAR, TEXT, NTEXT, DECIMAL, NUMERIC, FLOAT, DATETIME,\
DATETIME2, DATETIMEOFFSET, DATE, TIME, SMALLDATETIME, \
BINARY, VARBINARY, BIT, REAL, IMAGE, TIMESTAMP,\
MONEY, SMALLMONEY, UNIQUEIDENTIFIER, SQL_VARIANT, dialect
__all__ = (
'INTEGER', 'BIGINT', 'SMALLINT', 'TINYINT', 'VARCHAR', 'NVARCHAR', 'CHAR',
'NCHAR', 'TEXT', 'NTEXT', 'DECIMAL', 'NUMERIC', 'FLOAT', 'DATETIME',
'DATETIME2', 'DATETIMEOFFSET', 'DATE', 'TIME', 'SMALLDATETIME',
'BINARY', 'VARBINARY', 'BIT', 'REAL', 'IMAGE', 'TIMESTAMP',
'MONEY', 'SMALLMONEY', 'UNIQUEIDENTIFIER', 'SQL_VARIANT', 'dialect'
)

View File

@@ -1,79 +0,0 @@
# mssql/adodbapi.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mssql+adodbapi
:name: adodbapi
:dbapi: adodbapi
:connectstring: mssql+adodbapi://<username>:<password>@<dsnname>
:url: http://adodbapi.sourceforge.net/
.. note::
The adodbapi dialect is not implemented SQLAlchemy versions 0.6 and
above at this time.
"""
import datetime
from sqlalchemy import types as sqltypes, util
from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect
import sys
class MSDateTime_adodbapi(MSDateTime):
def result_processor(self, dialect, coltype):
def process(value):
# adodbapi will return datetimes with empty time
# values as datetime.date() objects.
# Promote them back to full datetime.datetime()
if type(value) is datetime.date:
return datetime.datetime(value.year, value.month, value.day)
return value
return process
class MSDialect_adodbapi(MSDialect):
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
supports_unicode = sys.maxunicode == 65535
supports_unicode_statements = True
driver = 'adodbapi'
@classmethod
def import_dbapi(cls):
import adodbapi as module
return module
colspecs = util.update_copy(
MSDialect.colspecs,
{
sqltypes.DateTime: MSDateTime_adodbapi
}
)
def create_connect_args(self, url):
keys = url.query
connectors = ["Provider=SQLOLEDB"]
if 'port' in keys:
connectors.append("Data Source=%s, %s" %
(keys.get("host"), keys.get("port")))
else:
connectors.append("Data Source=%s" % keys.get("host"))
connectors.append("Initial Catalog=%s" % keys.get("database"))
user = keys.get("user")
if user:
connectors.append("User Id=%s" % user)
connectors.append("Password=%s" % keys.get("password", ""))
else:
connectors.append("Integrated Security=SSPI")
return [[";".join(connectors)], {}]
def is_disconnect(self, e, connection, cursor):
return isinstance(e, self.dbapi.adodbapi.DatabaseError) and \
"'connection failure'" in str(e)
dialect = MSDialect_adodbapi

File diff suppressed because it is too large Load Diff

View File

@@ -1,114 +0,0 @@
# mssql/information_schema.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
# TODO: should be using the sys. catalog with SQL Server, not information schema
from ... import Table, MetaData, Column
from ...types import String, Unicode, UnicodeText, Integer, TypeDecorator
from ... import cast
from ... import util
from ...sql import expression
from ...ext.compiler import compiles
ischema = MetaData()
class CoerceUnicode(TypeDecorator):
impl = Unicode
def process_bind_param(self, value, dialect):
if util.py2k and isinstance(value, util.binary_type):
value = value.decode(dialect.encoding)
return value
def bind_expression(self, bindvalue):
return _cast_on_2005(bindvalue)
class _cast_on_2005(expression.ColumnElement):
def __init__(self, bindvalue):
self.bindvalue = bindvalue
@compiles(_cast_on_2005)
def _compile(element, compiler, **kw):
from . import base
if compiler.dialect.server_version_info < base.MS_2005_VERSION:
return compiler.process(element.bindvalue, **kw)
else:
return compiler.process(cast(element.bindvalue, Unicode), **kw)
schemata = Table("SCHEMATA", ischema,
Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"),
Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"),
Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"),
schema="INFORMATION_SCHEMA")
tables = Table("TABLES", ischema,
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("TABLE_TYPE", String(convert_unicode=True), key="table_type"),
schema="INFORMATION_SCHEMA")
columns = Table("COLUMNS", ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
Column("IS_NULLABLE", Integer, key="is_nullable"),
Column("DATA_TYPE", String, key="data_type"),
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
Column("CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"),
Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
Column("COLUMN_DEFAULT", Integer, key="column_default"),
Column("COLLATION_NAME", String, key="collation_name"),
schema="INFORMATION_SCHEMA")
constraints = Table("TABLE_CONSTRAINTS", ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
Column("CONSTRAINT_TYPE", String(convert_unicode=True), key="constraint_type"),
schema="INFORMATION_SCHEMA")
column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
schema="INFORMATION_SCHEMA")
key_constraints = Table("KEY_COLUMN_USAGE", ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
schema="INFORMATION_SCHEMA")
ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema,
Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"),
Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
# TODO: is CATLOG misspelled ?
Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode,
key="unique_constraint_catalog"),
Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode,
key="unique_constraint_schema"),
Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode,
key="unique_constraint_name"),
Column("MATCH_OPTION", String, key="match_option"),
Column("UPDATE_RULE", String, key="update_rule"),
Column("DELETE_RULE", String, key="delete_rule"),
schema="INFORMATION_SCHEMA")
views = Table("VIEWS", ischema,
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"),
Column("CHECK_OPTION", String, key="check_option"),
Column("IS_UPDATABLE", String, key="is_updatable"),
schema="INFORMATION_SCHEMA")

View File

@@ -1,111 +0,0 @@
# mssql/mxodbc.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mssql+mxodbc
:name: mxODBC
:dbapi: mxodbc
:connectstring: mssql+mxodbc://<username>:<password>@<dsnname>
:url: http://www.egenix.com/
Execution Modes
---------------
mxODBC features two styles of statement execution, using the
``cursor.execute()`` and ``cursor.executedirect()`` methods (the second being
an extension to the DBAPI specification). The former makes use of a particular
API call specific to the SQL Server Native Client ODBC driver known
SQLDescribeParam, while the latter does not.
mxODBC apparently only makes repeated use of a single prepared statement
when SQLDescribeParam is used. The advantage to prepared statement reuse is
one of performance. The disadvantage is that SQLDescribeParam has a limited
set of scenarios in which bind parameters are understood, including that they
cannot be placed within the argument lists of function calls, anywhere outside
the FROM, or even within subqueries within the FROM clause - making the usage
of bind parameters within SELECT statements impossible for all but the most
simplistic statements.
For this reason, the mxODBC dialect uses the "native" mode by default only for
INSERT, UPDATE, and DELETE statements, and uses the escaped string mode for
all other statements.
This behavior can be controlled via
:meth:`~sqlalchemy.sql.expression.Executable.execution_options` using the
``native_odbc_execute`` flag with a value of ``True`` or ``False``, where a
value of ``True`` will unconditionally use native bind parameters and a value
of ``False`` will unconditionally use string-escaped parameters.
"""
from ... import types as sqltypes
from ...connectors.mxodbc import MxODBCConnector
from .pyodbc import MSExecutionContext_pyodbc, _MSNumeric_pyodbc
from .base import (MSDialect,
MSSQLStrictCompiler,
_MSDateTime, _MSDate, _MSTime)
class _MSNumeric_mxodbc(_MSNumeric_pyodbc):
"""Include pyodbc's numeric processor.
"""
class _MSDate_mxodbc(_MSDate):
def bind_processor(self, dialect):
def process(value):
if value is not None:
return "%s-%s-%s" % (value.year, value.month, value.day)
else:
return None
return process
class _MSTime_mxodbc(_MSTime):
def bind_processor(self, dialect):
def process(value):
if value is not None:
return "%s:%s:%s" % (value.hour, value.minute, value.second)
else:
return None
return process
class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc):
"""
The pyodbc execution context is useful for enabling
SELECT SCOPE_IDENTITY in cases where OUTPUT clause
does not work (tables with insert triggers).
"""
#todo - investigate whether the pyodbc execution context
# is really only being used in cases where OUTPUT
# won't work.
class MSDialect_mxodbc(MxODBCConnector, MSDialect):
# this is only needed if "native ODBC" mode is used,
# which is now disabled by default.
#statement_compiler = MSSQLStrictCompiler
execution_ctx_cls = MSExecutionContext_mxodbc
# flag used by _MSNumeric_mxodbc
_need_decimal_fix = True
colspecs = {
sqltypes.Numeric: _MSNumeric_mxodbc,
sqltypes.DateTime: _MSDateTime,
sqltypes.Date: _MSDate_mxodbc,
sqltypes.Time: _MSTime_mxodbc,
}
def __init__(self, description_encoding=None, **params):
super(MSDialect_mxodbc, self).__init__(**params)
self.description_encoding = description_encoding
dialect = MSDialect_mxodbc

View File

@@ -1,100 +0,0 @@
# mssql/pymssql.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mssql+pymssql
:name: pymssql
:dbapi: pymssql
:connectstring: mssql+pymssql://<username>:<password>@<freetds_name>?charset=utf8
:url: http://pymssql.sourceforge.net/
Limitations
-----------
pymssql inherits a lot of limitations from FreeTDS, including:
* no support for multibyte schema identifiers
* poor support for large decimals
* poor support for binary fields
* poor support for VARCHAR/CHAR fields over 255 characters
Please consult the pymssql documentation for further information.
"""
from .base import MSDialect
from ... import types as sqltypes, util, processors
import re
class _MSNumeric_pymssql(sqltypes.Numeric):
def result_processor(self, dialect, type_):
if not self.asdecimal:
return processors.to_float
else:
return sqltypes.Numeric.result_processor(self, dialect, type_)
class MSDialect_pymssql(MSDialect):
supports_sane_rowcount = False
driver = 'pymssql'
colspecs = util.update_copy(
MSDialect.colspecs,
{
sqltypes.Numeric: _MSNumeric_pymssql,
sqltypes.Float: sqltypes.Float,
}
)
@classmethod
def dbapi(cls):
module = __import__('pymssql')
# pymmsql doesn't have a Binary method. we use string
# TODO: monkeypatching here is less than ideal
module.Binary = lambda x: x if hasattr(x, 'decode') else str(x)
client_ver = tuple(int(x) for x in module.__version__.split("."))
if client_ver < (1, ):
util.warn("The pymssql dialect expects at least "
"the 1.0 series of the pymssql DBAPI.")
return module
def __init__(self, **params):
super(MSDialect_pymssql, self).__init__(**params)
self.use_scope_identity = True
def _get_server_version_info(self, connection):
vers = connection.scalar("select @@version")
m = re.match(
r"Microsoft SQL Server.*? - (\d+).(\d+).(\d+).(\d+)", vers)
if m:
return tuple(int(x) for x in m.group(1, 2, 3, 4))
else:
return None
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
opts.update(url.query)
port = opts.pop('port', None)
if port and 'host' in opts:
opts['host'] = "%s:%s" % (opts['host'], port)
return [[], opts]
def is_disconnect(self, e, connection, cursor):
for msg in (
"Adaptive Server connection timed out",
"Net-Lib error during Connection reset by peer",
"message 20003", # connection timeout
"Error 10054",
"Not connected to any MS SQL server",
"Connection is closed"
):
if msg in str(e):
return True
else:
return False
dialect = MSDialect_pymssql

View File

@@ -1,260 +0,0 @@
# mssql/pyodbc.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mssql+pyodbc
:name: PyODBC
:dbapi: pyodbc
:connectstring: mssql+pyodbc://<username>:<password>@<dsnname>
:url: http://pypi.python.org/pypi/pyodbc/
Additional Connection Examples
-------------------------------
Examples of pyodbc connection string URLs:
* ``mssql+pyodbc://mydsn`` - connects using the specified DSN named ``mydsn``.
The connection string that is created will appear like::
dsn=mydsn;Trusted_Connection=Yes
* ``mssql+pyodbc://user:pass@mydsn`` - connects using the DSN named
``mydsn`` passing in the ``UID`` and ``PWD`` information. The
connection string that is created will appear like::
dsn=mydsn;UID=user;PWD=pass
* ``mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english`` - connects
using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD``
information, plus the additional connection configuration option
``LANGUAGE``. The connection string that is created will appear
like::
dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english
* ``mssql+pyodbc://user:pass@host/db`` - connects using a connection
that would appear like::
DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass
* ``mssql+pyodbc://user:pass@host:123/db`` - connects using a connection
string which includes the port
information using the comma syntax. This will create the following
connection string::
DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass
* ``mssql+pyodbc://user:pass@host/db?port=123`` - connects using a connection
string that includes the port
information as a separate ``port`` keyword. This will create the
following connection string::
DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123
* ``mssql+pyodbc://user:pass@host/db?driver=MyDriver`` - connects using a connection
string that includes a custom
ODBC driver name. This will create the following connection string::
DRIVER={MyDriver};Server=host;Database=db;UID=user;PWD=pass
If you require a connection string that is outside the options
presented above, use the ``odbc_connect`` keyword to pass in a
urlencoded connection string. What gets passed in will be urldecoded
and passed directly.
For example::
mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb
would create the following connection string::
dsn=mydsn;Database=db
Encoding your connection string can be easily accomplished through
the python shell. For example::
>>> import urllib
>>> urllib.quote_plus('dsn=mydsn;Database=db')
'dsn%3Dmydsn%3BDatabase%3Ddb'
Unicode Binds
-------------
The current state of PyODBC on a unix backend with FreeTDS and/or
EasySoft is poor regarding unicode; different OS platforms and versions of UnixODBC
versus IODBC versus FreeTDS/EasySoft versus PyODBC itself dramatically
alter how strings are received. The PyODBC dialect attempts to use all the information
it knows to determine whether or not a Python unicode literal can be
passed directly to the PyODBC driver or not; while SQLAlchemy can encode
these to bytestrings first, some users have reported that PyODBC mis-handles
bytestrings for certain encodings and requires a Python unicode object,
while the author has observed widespread cases where a Python unicode
is completely misinterpreted by PyODBC, particularly when dealing with
the information schema tables used in table reflection, and the value
must first be encoded to a bytestring.
It is for this reason that whether or not unicode literals for bound
parameters be sent to PyODBC can be controlled using the
``supports_unicode_binds`` parameter to ``create_engine()``. When
left at its default of ``None``, the PyODBC dialect will use its
best guess as to whether or not the driver deals with unicode literals
well. When ``False``, unicode literals will be encoded first, and when
``True`` unicode literals will be passed straight through. This is an interim
flag that hopefully should not be needed when the unicode situation stabilizes
for unix + PyODBC.
.. versionadded:: 0.7.7
``supports_unicode_binds`` parameter to ``create_engine()``\ .
"""
from .base import MSExecutionContext, MSDialect
from ...connectors.pyodbc import PyODBCConnector
from ... import types as sqltypes, util
import decimal
class _ms_numeric_pyodbc(object):
"""Turns Decimals with adjusted() < 0 or > 7 into strings.
The routines here are needed for older pyodbc versions
as well as current mxODBC versions.
"""
def bind_processor(self, dialect):
super_process = super(_ms_numeric_pyodbc, self).\
bind_processor(dialect)
if not dialect._need_decimal_fix:
return super_process
def process(value):
if self.asdecimal and \
isinstance(value, decimal.Decimal):
adjusted = value.adjusted()
if adjusted < 0:
return self._small_dec_to_string(value)
elif adjusted > 7:
return self._large_dec_to_string(value)
if super_process:
return super_process(value)
else:
return value
return process
# these routines needed for older versions of pyodbc.
# as of 2.1.8 this logic is integrated.
def _small_dec_to_string(self, value):
return "%s0.%s%s" % (
(value < 0 and '-' or ''),
'0' * (abs(value.adjusted()) - 1),
"".join([str(nint) for nint in value.as_tuple()[1]]))
def _large_dec_to_string(self, value):
_int = value.as_tuple()[1]
if 'E' in str(value):
result = "%s%s%s" % (
(value < 0 and '-' or ''),
"".join([str(s) for s in _int]),
"0" * (value.adjusted() - (len(_int) - 1)))
else:
if (len(_int) - 1) > value.adjusted():
result = "%s%s.%s" % (
(value < 0 and '-' or ''),
"".join(
[str(s) for s in _int][0:value.adjusted() + 1]),
"".join(
[str(s) for s in _int][value.adjusted() + 1:]))
else:
result = "%s%s" % (
(value < 0 and '-' or ''),
"".join(
[str(s) for s in _int][0:value.adjusted() + 1]))
return result
class _MSNumeric_pyodbc(_ms_numeric_pyodbc, sqltypes.Numeric):
pass
class _MSFloat_pyodbc(_ms_numeric_pyodbc, sqltypes.Float):
pass
class MSExecutionContext_pyodbc(MSExecutionContext):
_embedded_scope_identity = False
def pre_exec(self):
"""where appropriate, issue "select scope_identity()" in the same
statement.
Background on why "scope_identity()" is preferable to "@@identity":
http://msdn.microsoft.com/en-us/library/ms190315.aspx
Background on why we attempt to embed "scope_identity()" into the same
statement as the INSERT:
http://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values?
"""
super(MSExecutionContext_pyodbc, self).pre_exec()
# don't embed the scope_identity select into an
# "INSERT .. DEFAULT VALUES"
if self._select_lastrowid and \
self.dialect.use_scope_identity and \
len(self.parameters[0]):
self._embedded_scope_identity = True
self.statement += "; select scope_identity()"
def post_exec(self):
if self._embedded_scope_identity:
# Fetch the last inserted id from the manipulated statement
# We may have to skip over a number of result sets with
# no data (due to triggers, etc.)
while True:
try:
# fetchall() ensures the cursor is consumed
# without closing it (FreeTDS particularly)
row = self.cursor.fetchall()[0]
break
except self.dialect.dbapi.Error as e:
# no way around this - nextset() consumes the previous set
# so we need to just keep flipping
self.cursor.nextset()
self._lastrowid = int(row[0])
else:
super(MSExecutionContext_pyodbc, self).post_exec()
class MSDialect_pyodbc(PyODBCConnector, MSDialect):
execution_ctx_cls = MSExecutionContext_pyodbc
pyodbc_driver_name = 'SQL Server'
colspecs = util.update_copy(
MSDialect.colspecs,
{
sqltypes.Numeric: _MSNumeric_pyodbc,
sqltypes.Float: _MSFloat_pyodbc
}
)
def __init__(self, description_encoding=None, **params):
super(MSDialect_pyodbc, self).__init__(**params)
self.description_encoding = description_encoding
self.use_scope_identity = self.use_scope_identity and \
self.dbapi and \
hasattr(self.dbapi.Cursor, 'nextset')
self._need_decimal_fix = self.dbapi and \
self._dbapi_version() < (2, 1, 8)
dialect = MSDialect_pyodbc

View File

@@ -1,65 +0,0 @@
# mssql/zxjdbc.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mssql+zxjdbc
:name: zxJDBC for Jython
:dbapi: zxjdbc
:connectstring: mssql+zxjdbc://user:pass@host:port/dbname[?key=value&key=value...]
:driverurl: http://jtds.sourceforge.net/
"""
from ...connectors.zxJDBC import ZxJDBCConnector
from .base import MSDialect, MSExecutionContext
from ... import engine
class MSExecutionContext_zxjdbc(MSExecutionContext):
_embedded_scope_identity = False
def pre_exec(self):
super(MSExecutionContext_zxjdbc, self).pre_exec()
# scope_identity after the fact returns null in jTDS so we must
# embed it
if self._select_lastrowid and self.dialect.use_scope_identity:
self._embedded_scope_identity = True
self.statement += "; SELECT scope_identity()"
def post_exec(self):
if self._embedded_scope_identity:
while True:
try:
row = self.cursor.fetchall()[0]
break
except self.dialect.dbapi.Error:
self.cursor.nextset()
self._lastrowid = int(row[0])
if (self.isinsert or self.isupdate or self.isdelete) and \
self.compiled.returning:
self._result_proxy = engine.FullyBufferedResultProxy(self)
if self._enable_identity_insert:
table = self.dialect.identifier_preparer.format_table(
self.compiled.statement.table)
self.cursor.execute("SET IDENTITY_INSERT %s OFF" % table)
class MSDialect_zxjdbc(ZxJDBCConnector, MSDialect):
jdbc_db_name = 'jtds:sqlserver'
jdbc_driver_name = 'net.sourceforge.jtds.jdbc.Driver'
execution_ctx_cls = MSExecutionContext_zxjdbc
def _get_server_version_info(self, connection):
return tuple(
int(x)
for x in connection.connection.dbversion.split('.')
)
dialect = MSDialect_zxjdbc

View File

@@ -1,28 +0,0 @@
# mysql/__init__.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from . import base, mysqldb, oursql, \
pyodbc, zxjdbc, mysqlconnector, pymysql,\
gaerdbms, cymysql
# default dialect
base.dialect = mysqldb.dialect
from .base import \
BIGINT, BINARY, BIT, BLOB, BOOLEAN, CHAR, DATE, DATETIME, \
DECIMAL, DOUBLE, ENUM, DECIMAL,\
FLOAT, INTEGER, INTEGER, LONGBLOB, LONGTEXT, MEDIUMBLOB, \
MEDIUMINT, MEDIUMTEXT, NCHAR, \
NVARCHAR, NUMERIC, SET, SMALLINT, REAL, TEXT, TIME, TIMESTAMP, \
TINYBLOB, TINYINT, TINYTEXT,\
VARBINARY, VARCHAR, YEAR, dialect
__all__ = (
'BIGINT', 'BINARY', 'BIT', 'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', 'DOUBLE',
'ENUM', 'DECIMAL', 'FLOAT', 'INTEGER', 'INTEGER', 'LONGBLOB', 'LONGTEXT', 'MEDIUMBLOB', 'MEDIUMINT',
'MEDIUMTEXT', 'NCHAR', 'NVARCHAR', 'NUMERIC', 'SET', 'SMALLINT', 'REAL', 'TEXT', 'TIME', 'TIMESTAMP',
'TINYBLOB', 'TINYINT', 'TINYTEXT', 'VARBINARY', 'VARCHAR', 'YEAR', 'dialect'
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,69 +0,0 @@
# mysql/cymysql.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+cymysql
:name: CyMySQL
:dbapi: cymysql
:connectstring: mysql+cymysql://<username>:<password>@<host>/<dbname>[?<options>]
:url: https://github.com/nakagami/CyMySQL
"""
from .mysqldb import MySQLDialect_mysqldb
from .base import (BIT, MySQLDialect)
from ... import util
class _cymysqlBIT(BIT):
def result_processor(self, dialect, coltype):
"""Convert a MySQL's 64 bit, variable length binary string to a long.
"""
def process(value):
if value is not None:
v = 0
for i in util.iterbytes(value):
v = v << 8 | i
return v
return value
return process
class MySQLDialect_cymysql(MySQLDialect_mysqldb):
driver = 'cymysql'
description_encoding = None
supports_sane_rowcount = True
supports_sane_multi_rowcount = False
supports_unicode_statements = True
colspecs = util.update_copy(
MySQLDialect.colspecs,
{
BIT: _cymysqlBIT,
}
)
@classmethod
def dbapi(cls):
return __import__('cymysql')
def _extract_error_code(self, exception):
return exception.errno
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.OperationalError):
return self._extract_error_code(e) in \
(2006, 2013, 2014, 2045, 2055)
elif isinstance(e, self.dbapi.InterfaceError):
# if underlying connection is closed,
# this is the error you get
return True
else:
return False
dialect = MySQLDialect_cymysql

View File

@@ -1,84 +0,0 @@
# mysql/gaerdbms.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+gaerdbms
:name: Google Cloud SQL
:dbapi: rdbms
:connectstring: mysql+gaerdbms:///<dbname>?instance=<instancename>
:url: https://developers.google.com/appengine/docs/python/cloud-sql/developers-guide
This dialect is based primarily on the :mod:`.mysql.mysqldb` dialect with minimal
changes.
.. versionadded:: 0.7.8
Pooling
-------
Google App Engine connections appear to be randomly recycled,
so the dialect does not pool connections. The :class:`.NullPool`
implementation is installed within the :class:`.Engine` by
default.
"""
import os
from .mysqldb import MySQLDialect_mysqldb
from ...pool import NullPool
import re
def _is_dev_environment():
return os.environ.get('SERVER_SOFTWARE', '').startswith('Development/')
class MySQLDialect_gaerdbms(MySQLDialect_mysqldb):
@classmethod
def dbapi(cls):
# from django:
# http://code.google.com/p/googleappengine/source/
# browse/trunk/python/google/storage/speckle/
# python/django/backend/base.py#118
# see also [ticket:2649]
# see also http://stackoverflow.com/q/14224679/34549
from google.appengine.api import apiproxy_stub_map
if _is_dev_environment():
from google.appengine.api import rdbms_mysqldb
return rdbms_mysqldb
elif apiproxy_stub_map.apiproxy.GetStub('rdbms'):
from google.storage.speckle.python.api import rdbms_apiproxy
return rdbms_apiproxy
else:
from google.storage.speckle.python.api import rdbms_googleapi
return rdbms_googleapi
@classmethod
def get_pool_class(cls, url):
# Cloud SQL connections die at any moment
return NullPool
def create_connect_args(self, url):
opts = url.translate_connect_args()
if not _is_dev_environment():
# 'dsn' and 'instance' are because we are skipping
# the traditional google.api.rdbms wrapper
opts['dsn'] = ''
opts['instance'] = url.query['instance']
return [], opts
def _extract_error_code(self, exception):
match = re.compile(r"^(\d+)L?:|^\((\d+)L?,").match(str(exception))
# The rdbms api will wrap then re-raise some types of errors
# making this regex return no matches.
code = match.group(1) or match.group(2) if match else None
if code:
return int(code)
dialect = MySQLDialect_gaerdbms

View File

@@ -1,128 +0,0 @@
# mysql/mysqlconnector.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+mysqlconnector
:name: MySQL Connector/Python
:dbapi: myconnpy
:connectstring: mysql+mysqlconnector://<user>:<password>@<host>[:<port>]/<dbname>
:url: https://launchpad.net/myconnpy
"""
from .base import (MySQLDialect,
MySQLExecutionContext, MySQLCompiler, MySQLIdentifierPreparer,
BIT)
from ... import util
class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext):
def get_lastrowid(self):
return self.cursor.lastrowid
class MySQLCompiler_mysqlconnector(MySQLCompiler):
def visit_mod_binary(self, binary, operator, **kw):
return self.process(binary.left, **kw) + " %% " + \
self.process(binary.right, **kw)
def post_process_text(self, text):
return text.replace('%', '%%')
class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer):
def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote)
return value.replace("%", "%%")
class _myconnpyBIT(BIT):
def result_processor(self, dialect, coltype):
"""MySQL-connector already converts mysql bits, so."""
return None
class MySQLDialect_mysqlconnector(MySQLDialect):
driver = 'mysqlconnector'
supports_unicode_statements = True
supports_unicode_binds = True
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
supports_native_decimal = True
default_paramstyle = 'format'
execution_ctx_cls = MySQLExecutionContext_mysqlconnector
statement_compiler = MySQLCompiler_mysqlconnector
preparer = MySQLIdentifierPreparer_mysqlconnector
colspecs = util.update_copy(
MySQLDialect.colspecs,
{
BIT: _myconnpyBIT,
}
)
@classmethod
def dbapi(cls):
from mysql import connector
return connector
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
opts.update(url.query)
util.coerce_kw_type(opts, 'buffered', bool)
util.coerce_kw_type(opts, 'raise_on_warnings', bool)
opts.setdefault('buffered', True)
opts.setdefault('raise_on_warnings', True)
# FOUND_ROWS must be set in ClientFlag to enable
# supports_sane_rowcount.
if self.dbapi is not None:
try:
from mysql.connector.constants import ClientFlag
client_flags = opts.get('client_flags', ClientFlag.get_default())
client_flags |= ClientFlag.FOUND_ROWS
opts['client_flags'] = client_flags
except:
pass
return [[], opts]
def _get_server_version_info(self, connection):
dbapi_con = connection.connection
version = dbapi_con.get_server_version()
return tuple(version)
def _detect_charset(self, connection):
return connection.connection.charset
def _extract_error_code(self, exception):
return exception.errno
def is_disconnect(self, e, connection, cursor):
errnos = (2006, 2013, 2014, 2045, 2055, 2048)
exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError)
if isinstance(e, exceptions):
return e.errno in errnos or \
"MySQL Connection not available." in str(e)
else:
return False
def _compat_fetchall(self, rp, charset=None):
return rp.fetchall()
def _compat_fetchone(self, rp, charset=None):
return rp.fetchone()
dialect = MySQLDialect_mysqlconnector

View File

@@ -1,78 +0,0 @@
# mysql/mysqldb.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+mysqldb
:name: MySQL-Python
:dbapi: mysqldb
:connectstring: mysql+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>
:url: http://sourceforge.net/projects/mysql-python
Unicode
-------
MySQLdb will accommodate Python ``unicode`` objects if the
``use_unicode=1`` parameter, or the ``charset`` parameter,
is passed as a connection argument.
Without this setting, many MySQL server installations default to
a ``latin1`` encoding for client connections, which has the effect
of all data being converted into ``latin1``, even if you have ``utf8``
or another character set configured on your tables
and columns. With versions 4.1 and higher, you can change the connection
character set either through server configuration or by including the
``charset`` parameter. The ``charset``
parameter as received by MySQL-Python also has the side-effect of
enabling ``use_unicode=1``::
# set client encoding to utf8; all strings come back as unicode
create_engine('mysql+mysqldb:///mydb?charset=utf8')
Manually configuring ``use_unicode=0`` will cause MySQL-python to
return encoded strings::
# set client encoding to utf8; all strings come back as utf8 str
create_engine('mysql+mysqldb:///mydb?charset=utf8&use_unicode=0')
Known Issues
-------------
MySQL-python version 1.2.2 has a serious memory leak related
to unicode conversion, a feature which is disabled via ``use_unicode=0``.
It is strongly advised to use the latest version of MySQL-Python.
"""
from .base import (MySQLDialect, MySQLExecutionContext,
MySQLCompiler, MySQLIdentifierPreparer)
from ...connectors.mysqldb import (
MySQLDBExecutionContext,
MySQLDBCompiler,
MySQLDBIdentifierPreparer,
MySQLDBConnector
)
class MySQLExecutionContext_mysqldb(MySQLDBExecutionContext, MySQLExecutionContext):
pass
class MySQLCompiler_mysqldb(MySQLDBCompiler, MySQLCompiler):
pass
class MySQLIdentifierPreparer_mysqldb(MySQLDBIdentifierPreparer, MySQLIdentifierPreparer):
pass
class MySQLDialect_mysqldb(MySQLDBConnector, MySQLDialect):
execution_ctx_cls = MySQLExecutionContext_mysqldb
statement_compiler = MySQLCompiler_mysqldb
preparer = MySQLIdentifierPreparer_mysqldb
dialect = MySQLDialect_mysqldb

View File

@@ -1,261 +0,0 @@
# mysql/oursql.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+oursql
:name: OurSQL
:dbapi: oursql
:connectstring: mysql+oursql://<user>:<password>@<host>[:<port>]/<dbname>
:url: http://packages.python.org/oursql/
Unicode
-------
oursql defaults to using ``utf8`` as the connection charset, but other
encodings may be used instead. Like the MySQL-Python driver, unicode support
can be completely disabled::
# oursql sets the connection charset to utf8 automatically; all strings come
# back as utf8 str
create_engine('mysql+oursql:///mydb?use_unicode=0')
To not automatically use ``utf8`` and instead use whatever the connection
defaults to, there is a separate parameter::
# use the default connection charset; all strings come back as unicode
create_engine('mysql+oursql:///mydb?default_charset=1')
# use latin1 as the connection charset; all strings come back as unicode
create_engine('mysql+oursql:///mydb?charset=latin1')
"""
import re
from .base import (BIT, MySQLDialect, MySQLExecutionContext)
from ... import types as sqltypes, util
class _oursqlBIT(BIT):
def result_processor(self, dialect, coltype):
"""oursql already converts mysql bits, so."""
return None
class MySQLExecutionContext_oursql(MySQLExecutionContext):
@property
def plain_query(self):
return self.execution_options.get('_oursql_plain_query', False)
class MySQLDialect_oursql(MySQLDialect):
driver = 'oursql'
if util.py2k:
supports_unicode_binds = True
supports_unicode_statements = True
supports_native_decimal = True
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
execution_ctx_cls = MySQLExecutionContext_oursql
colspecs = util.update_copy(
MySQLDialect.colspecs,
{
sqltypes.Time: sqltypes.Time,
BIT: _oursqlBIT,
}
)
@classmethod
def dbapi(cls):
return __import__('oursql')
def do_execute(self, cursor, statement, parameters, context=None):
"""Provide an implementation of *cursor.execute(statement, parameters)*."""
if context and context.plain_query:
cursor.execute(statement, plain_query=True)
else:
cursor.execute(statement, parameters)
def do_begin(self, connection):
connection.cursor().execute('BEGIN', plain_query=True)
def _xa_query(self, connection, query, xid):
if util.py2k:
arg = connection.connection._escape_string(xid)
else:
charset = self._connection_charset
arg = connection.connection._escape_string(xid.encode(charset)).decode(charset)
arg = "'%s'" % arg
connection.execution_options(_oursql_plain_query=True).execute(query % arg)
# Because mysql is bad, these methods have to be
# reimplemented to use _PlainQuery. Basically, some queries
# refuse to return any data if they're run through
# the parameterized query API, or refuse to be parameterized
# in the first place.
def do_begin_twophase(self, connection, xid):
self._xa_query(connection, 'XA BEGIN %s', xid)
def do_prepare_twophase(self, connection, xid):
self._xa_query(connection, 'XA END %s', xid)
self._xa_query(connection, 'XA PREPARE %s', xid)
def do_rollback_twophase(self, connection, xid, is_prepared=True,
recover=False):
if not is_prepared:
self._xa_query(connection, 'XA END %s', xid)
self._xa_query(connection, 'XA ROLLBACK %s', xid)
def do_commit_twophase(self, connection, xid, is_prepared=True,
recover=False):
if not is_prepared:
self.do_prepare_twophase(connection, xid)
self._xa_query(connection, 'XA COMMIT %s', xid)
# Q: why didn't we need all these "plain_query" overrides earlier ?
# am i on a newer/older version of OurSQL ?
def has_table(self, connection, table_name, schema=None):
return MySQLDialect.has_table(
self,
connection.connect().execution_options(_oursql_plain_query=True),
table_name,
schema
)
def get_table_options(self, connection, table_name, schema=None, **kw):
return MySQLDialect.get_table_options(
self,
connection.connect().execution_options(_oursql_plain_query=True),
table_name,
schema=schema,
**kw
)
def get_columns(self, connection, table_name, schema=None, **kw):
return MySQLDialect.get_columns(
self,
connection.connect().execution_options(_oursql_plain_query=True),
table_name,
schema=schema,
**kw
)
def get_view_names(self, connection, schema=None, **kw):
return MySQLDialect.get_view_names(
self,
connection.connect().execution_options(_oursql_plain_query=True),
schema=schema,
**kw
)
def get_table_names(self, connection, schema=None, **kw):
return MySQLDialect.get_table_names(
self,
connection.connect().execution_options(_oursql_plain_query=True),
schema
)
def get_schema_names(self, connection, **kw):
return MySQLDialect.get_schema_names(
self,
connection.connect().execution_options(_oursql_plain_query=True),
**kw
)
def initialize(self, connection):
return MySQLDialect.initialize(
self,
connection.execution_options(_oursql_plain_query=True)
)
def _show_create_table(self, connection, table, charset=None,
full_name=None):
return MySQLDialect._show_create_table(
self,
connection.contextual_connect(close_with_result=True).
execution_options(_oursql_plain_query=True),
table, charset, full_name
)
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.ProgrammingError):
return e.errno is None and 'cursor' not in e.args[1] and e.args[1].endswith('closed')
else:
return e.errno in (2006, 2013, 2014, 2045, 2055)
def create_connect_args(self, url):
opts = url.translate_connect_args(database='db', username='user',
password='passwd')
opts.update(url.query)
util.coerce_kw_type(opts, 'port', int)
util.coerce_kw_type(opts, 'compress', bool)
util.coerce_kw_type(opts, 'autoping', bool)
util.coerce_kw_type(opts, 'raise_on_warnings', bool)
util.coerce_kw_type(opts, 'default_charset', bool)
if opts.pop('default_charset', False):
opts['charset'] = None
else:
util.coerce_kw_type(opts, 'charset', str)
opts['use_unicode'] = opts.get('use_unicode', True)
util.coerce_kw_type(opts, 'use_unicode', bool)
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
# supports_sane_rowcount.
opts.setdefault('found_rows', True)
ssl = {}
for key in ['ssl_ca', 'ssl_key', 'ssl_cert',
'ssl_capath', 'ssl_cipher']:
if key in opts:
ssl[key[4:]] = opts[key]
util.coerce_kw_type(ssl, key[4:], str)
del opts[key]
if ssl:
opts['ssl'] = ssl
return [[], opts]
def _get_server_version_info(self, connection):
dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
for n in r.split(dbapi_con.server_info):
try:
version.append(int(n))
except ValueError:
version.append(n)
return tuple(version)
def _extract_error_code(self, exception):
return exception.errno
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
return connection.connection.charset
def _compat_fetchall(self, rp, charset=None):
"""oursql isn't super-broken like MySQLdb, yaaay."""
return rp.fetchall()
def _compat_fetchone(self, rp, charset=None):
"""oursql isn't super-broken like MySQLdb, yaaay."""
return rp.fetchone()
def _compat_first(self, rp, charset=None):
return rp.first()
dialect = MySQLDialect_oursql

View File

@@ -1,44 +0,0 @@
# mysql/pymysql.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+pymysql
:name: PyMySQL
:dbapi: pymysql
:connectstring: mysql+pymysql://<username>:<password>@<host>/<dbname>[?<options>]
:url: http://code.google.com/p/pymysql/
MySQL-Python Compatibility
--------------------------
The pymysql DBAPI is a pure Python port of the MySQL-python (MySQLdb) driver,
and targets 100% compatibility. Most behavioral notes for MySQL-python apply to
the pymysql driver as well.
"""
from .mysqldb import MySQLDialect_mysqldb
from ...util import py3k
class MySQLDialect_pymysql(MySQLDialect_mysqldb):
driver = 'pymysql'
description_encoding = None
if py3k:
supports_unicode_statements = True
@classmethod
def dbapi(cls):
return __import__('pymysql')
if py3k:
def _extract_error_code(self, exception):
if isinstance(exception.args[0], Exception):
exception = exception.args[0]
return exception.args[0]
dialect = MySQLDialect_pymysql

View File

@@ -1,80 +0,0 @@
# mysql/pyodbc.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+pyodbc
:name: PyODBC
:dbapi: pyodbc
:connectstring: mysql+pyodbc://<username>:<password>@<dsnname>
:url: http://pypi.python.org/pypi/pyodbc/
Limitations
-----------
The mysql-pyodbc dialect is subject to unresolved character encoding issues
which exist within the current ODBC drivers available.
(see http://code.google.com/p/pyodbc/issues/detail?id=25). Consider usage
of OurSQL, MySQLdb, or MySQL-connector/Python.
"""
from .base import MySQLDialect, MySQLExecutionContext
from ...connectors.pyodbc import PyODBCConnector
from ... import util
import re
class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
def get_lastrowid(self):
cursor = self.create_cursor()
cursor.execute("SELECT LAST_INSERT_ID()")
lastrowid = cursor.fetchone()[0]
cursor.close()
return lastrowid
class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
supports_unicode_statements = False
execution_ctx_cls = MySQLExecutionContext_pyodbc
pyodbc_driver_name = "MySQL"
def __init__(self, **kw):
# deal with http://code.google.com/p/pyodbc/issues/detail?id=25
kw.setdefault('convert_unicode', True)
super(MySQLDialect_pyodbc, self).__init__(**kw)
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
# Prefer 'character_set_results' for the current connection over the
# value in the driver. SET NAMES or individual variable SETs will
# change the charset without updating the driver's view of the world.
#
# If it's decided that issuing that sort of SQL leaves you SOL, then
# this can prefer the driver value.
rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)])
for key in ('character_set_connection', 'character_set'):
if opts.get(key, None):
return opts[key]
util.warn("Could not detect the connection character set. Assuming latin1.")
return 'latin1'
def _extract_error_code(self, exception):
m = re.compile(r"\((\d+)\)").search(str(exception.args))
c = m.group(1)
if c:
return int(c)
else:
return None
dialect = MySQLDialect_pyodbc

View File

@@ -1,111 +0,0 @@
# mysql/zxjdbc.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+zxjdbc
:name: zxjdbc for Jython
:dbapi: zxjdbc
:connectstring: mysql+zxjdbc://<user>:<password>@<hostname>[:<port>]/<database>
:driverurl: http://dev.mysql.com/downloads/connector/j/
Character Sets
--------------
SQLAlchemy zxjdbc dialects pass unicode straight through to the
zxjdbc/JDBC layer. To allow multiple character sets to be sent from the
MySQL Connector/J JDBC driver, by default SQLAlchemy sets its
``characterEncoding`` connection property to ``UTF-8``. It may be
overriden via a ``create_engine`` URL parameter.
"""
import re
from ... import types as sqltypes, util
from ...connectors.zxJDBC import ZxJDBCConnector
from .base import BIT, MySQLDialect, MySQLExecutionContext
class _ZxJDBCBit(BIT):
def result_processor(self, dialect, coltype):
"""Converts boolean or byte arrays from MySQL Connector/J to longs."""
def process(value):
if value is None:
return value
if isinstance(value, bool):
return int(value)
v = 0
for i in value:
v = v << 8 | (i & 0xff)
value = v
return value
return process
class MySQLExecutionContext_zxjdbc(MySQLExecutionContext):
def get_lastrowid(self):
cursor = self.create_cursor()
cursor.execute("SELECT LAST_INSERT_ID()")
lastrowid = cursor.fetchone()[0]
cursor.close()
return lastrowid
class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect):
jdbc_db_name = 'mysql'
jdbc_driver_name = 'com.mysql.jdbc.Driver'
execution_ctx_cls = MySQLExecutionContext_zxjdbc
colspecs = util.update_copy(
MySQLDialect.colspecs,
{
sqltypes.Time: sqltypes.Time,
BIT: _ZxJDBCBit
}
)
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
# Prefer 'character_set_results' for the current connection over the
# value in the driver. SET NAMES or individual variable SETs will
# change the charset without updating the driver's view of the world.
#
# If it's decided that issuing that sort of SQL leaves you SOL, then
# this can prefer the driver value.
rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
opts = dict((row[0], row[1]) for row in self._compat_fetchall(rs))
for key in ('character_set_connection', 'character_set'):
if opts.get(key, None):
return opts[key]
util.warn("Could not detect the connection character set. Assuming latin1.")
return 'latin1'
def _driver_kwargs(self):
"""return kw arg dict to be sent to connect()."""
return dict(characterEncoding='UTF-8', yearIsDateType='false')
def _extract_error_code(self, exception):
# e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist
# [SQLCode: 1146], [SQLState: 42S02] 'DESCRIBE `u2`' ()
m = re.compile(r"\[SQLCode\: (\d+)\]").search(str(exception.args))
c = m.group(1)
if c:
return int(c)
def _get_server_version_info(self, connection):
dbapi_con = connection.connection
version = []
r = re.compile('[.\-]')
for n in r.split(dbapi_con.dbversion):
try:
version.append(int(n))
except ValueError:
version.append(n)
return tuple(version)
dialect = MySQLDialect_zxjdbc

View File

@@ -1,23 +0,0 @@
# oracle/__init__.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from sqlalchemy.dialects.oracle import base, cx_oracle, zxjdbc
base.dialect = cx_oracle.dialect
from sqlalchemy.dialects.oracle.base import \
VARCHAR, NVARCHAR, CHAR, DATE, DATETIME, NUMBER,\
BLOB, BFILE, CLOB, NCLOB, TIMESTAMP, RAW,\
FLOAT, DOUBLE_PRECISION, LONG, dialect, INTERVAL,\
VARCHAR2, NVARCHAR2, ROWID, dialect
__all__ = (
'VARCHAR', 'NVARCHAR', 'CHAR', 'DATE', 'DATETIME', 'NUMBER',
'BLOB', 'BFILE', 'CLOB', 'NCLOB', 'TIMESTAMP', 'RAW',
'FLOAT', 'DOUBLE_PRECISION', 'LONG', 'dialect', 'INTERVAL',
'VARCHAR2', 'NVARCHAR2', 'ROWID'
)

Some files were not shown because too many files have changed in this diff Show More