Update SQLAlchemy

This commit is contained in:
Ruud
2013-06-14 11:00:06 +02:00
parent 267ecfacab
commit 4aa6700ceb
124 changed files with 6500 additions and 5207 deletions
+1 -1
View File
@@ -1,5 +1,5 @@
# sql/__init__.py
# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors <see AUTHORS file>
# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+249 -174
View File
@@ -1,5 +1,5 @@
# sql/compiler.py
# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors <see AUTHORS file>
# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
@@ -17,7 +17,7 @@ strings
:class:`~sqlalchemy.sql.compiler.GenericTypeCompiler` - renders
type specification strings.
To generate user-defined SQL strings, see
To generate user-defined SQL strings, see
:module:`~sqlalchemy.ext.compiler`.
"""
@@ -29,6 +29,7 @@ from sqlalchemy.sql import operators, functions, util as sql_util, \
visitors
from sqlalchemy.sql import expression as sql
import decimal
import itertools
RESERVED_WORDS = set([
'all', 'analyse', 'analyze', 'and', 'any', 'array',
@@ -59,7 +60,7 @@ BIND_TEMPLATES = {
'pyformat':"%%(%(name)s)s",
'qmark':"?",
'format':"%%s",
'numeric':":%(position)s",
'numeric':":[_POSITION]",
'named':":%(name)s"
}
@@ -214,7 +215,7 @@ class SQLCompiler(engine.Compiled):
driver/DB enforces this
"""
def __init__(self, dialect, statement, column_keys=None,
def __init__(self, dialect, statement, column_keys=None,
inline=False, **kwargs):
"""Construct a new ``DefaultCompiler`` object.
@@ -252,16 +253,14 @@ class SQLCompiler(engine.Compiled):
# column targeting
self.result_map = {}
# collect CTEs to tack on top of a SELECT
self.ctes = util.OrderedDict()
self.ctes_recursive = False
# true if the paramstyle is positional
self.positional = dialect.positional
if self.positional:
self.positiontup = []
self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
self.ctes = None
# an IdentifierPreparer that formats the quoting of identifiers
self.preparer = dialect.identifier_preparer
self.label_length = dialect.label_length \
@@ -276,7 +275,29 @@ class SQLCompiler(engine.Compiled):
self.truncated_names = {}
engine.Compiled.__init__(self, dialect, statement, **kwargs)
if self.positional and dialect.paramstyle == 'numeric':
self._apply_numbered_params()
@util.memoized_instancemethod
def _init_cte_state(self):
"""Initialize collections related to CTEs only if
a CTE is located, to save on the overhead of
these collections otherwise.
"""
# collect CTEs to tack on top of a SELECT
self.ctes = util.OrderedDict()
self.ctes_by_name = {}
self.ctes_recursive = False
if self.positional:
self.cte_positional = []
def _apply_numbered_params(self):
poscount = itertools.count(1)
self.string = re.sub(
r'\[_POSITION\]',
lambda m:str(util.next(poscount)),
self.string)
@util.memoized_property
def _bind_processors(self):
@@ -309,11 +330,11 @@ class SQLCompiler(engine.Compiled):
if _group_number:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r, "
"in parameter group %d" %
"in parameter group %d" %
(bindparam.key, _group_number))
else:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r"
"A value is required for bind parameter %r"
% bindparam.key)
else:
pd[name] = bindparam.effective_value
@@ -325,18 +346,18 @@ class SQLCompiler(engine.Compiled):
if _group_number:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r, "
"in parameter group %d" %
"in parameter group %d" %
(bindparam.key, _group_number))
else:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r"
"A value is required for bind parameter %r"
% bindparam.key)
pd[self.bind_names[bindparam]] = bindparam.effective_value
return pd
@property
def params(self):
"""Return the bind param dictionary embedded into this
"""Return the bind param dictionary embedded into this
compiled object, for those values that are present."""
return self.construct_params(_check=False)
@@ -352,8 +373,8 @@ class SQLCompiler(engine.Compiled):
def visit_grouping(self, grouping, asfrom=False, **kwargs):
return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
def visit_label(self, label, result_map=None,
within_label_clause=False,
def visit_label(self, label, result_map=None,
within_label_clause=False,
within_columns_clause=False, **kw):
# only render labels within the columns clause
# or ORDER BY clause of a select. dialect-specific compilers
@@ -366,20 +387,20 @@ class SQLCompiler(engine.Compiled):
if result_map is not None:
result_map[labelname.lower()] = (
label.name,
(label, label.element, labelname, ) +
label.name,
(label, label.element, labelname, ) +
label._alt_names,
label.type)
return label.element._compiler_dispatch(self,
return label.element._compiler_dispatch(self,
within_columns_clause=True,
within_label_clause=True,
within_label_clause=True,
**kw) + \
OPERATORS[operators.as_] + \
self.preparer.format_label(label, labelname)
else:
return label.element._compiler_dispatch(self,
within_columns_clause=False,
return label.element._compiler_dispatch(self,
within_columns_clause=False,
**kw)
def visit_column(self, column, result_map=None, **kwargs):
@@ -393,8 +414,8 @@ class SQLCompiler(engine.Compiled):
name = self._truncated_identifier("colident", name)
if result_map is not None:
result_map[name.lower()] = (orig_name,
(column, name, column.key),
result_map[name.lower()] = (orig_name,
(column, name, column.key),
column.type)
if is_literal:
@@ -408,7 +429,7 @@ class SQLCompiler(engine.Compiled):
else:
if table.schema:
schema_prefix = self.preparer.quote_schema(
table.schema,
table.schema,
table.quote_schema) + '.'
else:
schema_prefix = ''
@@ -448,7 +469,7 @@ class SQLCompiler(engine.Compiled):
if name in textclause.bindparams:
return self.process(textclause.bindparams[name])
else:
return self.bindparam_string(name)
return self.bindparam_string(name, **kwargs)
# un-escape any \:params
return BIND_PARAMS_ESC.sub(lambda m: m.group(1),
@@ -472,8 +493,8 @@ class SQLCompiler(engine.Compiled):
else:
sep = OPERATORS[clauselist.operator]
return sep.join(
s for s in
(c._compiler_dispatch(self, **kwargs)
s for s in
(c._compiler_dispatch(self, **kwargs)
for c in clauselist.clauses)
if s)
@@ -499,21 +520,21 @@ class SQLCompiler(engine.Compiled):
cast.typeclause._compiler_dispatch(self, **kwargs))
def visit_over(self, over, **kwargs):
x ="%s OVER (" % over.func._compiler_dispatch(self, **kwargs)
if over.partition_by is not None:
x += "PARTITION BY %s" % \
over.partition_by._compiler_dispatch(self, **kwargs)
if over.order_by is not None:
x += " "
if over.order_by is not None:
x += "ORDER BY %s" % \
over.order_by._compiler_dispatch(self, **kwargs)
x += ")"
return x
return "%s OVER (%s)" % (
over.func._compiler_dispatch(self, **kwargs),
' '.join(
'%s BY %s' % (word, clause._compiler_dispatch(self, **kwargs))
for word, clause in (
('PARTITION', over.partition_by),
('ORDER', over.order_by)
)
if clause is not None and len(clause)
)
)
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
return "EXTRACT(%s FROM %s)" % (field,
return "EXTRACT(%s FROM %s)" % (field,
extract.expr._compiler_dispatch(self, **kwargs))
def visit_function(self, func, result_map=None, **kwargs):
@@ -526,7 +547,7 @@ class SQLCompiler(engine.Compiled):
else:
name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s")
return ".".join(list(func.packagenames) + [name]) % \
{'expr':self.function_argspec(func, **kwargs)}
{'expr': self.function_argspec(func, **kwargs)}
def visit_next_value_func(self, next_value, **kw):
return self.visit_sequence(next_value.sequence)
@@ -539,16 +560,17 @@ class SQLCompiler(engine.Compiled):
def function_argspec(self, func, **kwargs):
return func.clause_expr._compiler_dispatch(self, **kwargs)
def visit_compound_select(self, cs, asfrom=False,
parens=True, compound_index=1, **kwargs):
def visit_compound_select(self, cs, asfrom=False,
parens=True, compound_index=0, **kwargs):
entry = self.stack and self.stack[-1] or {}
self.stack.append({'from':entry.get('from', None), 'iswrapper':True})
self.stack.append({'from': entry.get('from', None),
'iswrapper': not entry})
keyword = self.compound_keywords.get(cs.keyword)
text = (" " + keyword + " ").join(
(c._compiler_dispatch(self,
asfrom=asfrom, parens=False,
(c._compiler_dispatch(self,
asfrom=asfrom, parens=False,
compound_index=i, **kwargs)
for i, c in enumerate(cs.selects))
)
@@ -562,6 +584,10 @@ class SQLCompiler(engine.Compiled):
text += (cs._limit is not None or cs._offset is not None) and \
self.limit_clause(cs) or ""
if self.ctes and \
compound_index == 0 and not entry:
text = self._render_cte_clause() + text
self.stack.pop(-1)
if asfrom and parens:
return "(" + text + ")"
@@ -585,8 +611,8 @@ class SQLCompiler(engine.Compiled):
return self._operator_dispatch(binary.operator,
binary,
lambda opstr: binary.left._compiler_dispatch(self, **kw) +
opstr +
lambda opstr: binary.left._compiler_dispatch(self, **kw) +
opstr +
binary.right._compiler_dispatch(
self, **kw),
**kw
@@ -595,36 +621,36 @@ class SQLCompiler(engine.Compiled):
def visit_like_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
return '%s LIKE %s' % (
binary.left._compiler_dispatch(self, **kw),
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
+ (escape and
+ (escape and
(' ESCAPE ' + self.render_literal_value(escape, None))
or '')
def visit_notlike_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
return '%s NOT LIKE %s' % (
binary.left._compiler_dispatch(self, **kw),
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
+ (escape and
+ (escape and
(' ESCAPE ' + self.render_literal_value(escape, None))
or '')
def visit_ilike_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
return 'lower(%s) LIKE lower(%s)' % (
binary.left._compiler_dispatch(self, **kw),
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
+ (escape and
+ (escape and
(' ESCAPE ' + self.render_literal_value(escape, None))
or '')
def visit_notilike_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
return 'lower(%s) NOT LIKE lower(%s)' % (
binary.left._compiler_dispatch(self, **kw),
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
+ (escape and
+ (escape and
(' ESCAPE ' + self.render_literal_value(escape, None))
or '')
@@ -668,7 +694,7 @@ class SQLCompiler(engine.Compiled):
"bindparam() name '%s' is reserved "
"for automatic usage in the VALUES or SET "
"clause of this "
"insert/update statement. Please use a "
"insert/update statement. Please use a "
"name other than column name when using bindparam() "
"with insert() or update() (for example, 'b_%s')."
% (bindparam.key, bindparam.key)
@@ -676,7 +702,7 @@ class SQLCompiler(engine.Compiled):
self.binds[bindparam.key] = self.binds[name] = bindparam
return self.bindparam_string(name)
return self.bindparam_string(name, **kwargs)
def render_literal_bindparam(self, bindparam, **kw):
value = bindparam.value
@@ -688,7 +714,7 @@ class SQLCompiler(engine.Compiled):
def render_literal_value(self, value, type_):
"""Render the value of a bind parameter as a quoted literal.
This is used for statement sections that do not accept bind paramters
This is used for statement sections that do not accept bind parameters
on the target driver/database.
This should be implemented by subclasses using the quoting services
@@ -746,20 +772,45 @@ class SQLCompiler(engine.Compiled):
self.anon_map[derived] = anonymous_counter + 1
return derived + "_" + str(anonymous_counter)
def bindparam_string(self, name):
def bindparam_string(self, name, positional_names=None, **kw):
if self.positional:
self.positiontup.append(name)
return self.bindtemplate % {
'name':name, 'position':len(self.positiontup)}
else:
return self.bindtemplate % {'name':name}
if positional_names is not None:
positional_names.append(name)
else:
self.positiontup.append(name)
return self.bindtemplate % {'name':name}
def visit_cte(self, cte, asfrom=False, ashint=False,
fromhints=None,
**kwargs):
self._init_cte_state()
if self.positional:
kwargs['positional_names'] = self.cte_positional
def visit_cte(self, cte, asfrom=False, ashint=False,
fromhints=None, **kwargs):
if isinstance(cte.name, sql._truncated_label):
cte_name = self._truncated_identifier("alias", cte.name)
else:
cte_name = cte.name
if cte_name in self.ctes_by_name:
existing_cte = self.ctes_by_name[cte_name]
# we've generated a same-named CTE that we are enclosed in,
# or this is the same CTE. just return the name.
if cte in existing_cte._restates or cte is existing_cte:
return cte_name
elif existing_cte in cte._restates:
# we've generated a same-named CTE that is
# enclosed in us - we take precedence, so
# discard the text for the "inner".
del self.ctes[existing_cte]
else:
raise exc.CompileError(
"Multiple, unrelated CTEs found with "
"the same name: %r" %
cte_name)
self.ctes_by_name[cte_name] = cte
if cte.cte_alias:
if isinstance(cte.cte_alias, sql._truncated_label):
cte_alias = self._truncated_identifier("alias", cte.cte_alias)
@@ -776,10 +827,13 @@ class SQLCompiler(engine.Compiled):
col_source = cte.original.selects[0]
else:
assert False
recur_cols = [c.key for c in util.unique_list(col_source.inner_columns)
recur_cols = [c for c in
util.unique_list(col_source.inner_columns)
if c is not None]
text += "(%s)" % (", ".join(recur_cols))
text += "(%s)" % (", ".join(
self.preparer.format_column(ident)
for ident in recur_cols))
text += " AS \n" + \
cte.original._compiler_dispatch(
self, asfrom=True, **kwargs
@@ -793,7 +847,7 @@ class SQLCompiler(engine.Compiled):
return self.preparer.format_alias(cte, cte_name)
return text
def visit_alias(self, alias, asfrom=False, ashint=False,
def visit_alias(self, alias, asfrom=False, ashint=False,
fromhints=None, **kwargs):
if asfrom or ashint:
if isinstance(alias.name, sql._truncated_label):
@@ -804,7 +858,7 @@ class SQLCompiler(engine.Compiled):
if ashint:
return self.preparer.format_alias(alias, alias_name)
elif asfrom:
ret = alias.original._compiler_dispatch(self,
ret = alias.original._compiler_dispatch(self,
asfrom=True, **kwargs) + \
" AS " + \
self.preparer.format_alias(alias, alias_name)
@@ -828,8 +882,8 @@ class SQLCompiler(engine.Compiled):
select.use_labels and \
column._label:
return _CompileLabel(
column,
column._label,
column,
column._label,
alt_names=(column._key_label, )
)
@@ -839,9 +893,9 @@ class SQLCompiler(engine.Compiled):
not column.is_literal and \
column.table is not None and \
not isinstance(column.table, sql.Select):
return _CompileLabel(column, sql._as_truncated(column.name),
return _CompileLabel(column, sql._as_truncated(column.name),
alt_names=(column.key,))
elif not isinstance(column,
elif not isinstance(column,
(sql._UnaryExpression, sql._TextClause)) \
and (not hasattr(column, 'name') or \
isinstance(column, sql.Function)):
@@ -858,9 +912,10 @@ class SQLCompiler(engine.Compiled):
def get_crud_hint_text(self, table, text):
return None
def visit_select(self, select, asfrom=False, parens=True,
iswrapper=False, fromhints=None,
compound_index=1, **kwargs):
def visit_select(self, select, asfrom=False, parens=True,
iswrapper=False, fromhints=None,
compound_index=0,
positional_names=None, **kwargs):
entry = self.stack and self.stack[-1] or {}
@@ -875,13 +930,18 @@ class SQLCompiler(engine.Compiled):
# to outermost if existingfroms: correlate_froms =
# correlate_froms.union(existingfroms)
self.stack.append({'from': correlate_froms, 'iswrapper'
: iswrapper})
populate_result_map = compound_index == 0 and (
not entry or \
entry.get('iswrapper', False)
)
if compound_index==1 and not entry or entry.get('iswrapper', False):
column_clause_args = {'result_map':self.result_map}
self.stack.append({'from': correlate_froms, 'iswrapper': iswrapper})
if populate_result_map:
column_clause_args = {'result_map': self.result_map,
'positional_names': positional_names}
else:
column_clause_args = {}
column_clause_args = {'positional_names': positional_names}
# the actual list of columns to print in the SELECT column list.
inner_columns = [
@@ -889,7 +949,7 @@ class SQLCompiler(engine.Compiled):
self.label_select_column(select, co, asfrom=asfrom).\
_compiler_dispatch(self,
within_columns_clause=True,
**column_clause_args)
**column_clause_args)
for co in util.unique_list(select.inner_columns)
]
if c is not None
@@ -902,9 +962,9 @@ class SQLCompiler(engine.Compiled):
(from_, hinttext % {
'name':from_._compiler_dispatch(
self, ashint=True)
})
for (from_, dialect), hinttext in
select._hints.iteritems()
})
for (from_, dialect), hinttext in
select._hints.iteritems()
if dialect in ('*', self.dialect.name)
])
hint_text = self.get_select_hint_text(byfrom)
@@ -913,7 +973,7 @@ class SQLCompiler(engine.Compiled):
if select._prefixes:
text += " ".join(
x._compiler_dispatch(self, **kwargs)
x._compiler_dispatch(self, **kwargs)
for x in select._prefixes) + " "
text += self.get_select_precolumns(select)
text += ', '.join(inner_columns)
@@ -922,13 +982,13 @@ class SQLCompiler(engine.Compiled):
text += " \nFROM "
if select._hints:
text += ', '.join([f._compiler_dispatch(self,
asfrom=True, fromhints=byfrom,
**kwargs)
text += ', '.join([f._compiler_dispatch(self,
asfrom=True, fromhints=byfrom,
**kwargs)
for f in froms])
else:
text += ', '.join([f._compiler_dispatch(self,
asfrom=True, **kwargs)
text += ', '.join([f._compiler_dispatch(self,
asfrom=True, **kwargs)
for f in froms])
else:
text += self.default_from()
@@ -957,13 +1017,8 @@ class SQLCompiler(engine.Compiled):
text += self.for_update_clause(select)
if self.ctes and \
compound_index==1 and not entry:
cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
cte_text += ", \n".join(
[txt for txt in self.ctes.values()]
)
cte_text += "\n "
text = cte_text + text
compound_index == 0 and not entry:
text = self._render_cte_clause() + text
self.stack.pop(-1)
@@ -972,6 +1027,16 @@ class SQLCompiler(engine.Compiled):
else:
return text
def _render_cte_clause(self):
if self.positional:
self.positiontup = self.cte_positional + self.positiontup
cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
cte_text += ", \n".join(
[txt for txt in self.ctes.values()]
)
cte_text += "\n "
return cte_text
def get_cte_preamble(self, recursive):
if recursive:
return "WITH RECURSIVE"
@@ -1008,7 +1073,7 @@ class SQLCompiler(engine.Compiled):
text += " OFFSET " + self.process(sql.literal(select._offset))
return text
def visit_table(self, table, asfrom=False, ashint=False,
def visit_table(self, table, asfrom=False, ashint=False,
fromhints=None, **kwargs):
if asfrom or ashint:
if getattr(table, "schema", None):
@@ -1028,10 +1093,10 @@ class SQLCompiler(engine.Compiled):
def visit_join(self, join, asfrom=False, **kwargs):
return (
join.left._compiler_dispatch(self, asfrom=True, **kwargs) +
(join.isouter and " LEFT OUTER JOIN " or " JOIN ") +
join.right._compiler_dispatch(self, asfrom=True, **kwargs) +
" ON " +
join.left._compiler_dispatch(self, asfrom=True, **kwargs) +
(join.isouter and " LEFT OUTER JOIN " or " JOIN ") +
join.right._compiler_dispatch(self, asfrom=True, **kwargs) +
" ON " +
join.onclause._compiler_dispatch(self, **kwargs)
)
@@ -1043,7 +1108,7 @@ class SQLCompiler(engine.Compiled):
not self.dialect.supports_default_values and \
not self.dialect.supports_empty_insert:
raise exc.CompileError("The version of %s you are using does "
"not support empty inserts." %
"not support empty inserts." %
self.dialect.name)
preparer = self.preparer
@@ -1061,13 +1126,13 @@ class SQLCompiler(engine.Compiled):
if insert_stmt._hints:
dialect_hints = dict([
(table, hint_text)
for (table, dialect), hint_text in
for (table, dialect), hint_text in
insert_stmt._hints.items()
if dialect in ('*', self.dialect.name)
])
if insert_stmt.table in dialect_hints:
text += " " + self.get_crud_hint_text(
insert_stmt.table,
insert_stmt.table,
dialect_hints[insert_stmt.table]
)
@@ -1098,7 +1163,7 @@ class SQLCompiler(engine.Compiled):
"""Provide a hook for MySQL to add LIMIT to the UPDATE"""
return None
def update_tables_clause(self, update_stmt, from_table,
def update_tables_clause(self, update_stmt, from_table,
extra_froms, **kw):
"""Provide a hook to override the initial table clause
in an UPDATE statement.
@@ -1108,19 +1173,19 @@ class SQLCompiler(engine.Compiled):
"""
return self.preparer.format_table(from_table)
def update_from_clause(self, update_stmt,
from_table, extra_froms,
def update_from_clause(self, update_stmt,
from_table, extra_froms,
from_hints,
**kw):
"""Provide a hook to override the generation of an
"""Provide a hook to override the generation of an
UPDATE..FROM clause.
MySQL overrides this.
MySQL and MSSQL override this.
"""
return "FROM " + ', '.join(
t._compiler_dispatch(self, asfrom=True,
fromhints=from_hints, **kw)
t._compiler_dispatch(self, asfrom=True,
fromhints=from_hints, **kw)
for t in extra_froms)
def visit_update(self, update_stmt, **kw):
@@ -1133,20 +1198,20 @@ class SQLCompiler(engine.Compiled):
colparams = self._get_colparams(update_stmt, extra_froms)
text = "UPDATE " + self.update_tables_clause(
update_stmt,
update_stmt.table,
update_stmt,
update_stmt.table,
extra_froms, **kw)
if update_stmt._hints:
dialect_hints = dict([
(table, hint_text)
for (table, dialect), hint_text in
for (table, dialect), hint_text in
update_stmt._hints.items()
if dialect in ('*', self.dialect.name)
])
if update_stmt.table in dialect_hints:
text += " " + self.get_crud_hint_text(
update_stmt.table,
update_stmt.table,
dialect_hints[update_stmt.table]
)
else:
@@ -1155,12 +1220,12 @@ class SQLCompiler(engine.Compiled):
text += ' SET '
if extra_froms and self.render_table_with_column_in_update_from:
text += ', '.join(
self.visit_column(c[0]) +
self.visit_column(c[0]) +
'=' + c[1] for c in colparams
)
else:
text += ', '.join(
self.preparer.quote(c[0].name, c[0].quote) +
self.preparer.quote(c[0].name, c[0].quote) +
'=' + c[1] for c in colparams
)
@@ -1172,9 +1237,9 @@ class SQLCompiler(engine.Compiled):
if extra_froms:
extra_from_text = self.update_from_clause(
update_stmt,
update_stmt.table,
extra_froms,
update_stmt,
update_stmt.table,
extra_froms,
dialect_hints, **kw)
if extra_from_text:
text += " " + extra_from_text
@@ -1195,7 +1260,7 @@ class SQLCompiler(engine.Compiled):
return text
def _create_crud_bind_param(self, col, value, required=False):
bindparam = sql.bindparam(col.key, value,
bindparam = sql.bindparam(col.key, value,
type_=col.type, required=required)
bindparam._is_crud = True
return bindparam._compiler_dispatch(self)
@@ -1220,8 +1285,8 @@ class SQLCompiler(engine.Compiled):
# compiled params - return binds for all columns
if self.column_keys is None and stmt.parameters is None:
return [
(c, self._create_crud_bind_param(c,
None, required=True))
(c, self._create_crud_bind_param(c,
None, required=True))
for c in stmt.table.columns
]
@@ -1233,8 +1298,8 @@ class SQLCompiler(engine.Compiled):
parameters = {}
else:
parameters = dict((sql._column_as_key(key), required)
for key in self.column_keys
if not stmt.parameters or
for key in self.column_keys
if not stmt.parameters or
key not in stmt.parameters)
if stmt.parameters is not None:
@@ -1255,7 +1320,7 @@ class SQLCompiler(engine.Compiled):
postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid
check_columns = {}
# special logic that only occurs for multi-table UPDATE
# special logic that only occurs for multi-table UPDATE
# statements
if extra_tables and stmt.parameters:
assert self.isupdate
@@ -1274,7 +1339,7 @@ class SQLCompiler(engine.Compiled):
value = self.process(value.self_group())
values.append((c, value))
# determine tables which are actually
# to be updated - process onupdate and
# to be updated - process onupdate and
# server_onupdate for these
for t in affected_tables:
for c in t.c:
@@ -1295,7 +1360,7 @@ class SQLCompiler(engine.Compiled):
self.postfetch.append(c)
# iterating through columns at the top to maintain ordering.
# otherwise we might iterate through individual sets of
# otherwise we might iterate through individual sets of
# "defaults", "primary key cols", etc.
for c in stmt.table.columns:
if c.key in parameters and c.key not in check_columns:
@@ -1315,8 +1380,8 @@ class SQLCompiler(engine.Compiled):
if c.primary_key and \
need_pks and \
(
implicit_returning or
not postfetch_lastrowid or
implicit_returning or
not postfetch_lastrowid or
c is not stmt.table._autoincrement_column
):
@@ -1402,7 +1467,7 @@ class SQLCompiler(engine.Compiled):
).difference(check_columns)
if check:
util.warn(
"Unconsumed column names: %s" %
"Unconsumed column names: %s" %
(", ".join(check))
)
@@ -1417,13 +1482,13 @@ class SQLCompiler(engine.Compiled):
if delete_stmt._hints:
dialect_hints = dict([
(table, hint_text)
for (table, dialect), hint_text in
for (table, dialect), hint_text in
delete_stmt._hints.items()
if dialect in ('*', self.dialect.name)
])
if delete_stmt.table in dialect_hints:
text += " " + self.get_crud_hint_text(
delete_stmt.table,
delete_stmt.table,
dialect_hints[delete_stmt.table]
)
else:
@@ -1517,7 +1582,7 @@ class DDLCompiler(engine.Compiled):
text += separator
separator = ", \n"
text += "\t" + self.get_column_specification(
column,
column,
first_pk=column.primary_key and \
not first_pk
)
@@ -1529,16 +1594,16 @@ class DDLCompiler(engine.Compiled):
text += " " + const
except exc.CompileError, ce:
# Py3K
#raise exc.CompileError("(in table '%s', column '%s'): %s"
#raise exc.CompileError("(in table '%s', column '%s'): %s"
# % (
# table.description,
# column.name,
# table.description,
# column.name,
# ce.args[0]
# )) from ce
# Py2K
raise exc.CompileError("(in table '%s', column '%s'): %s"
raise exc.CompileError("(in table '%s', column '%s'): %s"
% (
table.description,
table.description,
column.name,
ce.args[0]
)), None, sys.exc_info()[2]
@@ -1559,17 +1624,17 @@ class DDLCompiler(engine.Compiled):
if table.primary_key:
constraints.append(table.primary_key)
constraints.extend([c for c in table._sorted_constraints
constraints.extend([c for c in table._sorted_constraints
if c is not table.primary_key])
return ", \n\t".join(p for p in
(self.process(constraint)
for constraint in constraints
(self.process(constraint)
for constraint in constraints
if (
constraint._create_rule is None or
constraint._create_rule(self))
and (
not self.dialect.supports_alter or
not self.dialect.supports_alter or
not getattr(constraint, 'use_alter', False)
)) if p is not None
)
@@ -1582,13 +1647,12 @@ class DDLCompiler(engine.Compiled):
max = self.dialect.max_index_name_length or \
self.dialect.max_identifier_length
if len(ident) > max:
return ident[0:max - 8] + \
ident = ident[0:max - 8] + \
"_" + util.md5_hex(ident)[-4:]
else:
return ident
else:
self.dialect.validate_identifier(ident)
return ident
return ident
def visit_create_index(self, create):
index = create.element
@@ -1597,7 +1661,7 @@ class DDLCompiler(engine.Compiled):
if index.unique:
text += "UNIQUE "
text += "INDEX %s ON %s (%s)" \
% (preparer.quote(self._index_identifier(index.name),
% (preparer.quote(self._index_identifier(index.name),
index.quote),
preparer.format_table(index.table),
', '.join(preparer.quote(c.name, c.quote)
@@ -1606,9 +1670,20 @@ class DDLCompiler(engine.Compiled):
def visit_drop_index(self, drop):
index = drop.element
return "\nDROP INDEX " + \
self.preparer.quote(
self._index_identifier(index.name), index.quote)
if index.table is not None and index.table.schema:
schema = index.table.schema
schema_name = self.preparer.quote_schema(schema,
index.table.quote_schema)
else:
schema_name = None
index_name = self.preparer.quote(
self._index_identifier(index.name),
index.quote)
if schema_name:
index_name = schema_name + "." + index_name
return "\nDROP INDEX " + index_name
def visit_add_constraint(self, create):
preparer = self.preparer
@@ -1723,7 +1798,7 @@ class DDLCompiler(engine.Compiled):
text += "CONSTRAINT %s " % \
self.preparer.format_constraint(constraint)
text += "UNIQUE (%s)" % (
', '.join(self.preparer.quote(c.name, c.quote)
', '.join(self.preparer.quote(c.name, c.quote)
for c in constraint))
text += self.define_constraint_deferrability(constraint)
return text
@@ -1769,7 +1844,7 @@ class GenericTypeCompiler(engine.TypeCompiler):
{'precision': type_.precision}
else:
return "NUMERIC(%(precision)s, %(scale)s)" % \
{'precision': type_.precision,
{'precision': type_.precision,
'scale' : type_.scale}
def visit_DECIMAL(self, type_):
@@ -1826,25 +1901,25 @@ class GenericTypeCompiler(engine.TypeCompiler):
def visit_large_binary(self, type_):
return self.visit_BLOB(type_)
def visit_boolean(self, type_):
def visit_boolean(self, type_):
return self.visit_BOOLEAN(type_)
def visit_time(self, type_):
def visit_time(self, type_):
return self.visit_TIME(type_)
def visit_datetime(self, type_):
def visit_datetime(self, type_):
return self.visit_DATETIME(type_)
def visit_date(self, type_):
def visit_date(self, type_):
return self.visit_DATE(type_)
def visit_big_integer(self, type_):
def visit_big_integer(self, type_):
return self.visit_BIGINT(type_)
def visit_small_integer(self, type_):
def visit_small_integer(self, type_):
return self.visit_SMALLINT(type_)
def visit_integer(self, type_):
def visit_integer(self, type_):
return self.visit_INTEGER(type_)
def visit_real(self, type_):
@@ -1853,19 +1928,19 @@ class GenericTypeCompiler(engine.TypeCompiler):
def visit_float(self, type_):
return self.visit_FLOAT(type_)
def visit_numeric(self, type_):
def visit_numeric(self, type_):
return self.visit_NUMERIC(type_)
def visit_string(self, type_):
def visit_string(self, type_):
return self.visit_VARCHAR(type_)
def visit_unicode(self, type_):
def visit_unicode(self, type_):
return self.visit_VARCHAR(type_)
def visit_text(self, type_):
def visit_text(self, type_):
return self.visit_TEXT(type_)
def visit_unicode_text(self, type_):
def visit_unicode_text(self, type_):
return self.visit_TEXT(type_)
def visit_enum(self, type_):
@@ -1889,7 +1964,7 @@ class IdentifierPreparer(object):
illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
def __init__(self, dialect, initial_quote='"',
def __init__(self, dialect, initial_quote='"',
final_quote=None, escape_quote='"', omit_schema=False):
"""Construct a new ``IdentifierPreparer`` object.
@@ -1953,7 +2028,7 @@ class IdentifierPreparer(object):
def quote_schema(self, schema, force):
"""Quote a schema.
Subclasses should override this to provide database-dependent
Subclasses should override this to provide database-dependent
quoting behavior.
"""
return self.quote(schema, force)
@@ -2010,7 +2085,7 @@ class IdentifierPreparer(object):
return self.quote(name, quote)
def format_column(self, column, use_table=False,
def format_column(self, column, use_table=False,
name=None, table_name=None):
"""Prepare a quoted column name."""
@@ -2019,14 +2094,14 @@ class IdentifierPreparer(object):
if not getattr(column, 'is_literal', False):
if use_table:
return self.format_table(
column.table, use_schema=False,
column.table, use_schema=False,
name=table_name) + "." + \
self.quote(name, column.quote)
else:
return self.quote(name, column.quote)
else:
# literal textual elements get stuck into ColumnClause alot,
# which shouldnt get quoted
# literal textual elements get stuck into ColumnClause a lot,
# which shouldn't get quoted
if use_table:
return self.format_table(column.table,
File diff suppressed because it is too large Load Diff
+3 -3
View File
@@ -1,5 +1,5 @@
# sql/functions.py
# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors <see AUTHORS file>
# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
@@ -33,11 +33,11 @@ class GenericFunction(Function):
class next_value(Function):
"""Represent the 'next value', given a :class:`.Sequence`
as it's single argument.
Compiles into the appropriate function on each backend,
or will raise NotImplementedError if used on a backend
that does not provide support for sequences.
"""
type = sqltypes.Integer()
name = "next_value"
+120 -90
View File
@@ -1,5 +1,5 @@
# sql/operators.py
# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors <see AUTHORS file>
# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
@@ -21,25 +21,25 @@ from sqlalchemy.util import symbol
class Operators(object):
"""Base of comparison and logical operators.
Implements base methods :meth:`operate` and :meth:`reverse_operate`,
as well as :meth:`__and__`, :meth:`__or__`, :meth:`__invert__`.
Usually is used via its most common subclass
:class:`.ColumnOperators`.
"""
def __and__(self, other):
"""Implement the ``&`` operator.
When used with SQL expressions, results in an
AND operation, equivalent to
:func:`~.expression.and_`, that is::
a & b
is equivalent to::
from sqlalchemy import and_
and_(a, b)
@@ -47,7 +47,7 @@ class Operators(object):
operator precedence; the ``&`` operator has the highest precedence.
The operands should be enclosed in parenthesis if they contain
further sub expressions::
(a == 2) & (b == 4)
"""
@@ -55,15 +55,15 @@ class Operators(object):
def __or__(self, other):
"""Implement the ``|`` operator.
When used with SQL expressions, results in an
OR operation, equivalent to
:func:`~.expression.or_`, that is::
a | b
is equivalent to::
from sqlalchemy import or_
or_(a, b)
@@ -71,7 +71,7 @@ class Operators(object):
operator precedence; the ``|`` operator has the highest precedence.
The operands should be enclosed in parenthesis if they contain
further sub expressions::
(a == 2) | (b == 4)
"""
@@ -79,15 +79,15 @@ class Operators(object):
def __invert__(self):
"""Implement the ``~`` operator.
When used with SQL expressions, results in a
NOT operation, equivalent to
When used with SQL expressions, results in a
NOT operation, equivalent to
:func:`~.expression.not_`, that is::
~a
is equivalent to::
from sqlalchemy import not_
not_(a)
@@ -123,16 +123,16 @@ class Operators(object):
def operate(self, op, *other, **kwargs):
"""Operate on an argument.
This is the lowest level of operation, raises
:class:`NotImplementedError` by default.
Overriding this on a subclass can allow common
behavior to be applied to all operations.
Overriding this on a subclass can allow common
behavior to be applied to all operations.
For example, overriding :class:`.ColumnOperators`
to apply ``func.lower()`` to the left and right
to apply ``func.lower()`` to the left and right
side::
class MyComparator(ColumnOperators):
def operate(self, op, other):
return op(func.lower(self), func.lower(other))
@@ -142,48 +142,48 @@ class Operators(object):
be a single scalar for most operations.
:param \**kwargs: modifiers. These may be passed by special
operators such as :meth:`ColumnOperators.contains`.
"""
raise NotImplementedError(str(op))
def reverse_operate(self, op, other, **kwargs):
"""Reverse operate on an argument.
Usage is the same as :meth:`operate`.
"""
raise NotImplementedError(str(op))
class ColumnOperators(Operators):
"""Defines comparison and math operations.
By default all methods call down to
:meth:`Operators.operate` or :meth:`Operators.reverse_operate`
passing in the appropriate operator function from the
passing in the appropriate operator function from the
Python builtin ``operator`` module or
a SQLAlchemy-specific operator function from
a SQLAlchemy-specific operator function from
:mod:`sqlalchemy.expression.operators`. For example
the ``__eq__`` function::
def __eq__(self, other):
return self.operate(operators.eq, other)
Where ``operators.eq`` is essentially::
def eq(a, b):
return a == b
A SQLAlchemy construct like :class:`.ColumnElement` ultimately
overrides :meth:`.Operators.operate` and others
to return further :class:`.ClauseElement` constructs,
to return further :class:`.ClauseElement` constructs,
so that the ``==`` operation above is replaced by a clause
construct.
The docstrings here will describe column-oriented
behavior of each operator. For ORM-based operators
on related objects and collections, see :class:`.RelationshipProperty.Comparator`.
"""
timetuple = None
@@ -191,17 +191,17 @@ class ColumnOperators(Operators):
def __lt__(self, other):
"""Implement the ``<`` operator.
In a column context, produces the clause ``a < b``.
"""
return self.operate(lt, other)
def __le__(self, other):
"""Implement the ``<=`` operator.
In a column context, produces the clause ``a <= b``.
"""
return self.operate(le, other)
@@ -209,7 +209,7 @@ class ColumnOperators(Operators):
def __eq__(self, other):
"""Implement the ``==`` operator.
In a column context, produces the clause ``a = b``.
If the target is ``None``, produces ``a IS NULL``.
@@ -221,98 +221,128 @@ class ColumnOperators(Operators):
In a column context, produces the clause ``a != b``.
If the target is ``None``, produces ``a IS NOT NULL``.
"""
return self.operate(ne, other)
def __gt__(self, other):
"""Implement the ``>`` operator.
In a column context, produces the clause ``a > b``.
"""
return self.operate(gt, other)
def __ge__(self, other):
"""Implement the ``>=`` operator.
In a column context, produces the clause ``a >= b``.
"""
return self.operate(ge, other)
def __neg__(self):
"""Implement the ``-`` operator.
In a column context, produces the clause ``-a``.
"""
return self.operate(neg)
def concat(self, other):
"""Implement the 'concat' operator.
In a column context, produces the clause ``a || b``,
or uses the ``concat()`` operator on MySQL.
"""
return self.operate(concat_op, other)
def like(self, other, escape=None):
"""Implement the ``like`` operator.
In a column context, produces the clause ``a LIKE other``.
"""
return self.operate(like_op, other, escape=escape)
def ilike(self, other, escape=None):
"""Implement the ``ilike`` operator.
In a column context, produces the clause ``a ILIKE other``.
"""
return self.operate(ilike_op, other, escape=escape)
def in_(self, other):
"""Implement the ``in`` operator.
In a column context, produces the clause ``a IN other``.
"other" may be a tuple/list of column expressions,
or a :func:`~.expression.select` construct.
"""
return self.operate(in_op, other)
def is_(self, other):
"""Implement the ``IS`` operator.
Normally, ``IS`` is generated automatically when comparing to a
value of ``None``, which resolves to ``NULL``. However, explicit
usage of ``IS`` may be desirable if comparing to boolean values
on certain platforms.
.. versionadded:: 0.7.9
.. seealso:: :meth:`.ColumnOperators.isnot`
"""
return self.operate(is_, other)
def isnot(self, other):
"""Implement the ``IS NOT`` operator.
Normally, ``IS NOT`` is generated automatically when comparing to a
value of ``None``, which resolves to ``NULL``. However, explicit
usage of ``IS NOT`` may be desirable if comparing to boolean values
on certain platforms.
.. versionadded:: 0.7.9
.. seealso:: :meth:`.ColumnOperators.is_`
"""
return self.operate(isnot, other)
def startswith(self, other, **kwargs):
"""Implement the ``startwith`` operator.
In a column context, produces the clause ``LIKE '<other>%'``
"""
return self.operate(startswith_op, other, **kwargs)
def endswith(self, other, **kwargs):
"""Implement the 'endswith' operator.
In a column context, produces the clause ``LIKE '%<other>'``
"""
return self.operate(endswith_op, other, **kwargs)
def contains(self, other, **kwargs):
"""Implement the 'contains' operator.
In a column context, produces the clause ``LIKE '%<other>%'``
"""
return self.operate(contains_op, other, **kwargs)
def match(self, other, **kwargs):
"""Implements the 'match' operator.
In a column context, this produces a MATCH clause, i.e.
``MATCH '<other>'``. The allowed contents of ``other``
In a column context, this produces a MATCH clause, i.e.
``MATCH '<other>'``. The allowed contents of ``other``
are database backend specific.
"""
@@ -347,7 +377,7 @@ class ColumnOperators(Operators):
"""Implement the ``+`` operator in reverse.
See :meth:`__add__`.
"""
return self.reverse_operate(add, other)
@@ -355,7 +385,7 @@ class ColumnOperators(Operators):
"""Implement the ``-`` operator in reverse.
See :meth:`__sub__`.
"""
return self.reverse_operate(sub, other)
@@ -363,7 +393,7 @@ class ColumnOperators(Operators):
"""Implement the ``*`` operator in reverse.
See :meth:`__mul__`.
"""
return self.reverse_operate(mul, other)
@@ -371,7 +401,7 @@ class ColumnOperators(Operators):
"""Implement the ``/`` operator in reverse.
See :meth:`__div__`.
"""
return self.reverse_operate(div, other)
@@ -386,61 +416,61 @@ class ColumnOperators(Operators):
def __add__(self, other):
"""Implement the ``+`` operator.
In a column context, produces the clause ``a + b``
if the parent object has non-string affinity.
If the parent object has a string affinity,
If the parent object has a string affinity,
produces the concatenation operator, ``a || b`` -
see :meth:`concat`.
"""
return self.operate(add, other)
def __sub__(self, other):
"""Implement the ``-`` operator.
In a column context, produces the clause ``a - b``.
"""
return self.operate(sub, other)
def __mul__(self, other):
"""Implement the ``*`` operator.
In a column context, produces the clause ``a * b``.
"""
return self.operate(mul, other)
def __div__(self, other):
"""Implement the ``/`` operator.
In a column context, produces the clause ``a / b``.
"""
return self.operate(div, other)
def __mod__(self, other):
"""Implement the ``%`` operator.
In a column context, produces the clause ``a % b``.
"""
return self.operate(mod, other)
def __truediv__(self, other):
"""Implement the ``//`` operator.
In a column context, produces the clause ``a / b``.
"""
return self.operate(truediv, other)
def __rtruediv__(self, other):
"""Implement the ``//`` operator in reverse.
See :meth:`__truediv__`.
"""
return self.reverse_operate(truediv, other)
@@ -469,13 +499,13 @@ def like_op(a, b, escape=None):
return a.like(b, escape=escape)
def notlike_op(a, b, escape=None):
raise NotImplementedError()
return ~a.like(b, escape=escape)
def ilike_op(a, b, escape=None):
return a.ilike(b, escape=escape)
def notilike_op(a, b, escape=None):
raise NotImplementedError()
return ~a.ilike(b, escape=escape)
def between_op(a, b, c):
return a.between(b, c)
@@ -484,7 +514,7 @@ def in_op(a, b):
return a.in_(b)
def notin_op(a, b):
raise NotImplementedError()
return ~a.in_(b)
def distinct_op(a):
return a.distinct()
@@ -525,7 +555,7 @@ def is_commutative(op):
return op in _commutative
def is_ordering_modifier(op):
return op in (asc_op, desc_op,
return op in (asc_op, desc_op,
nullsfirst_op, nullslast_op)
_associative = _commutative.union([concat_op, and_, or_])
+65 -53
View File
@@ -1,5 +1,5 @@
# sql/util.py
# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors <see AUTHORS file>
# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
@@ -27,8 +27,8 @@ def sort_tables(tables):
tuples.append((parent_table, child_table))
for table in tables:
visitors.traverse(table,
{'schema_visitor':True},
visitors.traverse(table,
{'schema_visitor':True},
{'foreign_key':visit_foreign_key})
tuples.extend(
@@ -38,9 +38,9 @@ def sort_tables(tables):
return list(topological.sort(tuples, tables))
def find_join_source(clauses, join_to):
"""Given a list of FROM clauses and a selectable,
return the first index and element from the list of
clauses which can be joined against the selectable. returns
"""Given a list of FROM clauses and a selectable,
return the first index and element from the list of
clauses which can be joined against the selectable. returns
None, None if no match is found.
e.g.::
@@ -62,8 +62,8 @@ def find_join_source(clauses, join_to):
else:
return None, None
def find_tables(clause, check_columns=False,
include_aliases=False, include_joins=False,
def find_tables(clause, check_columns=False,
include_aliases=False, include_joins=False,
include_selects=False, include_crud=False):
"""locate Table objects within the given expression."""
@@ -112,7 +112,7 @@ def unwrap_order_by(clause):
(
not isinstance(t, expression._UnaryExpression) or \
not operators.is_ordering_modifier(t.modifier)
):
):
cols.add(t)
else:
for c in t.get_children():
@@ -167,7 +167,7 @@ def _quote_ddl_expr(element):
class _repr_params(object):
"""A string view of bound parameters, truncating
display to the given number of 'multi' parameter sets.
"""
def __init__(self, params, batches):
self.params = params
@@ -187,7 +187,7 @@ class _repr_params(object):
def expression_as_ddl(clause):
"""Given a SQL expression, convert for usage in DDL, such as
"""Given a SQL expression, convert for usage in DDL, such as
CREATE INDEX and CHECK CONSTRAINT.
Converts bind params into quoted literals, column identifiers
@@ -259,7 +259,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None):
if left is None:
continue
for fk in sorted(
b.foreign_keys,
b.foreign_keys,
key=lambda fk:fk.parent._creation_order):
try:
col = fk.get_referent(left)
@@ -274,7 +274,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None):
constraints.add(fk.constraint)
if left is not b:
for fk in sorted(
left.foreign_keys,
left.foreign_keys,
key=lambda fk:fk.parent._creation_order):
try:
col = fk.get_referent(b)
@@ -317,12 +317,12 @@ def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None):
class Annotated(object):
"""clones a ClauseElement and applies an 'annotations' dictionary.
Unlike regular clones, this clone also mimics __hash__() and
Unlike regular clones, this clone also mimics __hash__() and
__cmp__() of the original element so that it takes its place
in hashed collections.
A reference to the original element is maintained, for the important
reason of keeping its hash value current. When GC'ed, the
reason of keeping its hash value current. When GC'ed, the
hash value may be reused, causing conflicts.
"""
@@ -338,13 +338,13 @@ class Annotated(object):
try:
cls = annotated_classes[element.__class__]
except KeyError:
cls = annotated_classes[element.__class__] = type.__new__(type,
"Annotated%s" % element.__class__.__name__,
cls = annotated_classes[element.__class__] = type.__new__(type,
"Annotated%s" % element.__class__.__name__,
(Annotated, element.__class__), {})
return object.__new__(cls)
def __init__(self, element, values):
# force FromClause to generate their internal
# force FromClause to generate their internal
# collections into __dict__
if isinstance(element, expression.FromClause):
element.c
@@ -404,22 +404,30 @@ for cls in expression.__dict__.values() + [schema.Column, schema.Table]:
exec "annotated_classes[cls] = Annotated%s" % (cls.__name__)
def _deep_annotate(element, annotations, exclude=None):
"""Deep copy the given ClauseElement, annotating each element with the given annotations dictionary.
"""Deep copy the given ClauseElement, annotating each element
with the given annotations dictionary.
Elements within the exclude collection will be cloned but not annotated.
"""
cloned = util.column_dict()
def clone(elem):
# check if element is present in the exclude list.
# take into account proxying relationships.
if exclude and \
if elem in cloned:
return cloned[elem]
elif exclude and \
hasattr(elem, 'proxy_set') and \
elem.proxy_set.intersection(exclude):
elem = elem._clone()
newelem = elem._clone()
elif annotations != elem._annotations:
elem = elem._annotate(annotations.copy())
elem._copy_internals(clone=clone)
return elem
newelem = elem._annotate(annotations)
else:
newelem = elem
newelem._copy_internals(clone=clone)
cloned[elem] = newelem
return newelem
if element is not None:
element = clone(element)
@@ -428,26 +436,30 @@ def _deep_annotate(element, annotations, exclude=None):
def _deep_deannotate(element):
"""Deep copy the given element, removing all annotations."""
cloned = util.column_dict()
def clone(elem):
elem = elem._deannotate()
elem._copy_internals(clone=clone)
return elem
if elem not in cloned:
newelem = elem._deannotate()
newelem._copy_internals(clone=clone)
cloned[elem] = newelem
return cloned[elem]
if element is not None:
element = clone(element)
return element
def _shallow_annotate(element, annotations):
"""Annotate the given ClauseElement and copy its internals so that
internal objects refer to the new annotated object.
def _shallow_annotate(element, annotations):
"""Annotate the given ClauseElement and copy its internals so that
internal objects refer to the new annotated object.
Basically used to apply a "dont traverse" annotation to a
selectable, without digging throughout the whole
structure wasting time.
"""
element = element._annotate(annotations)
element._copy_internals()
return element
Basically used to apply a "dont traverse" annotation to a
selectable, without digging throughout the whole
structure wasting time.
"""
element = element._annotate(annotations)
element._copy_internals()
return element
def splice_joins(left, right, stop_on=None):
if left is None:
@@ -526,7 +538,7 @@ def reduce_columns(columns, *clauses, **kw):
return expression.ColumnSet(columns.difference(omit))
def criterion_as_pairs(expression, consider_as_foreign_keys=None,
def criterion_as_pairs(expression, consider_as_foreign_keys=None,
consider_as_referenced_keys=None, any_operator=False):
"""traverse an expression and locate binary criterion pairs."""
@@ -544,20 +556,20 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None,
if consider_as_foreign_keys:
if binary.left in consider_as_foreign_keys and \
(binary.right is binary.left or
(binary.right is binary.left or
binary.right not in consider_as_foreign_keys):
pairs.append((binary.right, binary.left))
elif binary.right in consider_as_foreign_keys and \
(binary.left is binary.right or
(binary.left is binary.right or
binary.left not in consider_as_foreign_keys):
pairs.append((binary.left, binary.right))
elif consider_as_referenced_keys:
if binary.left in consider_as_referenced_keys and \
(binary.right is binary.left or
(binary.right is binary.left or
binary.right not in consider_as_referenced_keys):
pairs.append((binary.left, binary.right))
elif binary.right in consider_as_referenced_keys and \
(binary.left is binary.right or
(binary.left is binary.right or
binary.left not in consider_as_referenced_keys):
pairs.append((binary.right, binary.left))
else:
@@ -574,17 +586,17 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None,
def folded_equivalents(join, equivs=None):
"""Return a list of uniquely named columns.
The column list of the given Join will be narrowed
The column list of the given Join will be narrowed
down to a list of all equivalently-named,
equated columns folded into one column, where 'equated' means they are
equated to each other in the ON clause of this join.
This function is used by Join.select(fold_equivalents=True).
Deprecated. This function is used for a certain kind of
Deprecated. This function is used for a certain kind of
"polymorphic_union" which is designed to achieve joined
table inheritance where the base table has no "discriminator"
column; [ticket:1131] will provide a better way to
column; [ticket:1131] will provide a better way to
achieve this.
"""
@@ -679,12 +691,12 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
def _corresponding_column(self, col, require_embedded, _seen=util.EMPTY_SET):
newcol = self.selectable.corresponding_column(
col,
col,
require_embedded=require_embedded)
if newcol is None and col in self.equivalents and col not in _seen:
for equiv in self.equivalents[col]:
newcol = self._corresponding_column(equiv,
require_embedded=require_embedded,
newcol = self._corresponding_column(equiv,
require_embedded=require_embedded,
_seen=_seen.union([col]))
if newcol is not None:
return newcol
@@ -710,14 +722,14 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
class ColumnAdapter(ClauseAdapter):
"""Extends ClauseAdapter with extra utility functions.
Provides the ability to "wrap" this ClauseAdapter
Provides the ability to "wrap" this ClauseAdapter
around another, a columns dictionary which returns
adapted elements given an original, and an
adapted elements given an original, and an
adapted_row() factory.
"""
def __init__(self, selectable, equivalents=None,
chain_to=None, include=None,
def __init__(self, selectable, equivalents=None,
chain_to=None, include=None,
exclude=None, adapt_required=False):
ClauseAdapter.__init__(self, selectable, equivalents, include, exclude)
if chain_to:
@@ -753,7 +765,7 @@ class ColumnAdapter(ClauseAdapter):
c = c.label(None)
# adapt_required indicates that if we got the same column
# back which we put in (i.e. it passed through),
# back which we put in (i.e. it passed through),
# it's not correct. this is used by eagerloading which
# knows that all columns and expressions need to be adapted
# to a result row, and a "passthrough" is definitely targeting
+18 -18
View File
@@ -1,5 +1,5 @@
# sql/visitors.py
# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors <see AUTHORS file>
# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
@@ -8,15 +8,15 @@
SQLAlchemy schema and expression constructs rely on a Python-centric
version of the classic "visitor" pattern as the primary way in which
they apply functionality. The most common use of this pattern
is statement compilation, where individual expression classes match
up to rendering methods that produce a string result. Beyond this,
the visitor system is also used to inspect expressions for various
information and patterns, as well as for usage in
they apply functionality. The most common use of this pattern
is statement compilation, where individual expression classes match
up to rendering methods that produce a string result. Beyond this,
the visitor system is also used to inspect expressions for various
information and patterns, as well as for usage in
some kinds of expression transformation. Other kinds of transformation
use a non-visitor traversal system.
For many examples of how the visit system is used, see the
For many examples of how the visit system is used, see the
sqlalchemy.sql.util and the sqlalchemy.sql.compiler modules.
For an introduction to clause adaption, see
http://techspot.zzzeek.org/2008/01/23/expression-transformations/
@@ -28,18 +28,18 @@ import re
from sqlalchemy import util
import operator
__all__ = ['VisitableType', 'Visitable', 'ClauseVisitor',
'CloningVisitor', 'ReplacingCloningVisitor', 'iterate',
__all__ = ['VisitableType', 'Visitable', 'ClauseVisitor',
'CloningVisitor', 'ReplacingCloningVisitor', 'iterate',
'iterate_depthfirst', 'traverse_using', 'traverse',
'cloned_traverse', 'replacement_traverse']
class VisitableType(type):
"""Metaclass which assigns a `_compiler_dispatch` method to classes
having a `__visit_name__` attribute.
The _compiler_dispatch attribute becomes an instance method which
looks approximately like the following::
def _compiler_dispatch (self, visitor, **kw):
'''Look for an attribute named "visit_" + self.__visit_name__
on the visitor, and call it with the same kw params.'''
@@ -92,7 +92,7 @@ class Visitable(object):
__metaclass__ = VisitableType
class ClauseVisitor(object):
"""Base class for visitor objects which can traverse using
"""Base class for visitor objects which can traverse using
the traverse() function.
"""
@@ -144,7 +144,7 @@ class ClauseVisitor(object):
return self
class CloningVisitor(ClauseVisitor):
"""Base class for visitor objects which can traverse using
"""Base class for visitor objects which can traverse using
the cloned_traverse() function.
"""
@@ -160,7 +160,7 @@ class CloningVisitor(ClauseVisitor):
return cloned_traverse(obj, self.__traverse_options__, self._visitor_dict)
class ReplacingCloningVisitor(CloningVisitor):
"""Base class for visitor objects which can traverse using
"""Base class for visitor objects which can traverse using
the replacement_traverse() function.
"""
@@ -168,8 +168,8 @@ class ReplacingCloningVisitor(CloningVisitor):
def replace(self, elem):
"""receive pre-copied elements during a cloning traversal.
If the method returns a new element, the element is used
instead of creating a simple copy of the element. Traversal
If the method returns a new element, the element is used
instead of creating a simple copy of the element. Traversal
will halt on the newly returned element if it is re-encountered.
"""
return None
@@ -232,7 +232,7 @@ def traverse_depthfirst(obj, opts, visitors):
return traverse_using(iterate_depthfirst(obj, opts), obj, visitors)
def cloned_traverse(obj, opts, visitors):
"""clone the given expression structure, allowing
"""clone the given expression structure, allowing
modifications by visitors."""
cloned = util.column_dict()
@@ -256,7 +256,7 @@ def cloned_traverse(obj, opts, visitors):
def replacement_traverse(obj, opts, replace):
"""clone the given expression structure, allowing element
"""clone the given expression structure, allowing element
replacement by a given replacement function."""
cloned = util.column_dict()