Update SQLAlchemy

This commit is contained in:
Ruud
2013-06-14 11:00:06 +02:00
parent 267ecfacab
commit 4aa6700ceb
124 changed files with 6500 additions and 5207 deletions
+249 -174
View File
@@ -1,5 +1,5 @@
# sql/compiler.py
# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors <see AUTHORS file>
# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
@@ -17,7 +17,7 @@ strings
:class:`~sqlalchemy.sql.compiler.GenericTypeCompiler` - renders
type specification strings.
To generate user-defined SQL strings, see
To generate user-defined SQL strings, see
:module:`~sqlalchemy.ext.compiler`.
"""
@@ -29,6 +29,7 @@ from sqlalchemy.sql import operators, functions, util as sql_util, \
visitors
from sqlalchemy.sql import expression as sql
import decimal
import itertools
RESERVED_WORDS = set([
'all', 'analyse', 'analyze', 'and', 'any', 'array',
@@ -59,7 +60,7 @@ BIND_TEMPLATES = {
'pyformat':"%%(%(name)s)s",
'qmark':"?",
'format':"%%s",
'numeric':":%(position)s",
'numeric':":[_POSITION]",
'named':":%(name)s"
}
@@ -214,7 +215,7 @@ class SQLCompiler(engine.Compiled):
driver/DB enforces this
"""
def __init__(self, dialect, statement, column_keys=None,
def __init__(self, dialect, statement, column_keys=None,
inline=False, **kwargs):
"""Construct a new ``DefaultCompiler`` object.
@@ -252,16 +253,14 @@ class SQLCompiler(engine.Compiled):
# column targeting
self.result_map = {}
# collect CTEs to tack on top of a SELECT
self.ctes = util.OrderedDict()
self.ctes_recursive = False
# true if the paramstyle is positional
self.positional = dialect.positional
if self.positional:
self.positiontup = []
self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
self.ctes = None
# an IdentifierPreparer that formats the quoting of identifiers
self.preparer = dialect.identifier_preparer
self.label_length = dialect.label_length \
@@ -276,7 +275,29 @@ class SQLCompiler(engine.Compiled):
self.truncated_names = {}
engine.Compiled.__init__(self, dialect, statement, **kwargs)
if self.positional and dialect.paramstyle == 'numeric':
self._apply_numbered_params()
@util.memoized_instancemethod
def _init_cte_state(self):
"""Initialize collections related to CTEs only if
a CTE is located, to save on the overhead of
these collections otherwise.
"""
# collect CTEs to tack on top of a SELECT
self.ctes = util.OrderedDict()
self.ctes_by_name = {}
self.ctes_recursive = False
if self.positional:
self.cte_positional = []
def _apply_numbered_params(self):
poscount = itertools.count(1)
self.string = re.sub(
r'\[_POSITION\]',
lambda m:str(util.next(poscount)),
self.string)
@util.memoized_property
def _bind_processors(self):
@@ -309,11 +330,11 @@ class SQLCompiler(engine.Compiled):
if _group_number:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r, "
"in parameter group %d" %
"in parameter group %d" %
(bindparam.key, _group_number))
else:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r"
"A value is required for bind parameter %r"
% bindparam.key)
else:
pd[name] = bindparam.effective_value
@@ -325,18 +346,18 @@ class SQLCompiler(engine.Compiled):
if _group_number:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r, "
"in parameter group %d" %
"in parameter group %d" %
(bindparam.key, _group_number))
else:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r"
"A value is required for bind parameter %r"
% bindparam.key)
pd[self.bind_names[bindparam]] = bindparam.effective_value
return pd
@property
def params(self):
"""Return the bind param dictionary embedded into this
"""Return the bind param dictionary embedded into this
compiled object, for those values that are present."""
return self.construct_params(_check=False)
@@ -352,8 +373,8 @@ class SQLCompiler(engine.Compiled):
def visit_grouping(self, grouping, asfrom=False, **kwargs):
return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
def visit_label(self, label, result_map=None,
within_label_clause=False,
def visit_label(self, label, result_map=None,
within_label_clause=False,
within_columns_clause=False, **kw):
# only render labels within the columns clause
# or ORDER BY clause of a select. dialect-specific compilers
@@ -366,20 +387,20 @@ class SQLCompiler(engine.Compiled):
if result_map is not None:
result_map[labelname.lower()] = (
label.name,
(label, label.element, labelname, ) +
label.name,
(label, label.element, labelname, ) +
label._alt_names,
label.type)
return label.element._compiler_dispatch(self,
return label.element._compiler_dispatch(self,
within_columns_clause=True,
within_label_clause=True,
within_label_clause=True,
**kw) + \
OPERATORS[operators.as_] + \
self.preparer.format_label(label, labelname)
else:
return label.element._compiler_dispatch(self,
within_columns_clause=False,
return label.element._compiler_dispatch(self,
within_columns_clause=False,
**kw)
def visit_column(self, column, result_map=None, **kwargs):
@@ -393,8 +414,8 @@ class SQLCompiler(engine.Compiled):
name = self._truncated_identifier("colident", name)
if result_map is not None:
result_map[name.lower()] = (orig_name,
(column, name, column.key),
result_map[name.lower()] = (orig_name,
(column, name, column.key),
column.type)
if is_literal:
@@ -408,7 +429,7 @@ class SQLCompiler(engine.Compiled):
else:
if table.schema:
schema_prefix = self.preparer.quote_schema(
table.schema,
table.schema,
table.quote_schema) + '.'
else:
schema_prefix = ''
@@ -448,7 +469,7 @@ class SQLCompiler(engine.Compiled):
if name in textclause.bindparams:
return self.process(textclause.bindparams[name])
else:
return self.bindparam_string(name)
return self.bindparam_string(name, **kwargs)
# un-escape any \:params
return BIND_PARAMS_ESC.sub(lambda m: m.group(1),
@@ -472,8 +493,8 @@ class SQLCompiler(engine.Compiled):
else:
sep = OPERATORS[clauselist.operator]
return sep.join(
s for s in
(c._compiler_dispatch(self, **kwargs)
s for s in
(c._compiler_dispatch(self, **kwargs)
for c in clauselist.clauses)
if s)
@@ -499,21 +520,21 @@ class SQLCompiler(engine.Compiled):
cast.typeclause._compiler_dispatch(self, **kwargs))
def visit_over(self, over, **kwargs):
x ="%s OVER (" % over.func._compiler_dispatch(self, **kwargs)
if over.partition_by is not None:
x += "PARTITION BY %s" % \
over.partition_by._compiler_dispatch(self, **kwargs)
if over.order_by is not None:
x += " "
if over.order_by is not None:
x += "ORDER BY %s" % \
over.order_by._compiler_dispatch(self, **kwargs)
x += ")"
return x
return "%s OVER (%s)" % (
over.func._compiler_dispatch(self, **kwargs),
' '.join(
'%s BY %s' % (word, clause._compiler_dispatch(self, **kwargs))
for word, clause in (
('PARTITION', over.partition_by),
('ORDER', over.order_by)
)
if clause is not None and len(clause)
)
)
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
return "EXTRACT(%s FROM %s)" % (field,
return "EXTRACT(%s FROM %s)" % (field,
extract.expr._compiler_dispatch(self, **kwargs))
def visit_function(self, func, result_map=None, **kwargs):
@@ -526,7 +547,7 @@ class SQLCompiler(engine.Compiled):
else:
name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s")
return ".".join(list(func.packagenames) + [name]) % \
{'expr':self.function_argspec(func, **kwargs)}
{'expr': self.function_argspec(func, **kwargs)}
def visit_next_value_func(self, next_value, **kw):
return self.visit_sequence(next_value.sequence)
@@ -539,16 +560,17 @@ class SQLCompiler(engine.Compiled):
def function_argspec(self, func, **kwargs):
return func.clause_expr._compiler_dispatch(self, **kwargs)
def visit_compound_select(self, cs, asfrom=False,
parens=True, compound_index=1, **kwargs):
def visit_compound_select(self, cs, asfrom=False,
parens=True, compound_index=0, **kwargs):
entry = self.stack and self.stack[-1] or {}
self.stack.append({'from':entry.get('from', None), 'iswrapper':True})
self.stack.append({'from': entry.get('from', None),
'iswrapper': not entry})
keyword = self.compound_keywords.get(cs.keyword)
text = (" " + keyword + " ").join(
(c._compiler_dispatch(self,
asfrom=asfrom, parens=False,
(c._compiler_dispatch(self,
asfrom=asfrom, parens=False,
compound_index=i, **kwargs)
for i, c in enumerate(cs.selects))
)
@@ -562,6 +584,10 @@ class SQLCompiler(engine.Compiled):
text += (cs._limit is not None or cs._offset is not None) and \
self.limit_clause(cs) or ""
if self.ctes and \
compound_index == 0 and not entry:
text = self._render_cte_clause() + text
self.stack.pop(-1)
if asfrom and parens:
return "(" + text + ")"
@@ -585,8 +611,8 @@ class SQLCompiler(engine.Compiled):
return self._operator_dispatch(binary.operator,
binary,
lambda opstr: binary.left._compiler_dispatch(self, **kw) +
opstr +
lambda opstr: binary.left._compiler_dispatch(self, **kw) +
opstr +
binary.right._compiler_dispatch(
self, **kw),
**kw
@@ -595,36 +621,36 @@ class SQLCompiler(engine.Compiled):
def visit_like_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
return '%s LIKE %s' % (
binary.left._compiler_dispatch(self, **kw),
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
+ (escape and
+ (escape and
(' ESCAPE ' + self.render_literal_value(escape, None))
or '')
def visit_notlike_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
return '%s NOT LIKE %s' % (
binary.left._compiler_dispatch(self, **kw),
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
+ (escape and
+ (escape and
(' ESCAPE ' + self.render_literal_value(escape, None))
or '')
def visit_ilike_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
return 'lower(%s) LIKE lower(%s)' % (
binary.left._compiler_dispatch(self, **kw),
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
+ (escape and
+ (escape and
(' ESCAPE ' + self.render_literal_value(escape, None))
or '')
def visit_notilike_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
return 'lower(%s) NOT LIKE lower(%s)' % (
binary.left._compiler_dispatch(self, **kw),
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
+ (escape and
+ (escape and
(' ESCAPE ' + self.render_literal_value(escape, None))
or '')
@@ -668,7 +694,7 @@ class SQLCompiler(engine.Compiled):
"bindparam() name '%s' is reserved "
"for automatic usage in the VALUES or SET "
"clause of this "
"insert/update statement. Please use a "
"insert/update statement. Please use a "
"name other than column name when using bindparam() "
"with insert() or update() (for example, 'b_%s')."
% (bindparam.key, bindparam.key)
@@ -676,7 +702,7 @@ class SQLCompiler(engine.Compiled):
self.binds[bindparam.key] = self.binds[name] = bindparam
return self.bindparam_string(name)
return self.bindparam_string(name, **kwargs)
def render_literal_bindparam(self, bindparam, **kw):
value = bindparam.value
@@ -688,7 +714,7 @@ class SQLCompiler(engine.Compiled):
def render_literal_value(self, value, type_):
"""Render the value of a bind parameter as a quoted literal.
This is used for statement sections that do not accept bind paramters
This is used for statement sections that do not accept bind parameters
on the target driver/database.
This should be implemented by subclasses using the quoting services
@@ -746,20 +772,45 @@ class SQLCompiler(engine.Compiled):
self.anon_map[derived] = anonymous_counter + 1
return derived + "_" + str(anonymous_counter)
def bindparam_string(self, name):
def bindparam_string(self, name, positional_names=None, **kw):
if self.positional:
self.positiontup.append(name)
return self.bindtemplate % {
'name':name, 'position':len(self.positiontup)}
else:
return self.bindtemplate % {'name':name}
if positional_names is not None:
positional_names.append(name)
else:
self.positiontup.append(name)
return self.bindtemplate % {'name':name}
def visit_cte(self, cte, asfrom=False, ashint=False,
fromhints=None,
**kwargs):
self._init_cte_state()
if self.positional:
kwargs['positional_names'] = self.cte_positional
def visit_cte(self, cte, asfrom=False, ashint=False,
fromhints=None, **kwargs):
if isinstance(cte.name, sql._truncated_label):
cte_name = self._truncated_identifier("alias", cte.name)
else:
cte_name = cte.name
if cte_name in self.ctes_by_name:
existing_cte = self.ctes_by_name[cte_name]
# we've generated a same-named CTE that we are enclosed in,
# or this is the same CTE. just return the name.
if cte in existing_cte._restates or cte is existing_cte:
return cte_name
elif existing_cte in cte._restates:
# we've generated a same-named CTE that is
# enclosed in us - we take precedence, so
# discard the text for the "inner".
del self.ctes[existing_cte]
else:
raise exc.CompileError(
"Multiple, unrelated CTEs found with "
"the same name: %r" %
cte_name)
self.ctes_by_name[cte_name] = cte
if cte.cte_alias:
if isinstance(cte.cte_alias, sql._truncated_label):
cte_alias = self._truncated_identifier("alias", cte.cte_alias)
@@ -776,10 +827,13 @@ class SQLCompiler(engine.Compiled):
col_source = cte.original.selects[0]
else:
assert False
recur_cols = [c.key for c in util.unique_list(col_source.inner_columns)
recur_cols = [c for c in
util.unique_list(col_source.inner_columns)
if c is not None]
text += "(%s)" % (", ".join(recur_cols))
text += "(%s)" % (", ".join(
self.preparer.format_column(ident)
for ident in recur_cols))
text += " AS \n" + \
cte.original._compiler_dispatch(
self, asfrom=True, **kwargs
@@ -793,7 +847,7 @@ class SQLCompiler(engine.Compiled):
return self.preparer.format_alias(cte, cte_name)
return text
def visit_alias(self, alias, asfrom=False, ashint=False,
def visit_alias(self, alias, asfrom=False, ashint=False,
fromhints=None, **kwargs):
if asfrom or ashint:
if isinstance(alias.name, sql._truncated_label):
@@ -804,7 +858,7 @@ class SQLCompiler(engine.Compiled):
if ashint:
return self.preparer.format_alias(alias, alias_name)
elif asfrom:
ret = alias.original._compiler_dispatch(self,
ret = alias.original._compiler_dispatch(self,
asfrom=True, **kwargs) + \
" AS " + \
self.preparer.format_alias(alias, alias_name)
@@ -828,8 +882,8 @@ class SQLCompiler(engine.Compiled):
select.use_labels and \
column._label:
return _CompileLabel(
column,
column._label,
column,
column._label,
alt_names=(column._key_label, )
)
@@ -839,9 +893,9 @@ class SQLCompiler(engine.Compiled):
not column.is_literal and \
column.table is not None and \
not isinstance(column.table, sql.Select):
return _CompileLabel(column, sql._as_truncated(column.name),
return _CompileLabel(column, sql._as_truncated(column.name),
alt_names=(column.key,))
elif not isinstance(column,
elif not isinstance(column,
(sql._UnaryExpression, sql._TextClause)) \
and (not hasattr(column, 'name') or \
isinstance(column, sql.Function)):
@@ -858,9 +912,10 @@ class SQLCompiler(engine.Compiled):
def get_crud_hint_text(self, table, text):
return None
def visit_select(self, select, asfrom=False, parens=True,
iswrapper=False, fromhints=None,
compound_index=1, **kwargs):
def visit_select(self, select, asfrom=False, parens=True,
iswrapper=False, fromhints=None,
compound_index=0,
positional_names=None, **kwargs):
entry = self.stack and self.stack[-1] or {}
@@ -875,13 +930,18 @@ class SQLCompiler(engine.Compiled):
# to outermost if existingfroms: correlate_froms =
# correlate_froms.union(existingfroms)
self.stack.append({'from': correlate_froms, 'iswrapper'
: iswrapper})
populate_result_map = compound_index == 0 and (
not entry or \
entry.get('iswrapper', False)
)
if compound_index==1 and not entry or entry.get('iswrapper', False):
column_clause_args = {'result_map':self.result_map}
self.stack.append({'from': correlate_froms, 'iswrapper': iswrapper})
if populate_result_map:
column_clause_args = {'result_map': self.result_map,
'positional_names': positional_names}
else:
column_clause_args = {}
column_clause_args = {'positional_names': positional_names}
# the actual list of columns to print in the SELECT column list.
inner_columns = [
@@ -889,7 +949,7 @@ class SQLCompiler(engine.Compiled):
self.label_select_column(select, co, asfrom=asfrom).\
_compiler_dispatch(self,
within_columns_clause=True,
**column_clause_args)
**column_clause_args)
for co in util.unique_list(select.inner_columns)
]
if c is not None
@@ -902,9 +962,9 @@ class SQLCompiler(engine.Compiled):
(from_, hinttext % {
'name':from_._compiler_dispatch(
self, ashint=True)
})
for (from_, dialect), hinttext in
select._hints.iteritems()
})
for (from_, dialect), hinttext in
select._hints.iteritems()
if dialect in ('*', self.dialect.name)
])
hint_text = self.get_select_hint_text(byfrom)
@@ -913,7 +973,7 @@ class SQLCompiler(engine.Compiled):
if select._prefixes:
text += " ".join(
x._compiler_dispatch(self, **kwargs)
x._compiler_dispatch(self, **kwargs)
for x in select._prefixes) + " "
text += self.get_select_precolumns(select)
text += ', '.join(inner_columns)
@@ -922,13 +982,13 @@ class SQLCompiler(engine.Compiled):
text += " \nFROM "
if select._hints:
text += ', '.join([f._compiler_dispatch(self,
asfrom=True, fromhints=byfrom,
**kwargs)
text += ', '.join([f._compiler_dispatch(self,
asfrom=True, fromhints=byfrom,
**kwargs)
for f in froms])
else:
text += ', '.join([f._compiler_dispatch(self,
asfrom=True, **kwargs)
text += ', '.join([f._compiler_dispatch(self,
asfrom=True, **kwargs)
for f in froms])
else:
text += self.default_from()
@@ -957,13 +1017,8 @@ class SQLCompiler(engine.Compiled):
text += self.for_update_clause(select)
if self.ctes and \
compound_index==1 and not entry:
cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
cte_text += ", \n".join(
[txt for txt in self.ctes.values()]
)
cte_text += "\n "
text = cte_text + text
compound_index == 0 and not entry:
text = self._render_cte_clause() + text
self.stack.pop(-1)
@@ -972,6 +1027,16 @@ class SQLCompiler(engine.Compiled):
else:
return text
def _render_cte_clause(self):
if self.positional:
self.positiontup = self.cte_positional + self.positiontup
cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
cte_text += ", \n".join(
[txt for txt in self.ctes.values()]
)
cte_text += "\n "
return cte_text
def get_cte_preamble(self, recursive):
if recursive:
return "WITH RECURSIVE"
@@ -1008,7 +1073,7 @@ class SQLCompiler(engine.Compiled):
text += " OFFSET " + self.process(sql.literal(select._offset))
return text
def visit_table(self, table, asfrom=False, ashint=False,
def visit_table(self, table, asfrom=False, ashint=False,
fromhints=None, **kwargs):
if asfrom or ashint:
if getattr(table, "schema", None):
@@ -1028,10 +1093,10 @@ class SQLCompiler(engine.Compiled):
def visit_join(self, join, asfrom=False, **kwargs):
return (
join.left._compiler_dispatch(self, asfrom=True, **kwargs) +
(join.isouter and " LEFT OUTER JOIN " or " JOIN ") +
join.right._compiler_dispatch(self, asfrom=True, **kwargs) +
" ON " +
join.left._compiler_dispatch(self, asfrom=True, **kwargs) +
(join.isouter and " LEFT OUTER JOIN " or " JOIN ") +
join.right._compiler_dispatch(self, asfrom=True, **kwargs) +
" ON " +
join.onclause._compiler_dispatch(self, **kwargs)
)
@@ -1043,7 +1108,7 @@ class SQLCompiler(engine.Compiled):
not self.dialect.supports_default_values and \
not self.dialect.supports_empty_insert:
raise exc.CompileError("The version of %s you are using does "
"not support empty inserts." %
"not support empty inserts." %
self.dialect.name)
preparer = self.preparer
@@ -1061,13 +1126,13 @@ class SQLCompiler(engine.Compiled):
if insert_stmt._hints:
dialect_hints = dict([
(table, hint_text)
for (table, dialect), hint_text in
for (table, dialect), hint_text in
insert_stmt._hints.items()
if dialect in ('*', self.dialect.name)
])
if insert_stmt.table in dialect_hints:
text += " " + self.get_crud_hint_text(
insert_stmt.table,
insert_stmt.table,
dialect_hints[insert_stmt.table]
)
@@ -1098,7 +1163,7 @@ class SQLCompiler(engine.Compiled):
"""Provide a hook for MySQL to add LIMIT to the UPDATE"""
return None
def update_tables_clause(self, update_stmt, from_table,
def update_tables_clause(self, update_stmt, from_table,
extra_froms, **kw):
"""Provide a hook to override the initial table clause
in an UPDATE statement.
@@ -1108,19 +1173,19 @@ class SQLCompiler(engine.Compiled):
"""
return self.preparer.format_table(from_table)
def update_from_clause(self, update_stmt,
from_table, extra_froms,
def update_from_clause(self, update_stmt,
from_table, extra_froms,
from_hints,
**kw):
"""Provide a hook to override the generation of an
"""Provide a hook to override the generation of an
UPDATE..FROM clause.
MySQL overrides this.
MySQL and MSSQL override this.
"""
return "FROM " + ', '.join(
t._compiler_dispatch(self, asfrom=True,
fromhints=from_hints, **kw)
t._compiler_dispatch(self, asfrom=True,
fromhints=from_hints, **kw)
for t in extra_froms)
def visit_update(self, update_stmt, **kw):
@@ -1133,20 +1198,20 @@ class SQLCompiler(engine.Compiled):
colparams = self._get_colparams(update_stmt, extra_froms)
text = "UPDATE " + self.update_tables_clause(
update_stmt,
update_stmt.table,
update_stmt,
update_stmt.table,
extra_froms, **kw)
if update_stmt._hints:
dialect_hints = dict([
(table, hint_text)
for (table, dialect), hint_text in
for (table, dialect), hint_text in
update_stmt._hints.items()
if dialect in ('*', self.dialect.name)
])
if update_stmt.table in dialect_hints:
text += " " + self.get_crud_hint_text(
update_stmt.table,
update_stmt.table,
dialect_hints[update_stmt.table]
)
else:
@@ -1155,12 +1220,12 @@ class SQLCompiler(engine.Compiled):
text += ' SET '
if extra_froms and self.render_table_with_column_in_update_from:
text += ', '.join(
self.visit_column(c[0]) +
self.visit_column(c[0]) +
'=' + c[1] for c in colparams
)
else:
text += ', '.join(
self.preparer.quote(c[0].name, c[0].quote) +
self.preparer.quote(c[0].name, c[0].quote) +
'=' + c[1] for c in colparams
)
@@ -1172,9 +1237,9 @@ class SQLCompiler(engine.Compiled):
if extra_froms:
extra_from_text = self.update_from_clause(
update_stmt,
update_stmt.table,
extra_froms,
update_stmt,
update_stmt.table,
extra_froms,
dialect_hints, **kw)
if extra_from_text:
text += " " + extra_from_text
@@ -1195,7 +1260,7 @@ class SQLCompiler(engine.Compiled):
return text
def _create_crud_bind_param(self, col, value, required=False):
bindparam = sql.bindparam(col.key, value,
bindparam = sql.bindparam(col.key, value,
type_=col.type, required=required)
bindparam._is_crud = True
return bindparam._compiler_dispatch(self)
@@ -1220,8 +1285,8 @@ class SQLCompiler(engine.Compiled):
# compiled params - return binds for all columns
if self.column_keys is None and stmt.parameters is None:
return [
(c, self._create_crud_bind_param(c,
None, required=True))
(c, self._create_crud_bind_param(c,
None, required=True))
for c in stmt.table.columns
]
@@ -1233,8 +1298,8 @@ class SQLCompiler(engine.Compiled):
parameters = {}
else:
parameters = dict((sql._column_as_key(key), required)
for key in self.column_keys
if not stmt.parameters or
for key in self.column_keys
if not stmt.parameters or
key not in stmt.parameters)
if stmt.parameters is not None:
@@ -1255,7 +1320,7 @@ class SQLCompiler(engine.Compiled):
postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid
check_columns = {}
# special logic that only occurs for multi-table UPDATE
# special logic that only occurs for multi-table UPDATE
# statements
if extra_tables and stmt.parameters:
assert self.isupdate
@@ -1274,7 +1339,7 @@ class SQLCompiler(engine.Compiled):
value = self.process(value.self_group())
values.append((c, value))
# determine tables which are actually
# to be updated - process onupdate and
# to be updated - process onupdate and
# server_onupdate for these
for t in affected_tables:
for c in t.c:
@@ -1295,7 +1360,7 @@ class SQLCompiler(engine.Compiled):
self.postfetch.append(c)
# iterating through columns at the top to maintain ordering.
# otherwise we might iterate through individual sets of
# otherwise we might iterate through individual sets of
# "defaults", "primary key cols", etc.
for c in stmt.table.columns:
if c.key in parameters and c.key not in check_columns:
@@ -1315,8 +1380,8 @@ class SQLCompiler(engine.Compiled):
if c.primary_key and \
need_pks and \
(
implicit_returning or
not postfetch_lastrowid or
implicit_returning or
not postfetch_lastrowid or
c is not stmt.table._autoincrement_column
):
@@ -1402,7 +1467,7 @@ class SQLCompiler(engine.Compiled):
).difference(check_columns)
if check:
util.warn(
"Unconsumed column names: %s" %
"Unconsumed column names: %s" %
(", ".join(check))
)
@@ -1417,13 +1482,13 @@ class SQLCompiler(engine.Compiled):
if delete_stmt._hints:
dialect_hints = dict([
(table, hint_text)
for (table, dialect), hint_text in
for (table, dialect), hint_text in
delete_stmt._hints.items()
if dialect in ('*', self.dialect.name)
])
if delete_stmt.table in dialect_hints:
text += " " + self.get_crud_hint_text(
delete_stmt.table,
delete_stmt.table,
dialect_hints[delete_stmt.table]
)
else:
@@ -1517,7 +1582,7 @@ class DDLCompiler(engine.Compiled):
text += separator
separator = ", \n"
text += "\t" + self.get_column_specification(
column,
column,
first_pk=column.primary_key and \
not first_pk
)
@@ -1529,16 +1594,16 @@ class DDLCompiler(engine.Compiled):
text += " " + const
except exc.CompileError, ce:
# Py3K
#raise exc.CompileError("(in table '%s', column '%s'): %s"
#raise exc.CompileError("(in table '%s', column '%s'): %s"
# % (
# table.description,
# column.name,
# table.description,
# column.name,
# ce.args[0]
# )) from ce
# Py2K
raise exc.CompileError("(in table '%s', column '%s'): %s"
raise exc.CompileError("(in table '%s', column '%s'): %s"
% (
table.description,
table.description,
column.name,
ce.args[0]
)), None, sys.exc_info()[2]
@@ -1559,17 +1624,17 @@ class DDLCompiler(engine.Compiled):
if table.primary_key:
constraints.append(table.primary_key)
constraints.extend([c for c in table._sorted_constraints
constraints.extend([c for c in table._sorted_constraints
if c is not table.primary_key])
return ", \n\t".join(p for p in
(self.process(constraint)
for constraint in constraints
(self.process(constraint)
for constraint in constraints
if (
constraint._create_rule is None or
constraint._create_rule(self))
and (
not self.dialect.supports_alter or
not self.dialect.supports_alter or
not getattr(constraint, 'use_alter', False)
)) if p is not None
)
@@ -1582,13 +1647,12 @@ class DDLCompiler(engine.Compiled):
max = self.dialect.max_index_name_length or \
self.dialect.max_identifier_length
if len(ident) > max:
return ident[0:max - 8] + \
ident = ident[0:max - 8] + \
"_" + util.md5_hex(ident)[-4:]
else:
return ident
else:
self.dialect.validate_identifier(ident)
return ident
return ident
def visit_create_index(self, create):
index = create.element
@@ -1597,7 +1661,7 @@ class DDLCompiler(engine.Compiled):
if index.unique:
text += "UNIQUE "
text += "INDEX %s ON %s (%s)" \
% (preparer.quote(self._index_identifier(index.name),
% (preparer.quote(self._index_identifier(index.name),
index.quote),
preparer.format_table(index.table),
', '.join(preparer.quote(c.name, c.quote)
@@ -1606,9 +1670,20 @@ class DDLCompiler(engine.Compiled):
def visit_drop_index(self, drop):
index = drop.element
return "\nDROP INDEX " + \
self.preparer.quote(
self._index_identifier(index.name), index.quote)
if index.table is not None and index.table.schema:
schema = index.table.schema
schema_name = self.preparer.quote_schema(schema,
index.table.quote_schema)
else:
schema_name = None
index_name = self.preparer.quote(
self._index_identifier(index.name),
index.quote)
if schema_name:
index_name = schema_name + "." + index_name
return "\nDROP INDEX " + index_name
def visit_add_constraint(self, create):
preparer = self.preparer
@@ -1723,7 +1798,7 @@ class DDLCompiler(engine.Compiled):
text += "CONSTRAINT %s " % \
self.preparer.format_constraint(constraint)
text += "UNIQUE (%s)" % (
', '.join(self.preparer.quote(c.name, c.quote)
', '.join(self.preparer.quote(c.name, c.quote)
for c in constraint))
text += self.define_constraint_deferrability(constraint)
return text
@@ -1769,7 +1844,7 @@ class GenericTypeCompiler(engine.TypeCompiler):
{'precision': type_.precision}
else:
return "NUMERIC(%(precision)s, %(scale)s)" % \
{'precision': type_.precision,
{'precision': type_.precision,
'scale' : type_.scale}
def visit_DECIMAL(self, type_):
@@ -1826,25 +1901,25 @@ class GenericTypeCompiler(engine.TypeCompiler):
def visit_large_binary(self, type_):
return self.visit_BLOB(type_)
def visit_boolean(self, type_):
def visit_boolean(self, type_):
return self.visit_BOOLEAN(type_)
def visit_time(self, type_):
def visit_time(self, type_):
return self.visit_TIME(type_)
def visit_datetime(self, type_):
def visit_datetime(self, type_):
return self.visit_DATETIME(type_)
def visit_date(self, type_):
def visit_date(self, type_):
return self.visit_DATE(type_)
def visit_big_integer(self, type_):
def visit_big_integer(self, type_):
return self.visit_BIGINT(type_)
def visit_small_integer(self, type_):
def visit_small_integer(self, type_):
return self.visit_SMALLINT(type_)
def visit_integer(self, type_):
def visit_integer(self, type_):
return self.visit_INTEGER(type_)
def visit_real(self, type_):
@@ -1853,19 +1928,19 @@ class GenericTypeCompiler(engine.TypeCompiler):
def visit_float(self, type_):
return self.visit_FLOAT(type_)
def visit_numeric(self, type_):
def visit_numeric(self, type_):
return self.visit_NUMERIC(type_)
def visit_string(self, type_):
def visit_string(self, type_):
return self.visit_VARCHAR(type_)
def visit_unicode(self, type_):
def visit_unicode(self, type_):
return self.visit_VARCHAR(type_)
def visit_text(self, type_):
def visit_text(self, type_):
return self.visit_TEXT(type_)
def visit_unicode_text(self, type_):
def visit_unicode_text(self, type_):
return self.visit_TEXT(type_)
def visit_enum(self, type_):
@@ -1889,7 +1964,7 @@ class IdentifierPreparer(object):
illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
def __init__(self, dialect, initial_quote='"',
def __init__(self, dialect, initial_quote='"',
final_quote=None, escape_quote='"', omit_schema=False):
"""Construct a new ``IdentifierPreparer`` object.
@@ -1953,7 +2028,7 @@ class IdentifierPreparer(object):
def quote_schema(self, schema, force):
"""Quote a schema.
Subclasses should override this to provide database-dependent
Subclasses should override this to provide database-dependent
quoting behavior.
"""
return self.quote(schema, force)
@@ -2010,7 +2085,7 @@ class IdentifierPreparer(object):
return self.quote(name, quote)
def format_column(self, column, use_table=False,
def format_column(self, column, use_table=False,
name=None, table_name=None):
"""Prepare a quoted column name."""
@@ -2019,14 +2094,14 @@ class IdentifierPreparer(object):
if not getattr(column, 'is_literal', False):
if use_table:
return self.format_table(
column.table, use_schema=False,
column.table, use_schema=False,
name=table_name) + "." + \
self.quote(name, column.quote)
else:
return self.quote(name, column.quote)
else:
# literal textual elements get stuck into ColumnClause alot,
# which shouldnt get quoted
# literal textual elements get stuck into ColumnClause a lot,
# which shouldn't get quoted
if use_table:
return self.format_table(column.table,