Update SQLAlchemy
This commit is contained in:
+249
-174
@@ -1,5 +1,5 @@
|
||||
# sql/compiler.py
|
||||
# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
@@ -17,7 +17,7 @@ strings
|
||||
:class:`~sqlalchemy.sql.compiler.GenericTypeCompiler` - renders
|
||||
type specification strings.
|
||||
|
||||
To generate user-defined SQL strings, see
|
||||
To generate user-defined SQL strings, see
|
||||
:module:`~sqlalchemy.ext.compiler`.
|
||||
|
||||
"""
|
||||
@@ -29,6 +29,7 @@ from sqlalchemy.sql import operators, functions, util as sql_util, \
|
||||
visitors
|
||||
from sqlalchemy.sql import expression as sql
|
||||
import decimal
|
||||
import itertools
|
||||
|
||||
RESERVED_WORDS = set([
|
||||
'all', 'analyse', 'analyze', 'and', 'any', 'array',
|
||||
@@ -59,7 +60,7 @@ BIND_TEMPLATES = {
|
||||
'pyformat':"%%(%(name)s)s",
|
||||
'qmark':"?",
|
||||
'format':"%%s",
|
||||
'numeric':":%(position)s",
|
||||
'numeric':":[_POSITION]",
|
||||
'named':":%(name)s"
|
||||
}
|
||||
|
||||
@@ -214,7 +215,7 @@ class SQLCompiler(engine.Compiled):
|
||||
driver/DB enforces this
|
||||
"""
|
||||
|
||||
def __init__(self, dialect, statement, column_keys=None,
|
||||
def __init__(self, dialect, statement, column_keys=None,
|
||||
inline=False, **kwargs):
|
||||
"""Construct a new ``DefaultCompiler`` object.
|
||||
|
||||
@@ -252,16 +253,14 @@ class SQLCompiler(engine.Compiled):
|
||||
# column targeting
|
||||
self.result_map = {}
|
||||
|
||||
# collect CTEs to tack on top of a SELECT
|
||||
self.ctes = util.OrderedDict()
|
||||
self.ctes_recursive = False
|
||||
|
||||
# true if the paramstyle is positional
|
||||
self.positional = dialect.positional
|
||||
if self.positional:
|
||||
self.positiontup = []
|
||||
self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
|
||||
|
||||
self.ctes = None
|
||||
|
||||
# an IdentifierPreparer that formats the quoting of identifiers
|
||||
self.preparer = dialect.identifier_preparer
|
||||
self.label_length = dialect.label_length \
|
||||
@@ -276,7 +275,29 @@ class SQLCompiler(engine.Compiled):
|
||||
self.truncated_names = {}
|
||||
engine.Compiled.__init__(self, dialect, statement, **kwargs)
|
||||
|
||||
if self.positional and dialect.paramstyle == 'numeric':
|
||||
self._apply_numbered_params()
|
||||
|
||||
@util.memoized_instancemethod
|
||||
def _init_cte_state(self):
|
||||
"""Initialize collections related to CTEs only if
|
||||
a CTE is located, to save on the overhead of
|
||||
these collections otherwise.
|
||||
|
||||
"""
|
||||
# collect CTEs to tack on top of a SELECT
|
||||
self.ctes = util.OrderedDict()
|
||||
self.ctes_by_name = {}
|
||||
self.ctes_recursive = False
|
||||
if self.positional:
|
||||
self.cte_positional = []
|
||||
|
||||
def _apply_numbered_params(self):
|
||||
poscount = itertools.count(1)
|
||||
self.string = re.sub(
|
||||
r'\[_POSITION\]',
|
||||
lambda m:str(util.next(poscount)),
|
||||
self.string)
|
||||
|
||||
@util.memoized_property
|
||||
def _bind_processors(self):
|
||||
@@ -309,11 +330,11 @@ class SQLCompiler(engine.Compiled):
|
||||
if _group_number:
|
||||
raise exc.InvalidRequestError(
|
||||
"A value is required for bind parameter %r, "
|
||||
"in parameter group %d" %
|
||||
"in parameter group %d" %
|
||||
(bindparam.key, _group_number))
|
||||
else:
|
||||
raise exc.InvalidRequestError(
|
||||
"A value is required for bind parameter %r"
|
||||
"A value is required for bind parameter %r"
|
||||
% bindparam.key)
|
||||
else:
|
||||
pd[name] = bindparam.effective_value
|
||||
@@ -325,18 +346,18 @@ class SQLCompiler(engine.Compiled):
|
||||
if _group_number:
|
||||
raise exc.InvalidRequestError(
|
||||
"A value is required for bind parameter %r, "
|
||||
"in parameter group %d" %
|
||||
"in parameter group %d" %
|
||||
(bindparam.key, _group_number))
|
||||
else:
|
||||
raise exc.InvalidRequestError(
|
||||
"A value is required for bind parameter %r"
|
||||
"A value is required for bind parameter %r"
|
||||
% bindparam.key)
|
||||
pd[self.bind_names[bindparam]] = bindparam.effective_value
|
||||
return pd
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
"""Return the bind param dictionary embedded into this
|
||||
"""Return the bind param dictionary embedded into this
|
||||
compiled object, for those values that are present."""
|
||||
return self.construct_params(_check=False)
|
||||
|
||||
@@ -352,8 +373,8 @@ class SQLCompiler(engine.Compiled):
|
||||
def visit_grouping(self, grouping, asfrom=False, **kwargs):
|
||||
return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
|
||||
|
||||
def visit_label(self, label, result_map=None,
|
||||
within_label_clause=False,
|
||||
def visit_label(self, label, result_map=None,
|
||||
within_label_clause=False,
|
||||
within_columns_clause=False, **kw):
|
||||
# only render labels within the columns clause
|
||||
# or ORDER BY clause of a select. dialect-specific compilers
|
||||
@@ -366,20 +387,20 @@ class SQLCompiler(engine.Compiled):
|
||||
|
||||
if result_map is not None:
|
||||
result_map[labelname.lower()] = (
|
||||
label.name,
|
||||
(label, label.element, labelname, ) +
|
||||
label.name,
|
||||
(label, label.element, labelname, ) +
|
||||
label._alt_names,
|
||||
label.type)
|
||||
|
||||
return label.element._compiler_dispatch(self,
|
||||
return label.element._compiler_dispatch(self,
|
||||
within_columns_clause=True,
|
||||
within_label_clause=True,
|
||||
within_label_clause=True,
|
||||
**kw) + \
|
||||
OPERATORS[operators.as_] + \
|
||||
self.preparer.format_label(label, labelname)
|
||||
else:
|
||||
return label.element._compiler_dispatch(self,
|
||||
within_columns_clause=False,
|
||||
return label.element._compiler_dispatch(self,
|
||||
within_columns_clause=False,
|
||||
**kw)
|
||||
|
||||
def visit_column(self, column, result_map=None, **kwargs):
|
||||
@@ -393,8 +414,8 @@ class SQLCompiler(engine.Compiled):
|
||||
name = self._truncated_identifier("colident", name)
|
||||
|
||||
if result_map is not None:
|
||||
result_map[name.lower()] = (orig_name,
|
||||
(column, name, column.key),
|
||||
result_map[name.lower()] = (orig_name,
|
||||
(column, name, column.key),
|
||||
column.type)
|
||||
|
||||
if is_literal:
|
||||
@@ -408,7 +429,7 @@ class SQLCompiler(engine.Compiled):
|
||||
else:
|
||||
if table.schema:
|
||||
schema_prefix = self.preparer.quote_schema(
|
||||
table.schema,
|
||||
table.schema,
|
||||
table.quote_schema) + '.'
|
||||
else:
|
||||
schema_prefix = ''
|
||||
@@ -448,7 +469,7 @@ class SQLCompiler(engine.Compiled):
|
||||
if name in textclause.bindparams:
|
||||
return self.process(textclause.bindparams[name])
|
||||
else:
|
||||
return self.bindparam_string(name)
|
||||
return self.bindparam_string(name, **kwargs)
|
||||
|
||||
# un-escape any \:params
|
||||
return BIND_PARAMS_ESC.sub(lambda m: m.group(1),
|
||||
@@ -472,8 +493,8 @@ class SQLCompiler(engine.Compiled):
|
||||
else:
|
||||
sep = OPERATORS[clauselist.operator]
|
||||
return sep.join(
|
||||
s for s in
|
||||
(c._compiler_dispatch(self, **kwargs)
|
||||
s for s in
|
||||
(c._compiler_dispatch(self, **kwargs)
|
||||
for c in clauselist.clauses)
|
||||
if s)
|
||||
|
||||
@@ -499,21 +520,21 @@ class SQLCompiler(engine.Compiled):
|
||||
cast.typeclause._compiler_dispatch(self, **kwargs))
|
||||
|
||||
def visit_over(self, over, **kwargs):
|
||||
x ="%s OVER (" % over.func._compiler_dispatch(self, **kwargs)
|
||||
if over.partition_by is not None:
|
||||
x += "PARTITION BY %s" % \
|
||||
over.partition_by._compiler_dispatch(self, **kwargs)
|
||||
if over.order_by is not None:
|
||||
x += " "
|
||||
if over.order_by is not None:
|
||||
x += "ORDER BY %s" % \
|
||||
over.order_by._compiler_dispatch(self, **kwargs)
|
||||
x += ")"
|
||||
return x
|
||||
return "%s OVER (%s)" % (
|
||||
over.func._compiler_dispatch(self, **kwargs),
|
||||
' '.join(
|
||||
'%s BY %s' % (word, clause._compiler_dispatch(self, **kwargs))
|
||||
for word, clause in (
|
||||
('PARTITION', over.partition_by),
|
||||
('ORDER', over.order_by)
|
||||
)
|
||||
if clause is not None and len(clause)
|
||||
)
|
||||
)
|
||||
|
||||
def visit_extract(self, extract, **kwargs):
|
||||
field = self.extract_map.get(extract.field, extract.field)
|
||||
return "EXTRACT(%s FROM %s)" % (field,
|
||||
return "EXTRACT(%s FROM %s)" % (field,
|
||||
extract.expr._compiler_dispatch(self, **kwargs))
|
||||
|
||||
def visit_function(self, func, result_map=None, **kwargs):
|
||||
@@ -526,7 +547,7 @@ class SQLCompiler(engine.Compiled):
|
||||
else:
|
||||
name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s")
|
||||
return ".".join(list(func.packagenames) + [name]) % \
|
||||
{'expr':self.function_argspec(func, **kwargs)}
|
||||
{'expr': self.function_argspec(func, **kwargs)}
|
||||
|
||||
def visit_next_value_func(self, next_value, **kw):
|
||||
return self.visit_sequence(next_value.sequence)
|
||||
@@ -539,16 +560,17 @@ class SQLCompiler(engine.Compiled):
|
||||
def function_argspec(self, func, **kwargs):
|
||||
return func.clause_expr._compiler_dispatch(self, **kwargs)
|
||||
|
||||
def visit_compound_select(self, cs, asfrom=False,
|
||||
parens=True, compound_index=1, **kwargs):
|
||||
def visit_compound_select(self, cs, asfrom=False,
|
||||
parens=True, compound_index=0, **kwargs):
|
||||
entry = self.stack and self.stack[-1] or {}
|
||||
self.stack.append({'from':entry.get('from', None), 'iswrapper':True})
|
||||
self.stack.append({'from': entry.get('from', None),
|
||||
'iswrapper': not entry})
|
||||
|
||||
keyword = self.compound_keywords.get(cs.keyword)
|
||||
|
||||
text = (" " + keyword + " ").join(
|
||||
(c._compiler_dispatch(self,
|
||||
asfrom=asfrom, parens=False,
|
||||
(c._compiler_dispatch(self,
|
||||
asfrom=asfrom, parens=False,
|
||||
compound_index=i, **kwargs)
|
||||
for i, c in enumerate(cs.selects))
|
||||
)
|
||||
@@ -562,6 +584,10 @@ class SQLCompiler(engine.Compiled):
|
||||
text += (cs._limit is not None or cs._offset is not None) and \
|
||||
self.limit_clause(cs) or ""
|
||||
|
||||
if self.ctes and \
|
||||
compound_index == 0 and not entry:
|
||||
text = self._render_cte_clause() + text
|
||||
|
||||
self.stack.pop(-1)
|
||||
if asfrom and parens:
|
||||
return "(" + text + ")"
|
||||
@@ -585,8 +611,8 @@ class SQLCompiler(engine.Compiled):
|
||||
|
||||
return self._operator_dispatch(binary.operator,
|
||||
binary,
|
||||
lambda opstr: binary.left._compiler_dispatch(self, **kw) +
|
||||
opstr +
|
||||
lambda opstr: binary.left._compiler_dispatch(self, **kw) +
|
||||
opstr +
|
||||
binary.right._compiler_dispatch(
|
||||
self, **kw),
|
||||
**kw
|
||||
@@ -595,36 +621,36 @@ class SQLCompiler(engine.Compiled):
|
||||
def visit_like_op(self, binary, **kw):
|
||||
escape = binary.modifiers.get("escape", None)
|
||||
return '%s LIKE %s' % (
|
||||
binary.left._compiler_dispatch(self, **kw),
|
||||
binary.left._compiler_dispatch(self, **kw),
|
||||
binary.right._compiler_dispatch(self, **kw)) \
|
||||
+ (escape and
|
||||
+ (escape and
|
||||
(' ESCAPE ' + self.render_literal_value(escape, None))
|
||||
or '')
|
||||
|
||||
def visit_notlike_op(self, binary, **kw):
|
||||
escape = binary.modifiers.get("escape", None)
|
||||
return '%s NOT LIKE %s' % (
|
||||
binary.left._compiler_dispatch(self, **kw),
|
||||
binary.left._compiler_dispatch(self, **kw),
|
||||
binary.right._compiler_dispatch(self, **kw)) \
|
||||
+ (escape and
|
||||
+ (escape and
|
||||
(' ESCAPE ' + self.render_literal_value(escape, None))
|
||||
or '')
|
||||
|
||||
def visit_ilike_op(self, binary, **kw):
|
||||
escape = binary.modifiers.get("escape", None)
|
||||
return 'lower(%s) LIKE lower(%s)' % (
|
||||
binary.left._compiler_dispatch(self, **kw),
|
||||
binary.left._compiler_dispatch(self, **kw),
|
||||
binary.right._compiler_dispatch(self, **kw)) \
|
||||
+ (escape and
|
||||
+ (escape and
|
||||
(' ESCAPE ' + self.render_literal_value(escape, None))
|
||||
or '')
|
||||
|
||||
def visit_notilike_op(self, binary, **kw):
|
||||
escape = binary.modifiers.get("escape", None)
|
||||
return 'lower(%s) NOT LIKE lower(%s)' % (
|
||||
binary.left._compiler_dispatch(self, **kw),
|
||||
binary.left._compiler_dispatch(self, **kw),
|
||||
binary.right._compiler_dispatch(self, **kw)) \
|
||||
+ (escape and
|
||||
+ (escape and
|
||||
(' ESCAPE ' + self.render_literal_value(escape, None))
|
||||
or '')
|
||||
|
||||
@@ -668,7 +694,7 @@ class SQLCompiler(engine.Compiled):
|
||||
"bindparam() name '%s' is reserved "
|
||||
"for automatic usage in the VALUES or SET "
|
||||
"clause of this "
|
||||
"insert/update statement. Please use a "
|
||||
"insert/update statement. Please use a "
|
||||
"name other than column name when using bindparam() "
|
||||
"with insert() or update() (for example, 'b_%s')."
|
||||
% (bindparam.key, bindparam.key)
|
||||
@@ -676,7 +702,7 @@ class SQLCompiler(engine.Compiled):
|
||||
|
||||
self.binds[bindparam.key] = self.binds[name] = bindparam
|
||||
|
||||
return self.bindparam_string(name)
|
||||
return self.bindparam_string(name, **kwargs)
|
||||
|
||||
def render_literal_bindparam(self, bindparam, **kw):
|
||||
value = bindparam.value
|
||||
@@ -688,7 +714,7 @@ class SQLCompiler(engine.Compiled):
|
||||
def render_literal_value(self, value, type_):
|
||||
"""Render the value of a bind parameter as a quoted literal.
|
||||
|
||||
This is used for statement sections that do not accept bind paramters
|
||||
This is used for statement sections that do not accept bind parameters
|
||||
on the target driver/database.
|
||||
|
||||
This should be implemented by subclasses using the quoting services
|
||||
@@ -746,20 +772,45 @@ class SQLCompiler(engine.Compiled):
|
||||
self.anon_map[derived] = anonymous_counter + 1
|
||||
return derived + "_" + str(anonymous_counter)
|
||||
|
||||
def bindparam_string(self, name):
|
||||
def bindparam_string(self, name, positional_names=None, **kw):
|
||||
if self.positional:
|
||||
self.positiontup.append(name)
|
||||
return self.bindtemplate % {
|
||||
'name':name, 'position':len(self.positiontup)}
|
||||
else:
|
||||
return self.bindtemplate % {'name':name}
|
||||
if positional_names is not None:
|
||||
positional_names.append(name)
|
||||
else:
|
||||
self.positiontup.append(name)
|
||||
return self.bindtemplate % {'name':name}
|
||||
|
||||
def visit_cte(self, cte, asfrom=False, ashint=False,
|
||||
fromhints=None,
|
||||
**kwargs):
|
||||
self._init_cte_state()
|
||||
if self.positional:
|
||||
kwargs['positional_names'] = self.cte_positional
|
||||
|
||||
def visit_cte(self, cte, asfrom=False, ashint=False,
|
||||
fromhints=None, **kwargs):
|
||||
if isinstance(cte.name, sql._truncated_label):
|
||||
cte_name = self._truncated_identifier("alias", cte.name)
|
||||
else:
|
||||
cte_name = cte.name
|
||||
|
||||
if cte_name in self.ctes_by_name:
|
||||
existing_cte = self.ctes_by_name[cte_name]
|
||||
# we've generated a same-named CTE that we are enclosed in,
|
||||
# or this is the same CTE. just return the name.
|
||||
if cte in existing_cte._restates or cte is existing_cte:
|
||||
return cte_name
|
||||
elif existing_cte in cte._restates:
|
||||
# we've generated a same-named CTE that is
|
||||
# enclosed in us - we take precedence, so
|
||||
# discard the text for the "inner".
|
||||
del self.ctes[existing_cte]
|
||||
else:
|
||||
raise exc.CompileError(
|
||||
"Multiple, unrelated CTEs found with "
|
||||
"the same name: %r" %
|
||||
cte_name)
|
||||
|
||||
self.ctes_by_name[cte_name] = cte
|
||||
|
||||
if cte.cte_alias:
|
||||
if isinstance(cte.cte_alias, sql._truncated_label):
|
||||
cte_alias = self._truncated_identifier("alias", cte.cte_alias)
|
||||
@@ -776,10 +827,13 @@ class SQLCompiler(engine.Compiled):
|
||||
col_source = cte.original.selects[0]
|
||||
else:
|
||||
assert False
|
||||
recur_cols = [c.key for c in util.unique_list(col_source.inner_columns)
|
||||
recur_cols = [c for c in
|
||||
util.unique_list(col_source.inner_columns)
|
||||
if c is not None]
|
||||
|
||||
text += "(%s)" % (", ".join(recur_cols))
|
||||
text += "(%s)" % (", ".join(
|
||||
self.preparer.format_column(ident)
|
||||
for ident in recur_cols))
|
||||
text += " AS \n" + \
|
||||
cte.original._compiler_dispatch(
|
||||
self, asfrom=True, **kwargs
|
||||
@@ -793,7 +847,7 @@ class SQLCompiler(engine.Compiled):
|
||||
return self.preparer.format_alias(cte, cte_name)
|
||||
return text
|
||||
|
||||
def visit_alias(self, alias, asfrom=False, ashint=False,
|
||||
def visit_alias(self, alias, asfrom=False, ashint=False,
|
||||
fromhints=None, **kwargs):
|
||||
if asfrom or ashint:
|
||||
if isinstance(alias.name, sql._truncated_label):
|
||||
@@ -804,7 +858,7 @@ class SQLCompiler(engine.Compiled):
|
||||
if ashint:
|
||||
return self.preparer.format_alias(alias, alias_name)
|
||||
elif asfrom:
|
||||
ret = alias.original._compiler_dispatch(self,
|
||||
ret = alias.original._compiler_dispatch(self,
|
||||
asfrom=True, **kwargs) + \
|
||||
" AS " + \
|
||||
self.preparer.format_alias(alias, alias_name)
|
||||
@@ -828,8 +882,8 @@ class SQLCompiler(engine.Compiled):
|
||||
select.use_labels and \
|
||||
column._label:
|
||||
return _CompileLabel(
|
||||
column,
|
||||
column._label,
|
||||
column,
|
||||
column._label,
|
||||
alt_names=(column._key_label, )
|
||||
)
|
||||
|
||||
@@ -839,9 +893,9 @@ class SQLCompiler(engine.Compiled):
|
||||
not column.is_literal and \
|
||||
column.table is not None and \
|
||||
not isinstance(column.table, sql.Select):
|
||||
return _CompileLabel(column, sql._as_truncated(column.name),
|
||||
return _CompileLabel(column, sql._as_truncated(column.name),
|
||||
alt_names=(column.key,))
|
||||
elif not isinstance(column,
|
||||
elif not isinstance(column,
|
||||
(sql._UnaryExpression, sql._TextClause)) \
|
||||
and (not hasattr(column, 'name') or \
|
||||
isinstance(column, sql.Function)):
|
||||
@@ -858,9 +912,10 @@ class SQLCompiler(engine.Compiled):
|
||||
def get_crud_hint_text(self, table, text):
|
||||
return None
|
||||
|
||||
def visit_select(self, select, asfrom=False, parens=True,
|
||||
iswrapper=False, fromhints=None,
|
||||
compound_index=1, **kwargs):
|
||||
def visit_select(self, select, asfrom=False, parens=True,
|
||||
iswrapper=False, fromhints=None,
|
||||
compound_index=0,
|
||||
positional_names=None, **kwargs):
|
||||
|
||||
entry = self.stack and self.stack[-1] or {}
|
||||
|
||||
@@ -875,13 +930,18 @@ class SQLCompiler(engine.Compiled):
|
||||
# to outermost if existingfroms: correlate_froms =
|
||||
# correlate_froms.union(existingfroms)
|
||||
|
||||
self.stack.append({'from': correlate_froms, 'iswrapper'
|
||||
: iswrapper})
|
||||
populate_result_map = compound_index == 0 and (
|
||||
not entry or \
|
||||
entry.get('iswrapper', False)
|
||||
)
|
||||
|
||||
if compound_index==1 and not entry or entry.get('iswrapper', False):
|
||||
column_clause_args = {'result_map':self.result_map}
|
||||
self.stack.append({'from': correlate_froms, 'iswrapper': iswrapper})
|
||||
|
||||
if populate_result_map:
|
||||
column_clause_args = {'result_map': self.result_map,
|
||||
'positional_names': positional_names}
|
||||
else:
|
||||
column_clause_args = {}
|
||||
column_clause_args = {'positional_names': positional_names}
|
||||
|
||||
# the actual list of columns to print in the SELECT column list.
|
||||
inner_columns = [
|
||||
@@ -889,7 +949,7 @@ class SQLCompiler(engine.Compiled):
|
||||
self.label_select_column(select, co, asfrom=asfrom).\
|
||||
_compiler_dispatch(self,
|
||||
within_columns_clause=True,
|
||||
**column_clause_args)
|
||||
**column_clause_args)
|
||||
for co in util.unique_list(select.inner_columns)
|
||||
]
|
||||
if c is not None
|
||||
@@ -902,9 +962,9 @@ class SQLCompiler(engine.Compiled):
|
||||
(from_, hinttext % {
|
||||
'name':from_._compiler_dispatch(
|
||||
self, ashint=True)
|
||||
})
|
||||
for (from_, dialect), hinttext in
|
||||
select._hints.iteritems()
|
||||
})
|
||||
for (from_, dialect), hinttext in
|
||||
select._hints.iteritems()
|
||||
if dialect in ('*', self.dialect.name)
|
||||
])
|
||||
hint_text = self.get_select_hint_text(byfrom)
|
||||
@@ -913,7 +973,7 @@ class SQLCompiler(engine.Compiled):
|
||||
|
||||
if select._prefixes:
|
||||
text += " ".join(
|
||||
x._compiler_dispatch(self, **kwargs)
|
||||
x._compiler_dispatch(self, **kwargs)
|
||||
for x in select._prefixes) + " "
|
||||
text += self.get_select_precolumns(select)
|
||||
text += ', '.join(inner_columns)
|
||||
@@ -922,13 +982,13 @@ class SQLCompiler(engine.Compiled):
|
||||
text += " \nFROM "
|
||||
|
||||
if select._hints:
|
||||
text += ', '.join([f._compiler_dispatch(self,
|
||||
asfrom=True, fromhints=byfrom,
|
||||
**kwargs)
|
||||
text += ', '.join([f._compiler_dispatch(self,
|
||||
asfrom=True, fromhints=byfrom,
|
||||
**kwargs)
|
||||
for f in froms])
|
||||
else:
|
||||
text += ', '.join([f._compiler_dispatch(self,
|
||||
asfrom=True, **kwargs)
|
||||
text += ', '.join([f._compiler_dispatch(self,
|
||||
asfrom=True, **kwargs)
|
||||
for f in froms])
|
||||
else:
|
||||
text += self.default_from()
|
||||
@@ -957,13 +1017,8 @@ class SQLCompiler(engine.Compiled):
|
||||
text += self.for_update_clause(select)
|
||||
|
||||
if self.ctes and \
|
||||
compound_index==1 and not entry:
|
||||
cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
|
||||
cte_text += ", \n".join(
|
||||
[txt for txt in self.ctes.values()]
|
||||
)
|
||||
cte_text += "\n "
|
||||
text = cte_text + text
|
||||
compound_index == 0 and not entry:
|
||||
text = self._render_cte_clause() + text
|
||||
|
||||
self.stack.pop(-1)
|
||||
|
||||
@@ -972,6 +1027,16 @@ class SQLCompiler(engine.Compiled):
|
||||
else:
|
||||
return text
|
||||
|
||||
def _render_cte_clause(self):
|
||||
if self.positional:
|
||||
self.positiontup = self.cte_positional + self.positiontup
|
||||
cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
|
||||
cte_text += ", \n".join(
|
||||
[txt for txt in self.ctes.values()]
|
||||
)
|
||||
cte_text += "\n "
|
||||
return cte_text
|
||||
|
||||
def get_cte_preamble(self, recursive):
|
||||
if recursive:
|
||||
return "WITH RECURSIVE"
|
||||
@@ -1008,7 +1073,7 @@ class SQLCompiler(engine.Compiled):
|
||||
text += " OFFSET " + self.process(sql.literal(select._offset))
|
||||
return text
|
||||
|
||||
def visit_table(self, table, asfrom=False, ashint=False,
|
||||
def visit_table(self, table, asfrom=False, ashint=False,
|
||||
fromhints=None, **kwargs):
|
||||
if asfrom or ashint:
|
||||
if getattr(table, "schema", None):
|
||||
@@ -1028,10 +1093,10 @@ class SQLCompiler(engine.Compiled):
|
||||
|
||||
def visit_join(self, join, asfrom=False, **kwargs):
|
||||
return (
|
||||
join.left._compiler_dispatch(self, asfrom=True, **kwargs) +
|
||||
(join.isouter and " LEFT OUTER JOIN " or " JOIN ") +
|
||||
join.right._compiler_dispatch(self, asfrom=True, **kwargs) +
|
||||
" ON " +
|
||||
join.left._compiler_dispatch(self, asfrom=True, **kwargs) +
|
||||
(join.isouter and " LEFT OUTER JOIN " or " JOIN ") +
|
||||
join.right._compiler_dispatch(self, asfrom=True, **kwargs) +
|
||||
" ON " +
|
||||
join.onclause._compiler_dispatch(self, **kwargs)
|
||||
)
|
||||
|
||||
@@ -1043,7 +1108,7 @@ class SQLCompiler(engine.Compiled):
|
||||
not self.dialect.supports_default_values and \
|
||||
not self.dialect.supports_empty_insert:
|
||||
raise exc.CompileError("The version of %s you are using does "
|
||||
"not support empty inserts." %
|
||||
"not support empty inserts." %
|
||||
self.dialect.name)
|
||||
|
||||
preparer = self.preparer
|
||||
@@ -1061,13 +1126,13 @@ class SQLCompiler(engine.Compiled):
|
||||
if insert_stmt._hints:
|
||||
dialect_hints = dict([
|
||||
(table, hint_text)
|
||||
for (table, dialect), hint_text in
|
||||
for (table, dialect), hint_text in
|
||||
insert_stmt._hints.items()
|
||||
if dialect in ('*', self.dialect.name)
|
||||
])
|
||||
if insert_stmt.table in dialect_hints:
|
||||
text += " " + self.get_crud_hint_text(
|
||||
insert_stmt.table,
|
||||
insert_stmt.table,
|
||||
dialect_hints[insert_stmt.table]
|
||||
)
|
||||
|
||||
@@ -1098,7 +1163,7 @@ class SQLCompiler(engine.Compiled):
|
||||
"""Provide a hook for MySQL to add LIMIT to the UPDATE"""
|
||||
return None
|
||||
|
||||
def update_tables_clause(self, update_stmt, from_table,
|
||||
def update_tables_clause(self, update_stmt, from_table,
|
||||
extra_froms, **kw):
|
||||
"""Provide a hook to override the initial table clause
|
||||
in an UPDATE statement.
|
||||
@@ -1108,19 +1173,19 @@ class SQLCompiler(engine.Compiled):
|
||||
"""
|
||||
return self.preparer.format_table(from_table)
|
||||
|
||||
def update_from_clause(self, update_stmt,
|
||||
from_table, extra_froms,
|
||||
def update_from_clause(self, update_stmt,
|
||||
from_table, extra_froms,
|
||||
from_hints,
|
||||
**kw):
|
||||
"""Provide a hook to override the generation of an
|
||||
"""Provide a hook to override the generation of an
|
||||
UPDATE..FROM clause.
|
||||
|
||||
MySQL overrides this.
|
||||
MySQL and MSSQL override this.
|
||||
|
||||
"""
|
||||
return "FROM " + ', '.join(
|
||||
t._compiler_dispatch(self, asfrom=True,
|
||||
fromhints=from_hints, **kw)
|
||||
t._compiler_dispatch(self, asfrom=True,
|
||||
fromhints=from_hints, **kw)
|
||||
for t in extra_froms)
|
||||
|
||||
def visit_update(self, update_stmt, **kw):
|
||||
@@ -1133,20 +1198,20 @@ class SQLCompiler(engine.Compiled):
|
||||
colparams = self._get_colparams(update_stmt, extra_froms)
|
||||
|
||||
text = "UPDATE " + self.update_tables_clause(
|
||||
update_stmt,
|
||||
update_stmt.table,
|
||||
update_stmt,
|
||||
update_stmt.table,
|
||||
extra_froms, **kw)
|
||||
|
||||
if update_stmt._hints:
|
||||
dialect_hints = dict([
|
||||
(table, hint_text)
|
||||
for (table, dialect), hint_text in
|
||||
for (table, dialect), hint_text in
|
||||
update_stmt._hints.items()
|
||||
if dialect in ('*', self.dialect.name)
|
||||
])
|
||||
if update_stmt.table in dialect_hints:
|
||||
text += " " + self.get_crud_hint_text(
|
||||
update_stmt.table,
|
||||
update_stmt.table,
|
||||
dialect_hints[update_stmt.table]
|
||||
)
|
||||
else:
|
||||
@@ -1155,12 +1220,12 @@ class SQLCompiler(engine.Compiled):
|
||||
text += ' SET '
|
||||
if extra_froms and self.render_table_with_column_in_update_from:
|
||||
text += ', '.join(
|
||||
self.visit_column(c[0]) +
|
||||
self.visit_column(c[0]) +
|
||||
'=' + c[1] for c in colparams
|
||||
)
|
||||
else:
|
||||
text += ', '.join(
|
||||
self.preparer.quote(c[0].name, c[0].quote) +
|
||||
self.preparer.quote(c[0].name, c[0].quote) +
|
||||
'=' + c[1] for c in colparams
|
||||
)
|
||||
|
||||
@@ -1172,9 +1237,9 @@ class SQLCompiler(engine.Compiled):
|
||||
|
||||
if extra_froms:
|
||||
extra_from_text = self.update_from_clause(
|
||||
update_stmt,
|
||||
update_stmt.table,
|
||||
extra_froms,
|
||||
update_stmt,
|
||||
update_stmt.table,
|
||||
extra_froms,
|
||||
dialect_hints, **kw)
|
||||
if extra_from_text:
|
||||
text += " " + extra_from_text
|
||||
@@ -1195,7 +1260,7 @@ class SQLCompiler(engine.Compiled):
|
||||
return text
|
||||
|
||||
def _create_crud_bind_param(self, col, value, required=False):
|
||||
bindparam = sql.bindparam(col.key, value,
|
||||
bindparam = sql.bindparam(col.key, value,
|
||||
type_=col.type, required=required)
|
||||
bindparam._is_crud = True
|
||||
return bindparam._compiler_dispatch(self)
|
||||
@@ -1220,8 +1285,8 @@ class SQLCompiler(engine.Compiled):
|
||||
# compiled params - return binds for all columns
|
||||
if self.column_keys is None and stmt.parameters is None:
|
||||
return [
|
||||
(c, self._create_crud_bind_param(c,
|
||||
None, required=True))
|
||||
(c, self._create_crud_bind_param(c,
|
||||
None, required=True))
|
||||
for c in stmt.table.columns
|
||||
]
|
||||
|
||||
@@ -1233,8 +1298,8 @@ class SQLCompiler(engine.Compiled):
|
||||
parameters = {}
|
||||
else:
|
||||
parameters = dict((sql._column_as_key(key), required)
|
||||
for key in self.column_keys
|
||||
if not stmt.parameters or
|
||||
for key in self.column_keys
|
||||
if not stmt.parameters or
|
||||
key not in stmt.parameters)
|
||||
|
||||
if stmt.parameters is not None:
|
||||
@@ -1255,7 +1320,7 @@ class SQLCompiler(engine.Compiled):
|
||||
postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid
|
||||
|
||||
check_columns = {}
|
||||
# special logic that only occurs for multi-table UPDATE
|
||||
# special logic that only occurs for multi-table UPDATE
|
||||
# statements
|
||||
if extra_tables and stmt.parameters:
|
||||
assert self.isupdate
|
||||
@@ -1274,7 +1339,7 @@ class SQLCompiler(engine.Compiled):
|
||||
value = self.process(value.self_group())
|
||||
values.append((c, value))
|
||||
# determine tables which are actually
|
||||
# to be updated - process onupdate and
|
||||
# to be updated - process onupdate and
|
||||
# server_onupdate for these
|
||||
for t in affected_tables:
|
||||
for c in t.c:
|
||||
@@ -1295,7 +1360,7 @@ class SQLCompiler(engine.Compiled):
|
||||
self.postfetch.append(c)
|
||||
|
||||
# iterating through columns at the top to maintain ordering.
|
||||
# otherwise we might iterate through individual sets of
|
||||
# otherwise we might iterate through individual sets of
|
||||
# "defaults", "primary key cols", etc.
|
||||
for c in stmt.table.columns:
|
||||
if c.key in parameters and c.key not in check_columns:
|
||||
@@ -1315,8 +1380,8 @@ class SQLCompiler(engine.Compiled):
|
||||
if c.primary_key and \
|
||||
need_pks and \
|
||||
(
|
||||
implicit_returning or
|
||||
not postfetch_lastrowid or
|
||||
implicit_returning or
|
||||
not postfetch_lastrowid or
|
||||
c is not stmt.table._autoincrement_column
|
||||
):
|
||||
|
||||
@@ -1402,7 +1467,7 @@ class SQLCompiler(engine.Compiled):
|
||||
).difference(check_columns)
|
||||
if check:
|
||||
util.warn(
|
||||
"Unconsumed column names: %s" %
|
||||
"Unconsumed column names: %s" %
|
||||
(", ".join(check))
|
||||
)
|
||||
|
||||
@@ -1417,13 +1482,13 @@ class SQLCompiler(engine.Compiled):
|
||||
if delete_stmt._hints:
|
||||
dialect_hints = dict([
|
||||
(table, hint_text)
|
||||
for (table, dialect), hint_text in
|
||||
for (table, dialect), hint_text in
|
||||
delete_stmt._hints.items()
|
||||
if dialect in ('*', self.dialect.name)
|
||||
])
|
||||
if delete_stmt.table in dialect_hints:
|
||||
text += " " + self.get_crud_hint_text(
|
||||
delete_stmt.table,
|
||||
delete_stmt.table,
|
||||
dialect_hints[delete_stmt.table]
|
||||
)
|
||||
else:
|
||||
@@ -1517,7 +1582,7 @@ class DDLCompiler(engine.Compiled):
|
||||
text += separator
|
||||
separator = ", \n"
|
||||
text += "\t" + self.get_column_specification(
|
||||
column,
|
||||
column,
|
||||
first_pk=column.primary_key and \
|
||||
not first_pk
|
||||
)
|
||||
@@ -1529,16 +1594,16 @@ class DDLCompiler(engine.Compiled):
|
||||
text += " " + const
|
||||
except exc.CompileError, ce:
|
||||
# Py3K
|
||||
#raise exc.CompileError("(in table '%s', column '%s'): %s"
|
||||
#raise exc.CompileError("(in table '%s', column '%s'): %s"
|
||||
# % (
|
||||
# table.description,
|
||||
# column.name,
|
||||
# table.description,
|
||||
# column.name,
|
||||
# ce.args[0]
|
||||
# )) from ce
|
||||
# Py2K
|
||||
raise exc.CompileError("(in table '%s', column '%s'): %s"
|
||||
raise exc.CompileError("(in table '%s', column '%s'): %s"
|
||||
% (
|
||||
table.description,
|
||||
table.description,
|
||||
column.name,
|
||||
ce.args[0]
|
||||
)), None, sys.exc_info()[2]
|
||||
@@ -1559,17 +1624,17 @@ class DDLCompiler(engine.Compiled):
|
||||
if table.primary_key:
|
||||
constraints.append(table.primary_key)
|
||||
|
||||
constraints.extend([c for c in table._sorted_constraints
|
||||
constraints.extend([c for c in table._sorted_constraints
|
||||
if c is not table.primary_key])
|
||||
|
||||
return ", \n\t".join(p for p in
|
||||
(self.process(constraint)
|
||||
for constraint in constraints
|
||||
(self.process(constraint)
|
||||
for constraint in constraints
|
||||
if (
|
||||
constraint._create_rule is None or
|
||||
constraint._create_rule(self))
|
||||
and (
|
||||
not self.dialect.supports_alter or
|
||||
not self.dialect.supports_alter or
|
||||
not getattr(constraint, 'use_alter', False)
|
||||
)) if p is not None
|
||||
)
|
||||
@@ -1582,13 +1647,12 @@ class DDLCompiler(engine.Compiled):
|
||||
max = self.dialect.max_index_name_length or \
|
||||
self.dialect.max_identifier_length
|
||||
if len(ident) > max:
|
||||
return ident[0:max - 8] + \
|
||||
ident = ident[0:max - 8] + \
|
||||
"_" + util.md5_hex(ident)[-4:]
|
||||
else:
|
||||
return ident
|
||||
else:
|
||||
self.dialect.validate_identifier(ident)
|
||||
return ident
|
||||
|
||||
return ident
|
||||
|
||||
def visit_create_index(self, create):
|
||||
index = create.element
|
||||
@@ -1597,7 +1661,7 @@ class DDLCompiler(engine.Compiled):
|
||||
if index.unique:
|
||||
text += "UNIQUE "
|
||||
text += "INDEX %s ON %s (%s)" \
|
||||
% (preparer.quote(self._index_identifier(index.name),
|
||||
% (preparer.quote(self._index_identifier(index.name),
|
||||
index.quote),
|
||||
preparer.format_table(index.table),
|
||||
', '.join(preparer.quote(c.name, c.quote)
|
||||
@@ -1606,9 +1670,20 @@ class DDLCompiler(engine.Compiled):
|
||||
|
||||
def visit_drop_index(self, drop):
|
||||
index = drop.element
|
||||
return "\nDROP INDEX " + \
|
||||
self.preparer.quote(
|
||||
self._index_identifier(index.name), index.quote)
|
||||
if index.table is not None and index.table.schema:
|
||||
schema = index.table.schema
|
||||
schema_name = self.preparer.quote_schema(schema,
|
||||
index.table.quote_schema)
|
||||
else:
|
||||
schema_name = None
|
||||
|
||||
index_name = self.preparer.quote(
|
||||
self._index_identifier(index.name),
|
||||
index.quote)
|
||||
|
||||
if schema_name:
|
||||
index_name = schema_name + "." + index_name
|
||||
return "\nDROP INDEX " + index_name
|
||||
|
||||
def visit_add_constraint(self, create):
|
||||
preparer = self.preparer
|
||||
@@ -1723,7 +1798,7 @@ class DDLCompiler(engine.Compiled):
|
||||
text += "CONSTRAINT %s " % \
|
||||
self.preparer.format_constraint(constraint)
|
||||
text += "UNIQUE (%s)" % (
|
||||
', '.join(self.preparer.quote(c.name, c.quote)
|
||||
', '.join(self.preparer.quote(c.name, c.quote)
|
||||
for c in constraint))
|
||||
text += self.define_constraint_deferrability(constraint)
|
||||
return text
|
||||
@@ -1769,7 +1844,7 @@ class GenericTypeCompiler(engine.TypeCompiler):
|
||||
{'precision': type_.precision}
|
||||
else:
|
||||
return "NUMERIC(%(precision)s, %(scale)s)" % \
|
||||
{'precision': type_.precision,
|
||||
{'precision': type_.precision,
|
||||
'scale' : type_.scale}
|
||||
|
||||
def visit_DECIMAL(self, type_):
|
||||
@@ -1826,25 +1901,25 @@ class GenericTypeCompiler(engine.TypeCompiler):
|
||||
def visit_large_binary(self, type_):
|
||||
return self.visit_BLOB(type_)
|
||||
|
||||
def visit_boolean(self, type_):
|
||||
def visit_boolean(self, type_):
|
||||
return self.visit_BOOLEAN(type_)
|
||||
|
||||
def visit_time(self, type_):
|
||||
def visit_time(self, type_):
|
||||
return self.visit_TIME(type_)
|
||||
|
||||
def visit_datetime(self, type_):
|
||||
def visit_datetime(self, type_):
|
||||
return self.visit_DATETIME(type_)
|
||||
|
||||
def visit_date(self, type_):
|
||||
def visit_date(self, type_):
|
||||
return self.visit_DATE(type_)
|
||||
|
||||
def visit_big_integer(self, type_):
|
||||
def visit_big_integer(self, type_):
|
||||
return self.visit_BIGINT(type_)
|
||||
|
||||
def visit_small_integer(self, type_):
|
||||
def visit_small_integer(self, type_):
|
||||
return self.visit_SMALLINT(type_)
|
||||
|
||||
def visit_integer(self, type_):
|
||||
def visit_integer(self, type_):
|
||||
return self.visit_INTEGER(type_)
|
||||
|
||||
def visit_real(self, type_):
|
||||
@@ -1853,19 +1928,19 @@ class GenericTypeCompiler(engine.TypeCompiler):
|
||||
def visit_float(self, type_):
|
||||
return self.visit_FLOAT(type_)
|
||||
|
||||
def visit_numeric(self, type_):
|
||||
def visit_numeric(self, type_):
|
||||
return self.visit_NUMERIC(type_)
|
||||
|
||||
def visit_string(self, type_):
|
||||
def visit_string(self, type_):
|
||||
return self.visit_VARCHAR(type_)
|
||||
|
||||
def visit_unicode(self, type_):
|
||||
def visit_unicode(self, type_):
|
||||
return self.visit_VARCHAR(type_)
|
||||
|
||||
def visit_text(self, type_):
|
||||
def visit_text(self, type_):
|
||||
return self.visit_TEXT(type_)
|
||||
|
||||
def visit_unicode_text(self, type_):
|
||||
def visit_unicode_text(self, type_):
|
||||
return self.visit_TEXT(type_)
|
||||
|
||||
def visit_enum(self, type_):
|
||||
@@ -1889,7 +1964,7 @@ class IdentifierPreparer(object):
|
||||
|
||||
illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
|
||||
|
||||
def __init__(self, dialect, initial_quote='"',
|
||||
def __init__(self, dialect, initial_quote='"',
|
||||
final_quote=None, escape_quote='"', omit_schema=False):
|
||||
"""Construct a new ``IdentifierPreparer`` object.
|
||||
|
||||
@@ -1953,7 +2028,7 @@ class IdentifierPreparer(object):
|
||||
def quote_schema(self, schema, force):
|
||||
"""Quote a schema.
|
||||
|
||||
Subclasses should override this to provide database-dependent
|
||||
Subclasses should override this to provide database-dependent
|
||||
quoting behavior.
|
||||
"""
|
||||
return self.quote(schema, force)
|
||||
@@ -2010,7 +2085,7 @@ class IdentifierPreparer(object):
|
||||
|
||||
return self.quote(name, quote)
|
||||
|
||||
def format_column(self, column, use_table=False,
|
||||
def format_column(self, column, use_table=False,
|
||||
name=None, table_name=None):
|
||||
"""Prepare a quoted column name."""
|
||||
|
||||
@@ -2019,14 +2094,14 @@ class IdentifierPreparer(object):
|
||||
if not getattr(column, 'is_literal', False):
|
||||
if use_table:
|
||||
return self.format_table(
|
||||
column.table, use_schema=False,
|
||||
column.table, use_schema=False,
|
||||
name=table_name) + "." + \
|
||||
self.quote(name, column.quote)
|
||||
else:
|
||||
return self.quote(name, column.quote)
|
||||
else:
|
||||
# literal textual elements get stuck into ColumnClause alot,
|
||||
# which shouldnt get quoted
|
||||
# literal textual elements get stuck into ColumnClause a lot,
|
||||
# which shouldn't get quoted
|
||||
|
||||
if use_table:
|
||||
return self.format_table(column.table,
|
||||
|
||||
Reference in New Issue
Block a user