Update tornado

This commit is contained in:
Ruud
2012-07-07 09:29:32 +02:00
parent 57547fbd7c
commit 1018a7dd32
33 changed files with 1505 additions and 802 deletions
Regular → Executable
+4 -2
View File
@@ -16,6 +16,8 @@
"""The Tornado web server and tools."""
from __future__ import absolute_import, division, with_statement
# version is a human-readable version number.
# version_info is a four-tuple for programmatic comparison. The first
@@ -23,5 +25,5 @@
# is zero for an official release, positive for a development branch,
# or negative for a release candidate (after the base version number
# has been incremented)
version = "2.2.1"
version_info = (2, 2, 1, 0)
version = "2.3.post1"
version_info = (2, 3, 0, 1)
Regular → Executable
+99 -80
View File
@@ -44,6 +44,8 @@ Example usage for Google OpenID::
# Save the user with, e.g., set_secure_cookie()
"""
from __future__ import absolute_import, division, with_statement
import base64
import binascii
import hashlib
@@ -59,13 +61,14 @@ from tornado import escape
from tornado.httputil import url_concat
from tornado.util import bytes_type, b
class OpenIdMixin(object):
"""Abstract implementation of OpenID and Attribute Exchange.
See GoogleMixin below for example implementations.
"""
def authenticate_redirect(self, callback_uri=None,
ax_attrs=["name","email","language","username"]):
ax_attrs=["name", "email", "language", "username"]):
"""Returns the authentication URL for this service.
After authentication, the service will redirect back to the given
@@ -91,7 +94,8 @@ class OpenIdMixin(object):
args = dict((k, v[-1]) for k, v in self.request.arguments.iteritems())
args["openid.mode"] = u"check_authentication"
url = self._OPENID_ENDPOINT
if http_client is None: http_client = httpclient.AsyncHTTPClient()
if http_client is None:
http_client = httpclient.AsyncHTTPClient()
http_client.fetch(url, self.async_callback(
self._on_authentication_verified, callback),
method="POST", body=urllib.urlencode(args))
@@ -158,8 +162,10 @@ class OpenIdMixin(object):
self.get_argument(name) == u"http://openid.net/srv/ax/1.0":
ax_ns = name[10:]
break
def get_ax_arg(uri):
if not ax_ns: return u""
if not ax_ns:
return u""
prefix = "openid." + ax_ns + ".type."
ax_name = None
for name in self.request.arguments.iterkeys():
@@ -167,7 +173,8 @@ class OpenIdMixin(object):
part = name[len(prefix):]
ax_name = "openid." + ax_ns + ".value." + part
break
if not ax_name: return u""
if not ax_name:
return u""
return self.get_argument(ax_name, u"")
email = get_ax_arg("http://axschema.org/contact/email")
@@ -190,9 +197,12 @@ class OpenIdMixin(object):
user["name"] = u" ".join(name_parts)
elif email:
user["name"] = email.split("@")[0]
if email: user["email"] = email
if locale: user["locale"] = locale
if username: user["username"] = username
if email:
user["email"] = email
if locale:
user["locale"] = locale
if username:
user["username"] = username
callback(user)
@@ -235,7 +245,6 @@ class OAuthMixin(object):
self._on_request_token, self._OAUTH_AUTHORIZE_URL,
callback_uri))
def get_authenticated_user(self, callback, http_client=None):
"""Gets the OAuth authorized user and access token on callback.
@@ -269,7 +278,7 @@ class OAuthMixin(object):
http_client.fetch(self._oauth_access_token_url(token),
self.async_callback(self._on_access_token, callback))
def _oauth_request_token_url(self, callback_uri= None, extra_params=None):
def _oauth_request_token_url(self, callback_uri=None, extra_params=None):
consumer_token = self._oauth_consumer_token()
url = self._OAUTH_REQUEST_TOKEN_URL
args = dict(
@@ -283,7 +292,8 @@ class OAuthMixin(object):
if callback_uri:
args["oauth_callback"] = urlparse.urljoin(
self.request.full_url(), callback_uri)
if extra_params: args.update(extra_params)
if extra_params:
args.update(extra_params)
signature = _oauth10a_signature(consumer_token, "GET", url, args)
else:
signature = _oauth_signature(consumer_token, "GET", url, args)
@@ -316,7 +326,7 @@ class OAuthMixin(object):
oauth_version=getattr(self, "_OAUTH_VERSION", "1.0a"),
)
if "verifier" in request_token:
args["oauth_verifier"]=request_token["verifier"]
args["oauth_verifier"] = request_token["verifier"]
if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a":
signature = _oauth10a_signature(consumer_token, "GET", url, args,
@@ -376,11 +386,12 @@ class OAuthMixin(object):
base_args["oauth_signature"] = signature
return base_args
class OAuth2Mixin(object):
"""Abstract implementation of OAuth v 2."""
def authorize_redirect(self, redirect_uri=None, client_id=None,
client_secret=None, extra_params=None ):
client_secret=None, extra_params=None):
"""Redirects the user to obtain OAuth authorization for this service.
Some providers require that you register a Callback
@@ -393,11 +404,12 @@ class OAuth2Mixin(object):
"redirect_uri": redirect_uri,
"client_id": client_id
}
if extra_params: args.update(extra_params)
if extra_params:
args.update(extra_params)
self.redirect(
url_concat(self._OAUTH_AUTHORIZE_URL, args))
def _oauth_request_token_url(self, redirect_uri= None, client_id = None,
def _oauth_request_token_url(self, redirect_uri=None, client_id=None,
client_secret=None, code=None,
extra_params=None):
url = self._OAUTH_ACCESS_TOKEN_URL
@@ -407,9 +419,11 @@ class OAuth2Mixin(object):
client_id=client_id,
client_secret=client_secret,
)
if extra_params: args.update(extra_params)
if extra_params:
args.update(extra_params)
return url_concat(url, args)
class TwitterMixin(OAuthMixin):
"""Twitter OAuth authentication.
@@ -450,15 +464,14 @@ class TwitterMixin(OAuthMixin):
_OAUTH_AUTHENTICATE_URL = "http://api.twitter.com/oauth/authenticate"
_OAUTH_NO_CALLBACKS = False
def authenticate_redirect(self, callback_uri = None):
def authenticate_redirect(self, callback_uri=None):
"""Just like authorize_redirect(), but auto-redirects if authorized.
This is generally the right interface to use if you are using
Twitter for single-sign on.
"""
http = httpclient.AsyncHTTPClient()
http.fetch(self._oauth_request_token_url(callback_uri = callback_uri), self.async_callback(
http.fetch(self._oauth_request_token_url(callback_uri=callback_uri), self.async_callback(
self._on_request_token, self._OAUTH_AUTHENTICATE_URL, None))
def twitter_request(self, path, callback, access_token=None,
@@ -514,7 +527,8 @@ class TwitterMixin(OAuthMixin):
oauth = self._oauth_request_parameters(
url, access_token, all_args, method=method)
args.update(oauth)
if args: url += "?" + urllib.urlencode(args)
if args:
url += "?" + urllib.urlencode(args)
callback = self.async_callback(self._on_twitter_request, callback)
http = httpclient.AsyncHTTPClient()
if post_args is not None:
@@ -590,7 +604,6 @@ class FriendFeedMixin(OAuthMixin):
_OAUTH_NO_CALLBACKS = True
_OAUTH_VERSION = "1.0"
def friendfeed_request(self, path, callback, access_token=None,
post_args=None, **args):
"""Fetches the given relative API path, e.g., "/bret/friends"
@@ -636,7 +649,8 @@ class FriendFeedMixin(OAuthMixin):
oauth = self._oauth_request_parameters(
url, access_token, all_args, method=method)
args.update(oauth)
if args: url += "?" + urllib.urlencode(args)
if args:
url += "?" + urllib.urlencode(args)
callback = self.async_callback(self._on_friendfeed_request, callback)
http = httpclient.AsyncHTTPClient()
if post_args is not None:
@@ -701,7 +715,7 @@ class GoogleMixin(OpenIdMixin, OAuthMixin):
_OAUTH_ACCESS_TOKEN_URL = "https://www.google.com/accounts/OAuthGetAccessToken"
def authorize_redirect(self, oauth_scope, callback_uri=None,
ax_attrs=["name","email","language","username"]):
ax_attrs=["name", "email", "language", "username"]):
"""Authenticates and authorizes for the given Google resource.
Some of the available resources are:
@@ -746,6 +760,7 @@ class GoogleMixin(OpenIdMixin, OAuthMixin):
def _oauth_get_user(self, access_token, callback):
OpenIdMixin.get_authenticated_user(self, callback)
class FacebookMixin(object):
"""Facebook Connect authentication.
@@ -926,9 +941,11 @@ class FacebookMixin(object):
def _signature(self, args):
parts = ["%s=%s" % (n, args[n]) for n in sorted(args.keys())]
body = "".join(parts) + self.settings["facebook_secret"]
if isinstance(body, unicode): body = body.encode("utf-8")
if isinstance(body, unicode):
body = body.encode("utf-8")
return hashlib.md5(body).hexdigest()
class FacebookGraphMixin(OAuth2Mixin):
"""Facebook authentication using the new Graph API and OAuth2."""
_OAUTH_ACCESS_TOKEN_URL = "https://graph.facebook.com/oauth/access_token?"
@@ -937,68 +954,68 @@ class FacebookGraphMixin(OAuth2Mixin):
def get_authenticated_user(self, redirect_uri, client_id, client_secret,
code, callback, extra_fields=None):
"""Handles the login for the Facebook user, returning a user object.
"""Handles the login for the Facebook user, returning a user object.
Example usage::
Example usage::
class FacebookGraphLoginHandler(LoginHandler, tornado.auth.FacebookGraphMixin):
@tornado.web.asynchronous
def get(self):
if self.get_argument("code", False):
self.get_authenticated_user(
redirect_uri='/auth/facebookgraph/',
client_id=self.settings["facebook_api_key"],
client_secret=self.settings["facebook_secret"],
code=self.get_argument("code"),
callback=self.async_callback(
self._on_login))
return
self.authorize_redirect(redirect_uri='/auth/facebookgraph/',
client_id=self.settings["facebook_api_key"],
extra_params={"scope": "read_stream,offline_access"})
class FacebookGraphLoginHandler(LoginHandler, tornado.auth.FacebookGraphMixin):
@tornado.web.asynchronous
def get(self):
if self.get_argument("code", False):
self.get_authenticated_user(
redirect_uri='/auth/facebookgraph/',
client_id=self.settings["facebook_api_key"],
client_secret=self.settings["facebook_secret"],
code=self.get_argument("code"),
callback=self.async_callback(
self._on_login))
return
self.authorize_redirect(redirect_uri='/auth/facebookgraph/',
client_id=self.settings["facebook_api_key"],
extra_params={"scope": "read_stream,offline_access"})
def _on_login(self, user):
logging.error(user)
self.finish()
def _on_login(self, user):
logging.error(user)
self.finish()
"""
http = httpclient.AsyncHTTPClient()
args = {
"redirect_uri": redirect_uri,
"code": code,
"client_id": client_id,
"client_secret": client_secret,
}
"""
http = httpclient.AsyncHTTPClient()
args = {
"redirect_uri": redirect_uri,
"code": code,
"client_id": client_id,
"client_secret": client_secret,
}
fields = set(['id', 'name', 'first_name', 'last_name',
'locale', 'picture', 'link'])
if extra_fields: fields.update(extra_fields)
fields = set(['id', 'name', 'first_name', 'last_name',
'locale', 'picture', 'link'])
if extra_fields:
fields.update(extra_fields)
http.fetch(self._oauth_request_token_url(**args),
self.async_callback(self._on_access_token, redirect_uri, client_id,
client_secret, callback, fields))
http.fetch(self._oauth_request_token_url(**args),
self.async_callback(self._on_access_token, redirect_uri, client_id,
client_secret, callback, fields))
def _on_access_token(self, redirect_uri, client_id, client_secret,
callback, fields, response):
if response.error:
logging.warning('Facebook auth error: %s' % str(response))
callback(None)
return
if response.error:
logging.warning('Facebook auth error: %s' % str(response))
callback(None)
return
args = escape.parse_qs_bytes(escape.native_str(response.body))
session = {
"access_token": args["access_token"][-1],
"expires": args.get("expires")
}
self.facebook_request(
path="/me",
callback=self.async_callback(
self._on_get_user_info, callback, session, fields),
access_token=session["access_token"],
fields=",".join(fields)
)
args = escape.parse_qs_bytes(escape.native_str(response.body))
session = {
"access_token": args["access_token"][-1],
"expires": args.get("expires")
}
self.facebook_request(
path="/me",
callback=self.async_callback(
self._on_get_user_info, callback, session, fields),
access_token=session["access_token"],
fields=",".join(fields)
)
def _on_get_user_info(self, callback, session, fields, user):
if user is None:
@@ -1052,8 +1069,9 @@ class FacebookGraphMixin(OAuth2Mixin):
if access_token:
all_args["access_token"] = access_token
all_args.update(args)
all_args.update(post_args or {})
if all_args: url += "?" + urllib.urlencode(all_args)
if all_args:
url += "?" + urllib.urlencode(all_args)
callback = self.async_callback(self._on_facebook_request, callback)
http = httpclient.AsyncHTTPClient()
if post_args is not None:
@@ -1070,6 +1088,7 @@ class FacebookGraphMixin(OAuth2Mixin):
return
callback(escape.json_decode(response.body))
def _oauth_signature(consumer_token, method, url, parameters={}, token=None):
"""Calculates the HMAC-SHA1 OAuth signature for the given request.
@@ -1084,7 +1103,7 @@ def _oauth_signature(consumer_token, method, url, parameters={}, token=None):
base_elems.append(normalized_url)
base_elems.append("&".join("%s=%s" % (k, _oauth_escape(str(v)))
for k, v in sorted(parameters.items())))
base_string = "&".join(_oauth_escape(e) for e in base_elems)
base_string = "&".join(_oauth_escape(e) for e in base_elems)
key_elems = [escape.utf8(consumer_token["secret"])]
key_elems.append(escape.utf8(token["secret"] if token else ""))
@@ -1093,6 +1112,7 @@ def _oauth_signature(consumer_token, method, url, parameters={}, token=None):
hash = hmac.new(key, escape.utf8(base_string), hashlib.sha1)
return binascii.b2a_base64(hash.digest())[:-1]
def _oauth10a_signature(consumer_token, method, url, parameters={}, token=None):
"""Calculates the HMAC-SHA1 OAuth 1.0a signature for the given request.
@@ -1108,7 +1128,7 @@ def _oauth10a_signature(consumer_token, method, url, parameters={}, token=None):
base_elems.append("&".join("%s=%s" % (k, _oauth_escape(str(v)))
for k, v in sorted(parameters.items())))
base_string = "&".join(_oauth_escape(e) for e in base_elems)
base_string = "&".join(_oauth_escape(e) for e in base_elems)
key_elems = [escape.utf8(urllib.quote(consumer_token["secret"], safe='~'))]
key_elems.append(escape.utf8(urllib.quote(token["secret"], safe='~') if token else ""))
key = b("&").join(key_elems)
@@ -1116,6 +1136,7 @@ def _oauth10a_signature(consumer_token, method, url, parameters={}, token=None):
hash = hmac.new(key, escape.utf8(base_string), hashlib.sha1)
return binascii.b2a_base64(hash.digest())[:-1]
def _oauth_escape(val):
if isinstance(val, unicode):
val = val.encode("utf-8")
@@ -1130,5 +1151,3 @@ def _oauth_parse_response(body):
special = (b("oauth_token"), b("oauth_token_secret"))
token.update((k, p[k][0]) for k in p if k not in special)
return token
Regular → Executable
+69 -21
View File
@@ -24,9 +24,47 @@ and static resources.
This module depends on IOLoop, so it will not work in WSGI applications
and Google AppEngine. It also will not work correctly when HTTPServer's
multi-process mode is used.
Reloading loses any Python interpreter command-line arguments (e.g. ``-u``)
because it re-executes Python using ``sys.executable`` and ``sys.argv``.
Additionally, modifying these variables will cause reloading to behave
incorrectly.
"""
from __future__ import with_statement
from __future__ import absolute_import, division, with_statement
import os
import sys
# sys.path handling
# -----------------
#
# If a module is run with "python -m", the current directory (i.e. "")
# is automatically prepended to sys.path, but not if it is run as
# "path/to/file.py". The processing for "-m" rewrites the former to
# the latter, so subsequent executions won't have the same path as the
# original.
#
# Conversely, when run as path/to/file.py, the directory containing
# file.py gets added to the path, which can cause confusion as imports
# may become relative in spite of the future import.
#
# We address the former problem by setting the $PYTHONPATH environment
# variable before re-execution so the new process will see the correct
# path. We attempt to address the latter problem when tornado.autoreload
# is run as __main__, although we can't fix the general case because
# we cannot reliably reconstruct the original command line
# (http://bugs.python.org/issue14208).
if __name__ == "__main__":
# This sys.path manipulation must come before our imports (as much
# as possible - if we introduced a tornado.sys or tornado.os
# module we'd be in trouble), or else our imports would become
# relative again despite the future import.
#
# There is a separate __main__ block at the end of the file to call main().
if sys.path[0] == os.path.dirname(__file__):
del sys.path[0]
import functools
import logging
@@ -44,6 +82,7 @@ try:
except ImportError:
signal = None
def start(io_loop=None, check_time=500):
"""Restarts the process automatically when a module is modified.
@@ -57,6 +96,7 @@ def start(io_loop=None, check_time=500):
scheduler = ioloop.PeriodicCallback(callback, check_time, io_loop=io_loop)
scheduler.start()
def wait():
"""Wait for a watched file to change, then restart the process.
@@ -70,6 +110,7 @@ def wait():
_watched_files = set()
def watch(filename):
"""Add a file to the watch list.
@@ -79,6 +120,7 @@ def watch(filename):
_reload_hooks = []
def add_reload_hook(fn):
"""Add a function to be called before reloading the process.
@@ -89,6 +131,7 @@ def add_reload_hook(fn):
"""
_reload_hooks.append(fn)
def _close_all_fds(io_loop):
for fd in io_loop._handlers.keys():
try:
@@ -98,6 +141,7 @@ def _close_all_fds(io_loop):
_reload_attempted = False
def _reload_on_update(modify_times):
if _reload_attempted:
# We already tried to reload and it didn't work, so don't try again.
@@ -112,15 +156,18 @@ def _reload_on_update(modify_times):
# in the standard library), and occasionally this can cause strange
# failures in getattr. Just ignore anything that's not an ordinary
# module.
if not isinstance(module, types.ModuleType): continue
if not isinstance(module, types.ModuleType):
continue
path = getattr(module, "__file__", None)
if not path: continue
if not path:
continue
if path.endswith(".pyc") or path.endswith(".pyo"):
path = path[:-1]
_check_file(modify_times, path)
for path in _watched_files:
_check_file(modify_times, path)
def _check_file(modify_times, path):
try:
modified = os.stat(path).st_mtime
@@ -133,6 +180,7 @@ def _check_file(modify_times, path):
logging.info("%s modified; restarting server", path)
_reload()
def _reload():
global _reload_attempted
_reload_attempted = True
@@ -143,6 +191,15 @@ def _reload():
# ioloop.set_blocking_log_threshold so it doesn't fire
# after the exec.
signal.setitimer(signal.ITIMER_REAL, 0, 0)
# sys.path fixes: see comments at top of file. If sys.path[0] is an empty
# string, we were (probably) invoked with -m and the effective path
# is about to change on re-exec. Add the current directory to $PYTHONPATH
# to ensure that the new process sees the same path we did.
path_prefix = '.' + os.pathsep
if (sys.path[0] == '' and
not os.environ.get("PYTHONPATH", "").startswith(path_prefix)):
os.environ["PYTHONPATH"] = (path_prefix +
os.environ.get("PYTHONPATH", ""))
if sys.platform == 'win32':
# os.execv is broken on Windows and can't properly parse command line
# arguments and executable name if they contain whitespaces. subprocess
@@ -173,9 +230,11 @@ Usage:
python -m tornado.autoreload -m module.to.run [args...]
python -m tornado.autoreload path/to/script.py [args...]
"""
def main():
"""Command-line wrapper to re-run a script whenever its source changes.
Scripts may be specified by filename or module name::
python -m tornado.autoreload -m tornado.test.runtests
@@ -226,25 +285,14 @@ def main():
if mode == 'module':
# runpy did a fake import of the module as __main__, but now it's
# no longer in sys.modules. Figure out where it is and watch it.
watch(pkgutil.get_loader(module).get_filename())
loader = pkgutil.get_loader(module)
if loader is not None:
watch(loader.get_filename())
wait()
if __name__ == "__main__":
# If this module is run with "python -m tornado.autoreload", the current
# directory is automatically prepended to sys.path, but not if it is
# run as "path/to/tornado/autoreload.py". The processing for "-m" rewrites
# the former to the latter, so subsequent executions won't have the same
# path as the original. Modify os.environ here to ensure that the
# re-executed process will have the same path.
# Conversely, when run as path/to/tornado/autoreload.py, the directory
# containing autoreload.py gets added to the path, but we don't want
# tornado modules importable at top level, so remove it.
path_prefix = '.' + os.pathsep
if (sys.path[0] == '' and
not os.environ.get("PYTHONPATH", "").startswith(path_prefix)):
os.environ["PYTHONPATH"] = path_prefix + os.environ.get("PYTHONPATH", "")
elif sys.path[0] == os.path.dirname(__file__):
del sys.path[0]
# See also the other __main__ block at the top of the file, which modifies
# sys.path before our imports
main()
Regular → Executable
View File
Regular → Executable
+14 -8
View File
@@ -16,7 +16,7 @@
"""Blocking and non-blocking HTTP client implementations using pycurl."""
from __future__ import with_statement
from __future__ import absolute_import, division, with_statement
import cStringIO
import collections
@@ -32,6 +32,7 @@ from tornado import stack_context
from tornado.escape import utf8
from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError, AsyncHTTPClient, main
class CurlAsyncHTTPClient(AsyncHTTPClient):
def initialize(self, io_loop=None, max_clients=10,
max_simultaneous_connections=None):
@@ -109,15 +110,17 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = self.io_loop.add_timeout(
time.time() + msecs/1000.0, self._handle_timeout)
time.time() + msecs / 1000.0, self._handle_timeout)
def _handle_events(self, fd, events):
"""Called by IOLoop when there is activity on one of our
file descriptors.
"""
action = 0
if events & ioloop.IOLoop.READ: action |= pycurl.CSELECT_IN
if events & ioloop.IOLoop.WRITE: action |= pycurl.CSELECT_OUT
if events & ioloop.IOLoop.READ:
action |= pycurl.CSELECT_IN
if events & ioloop.IOLoop.WRITE:
action |= pycurl.CSELECT_OUT
while True:
try:
ret, num_handles = self._socket_action(fd, action)
@@ -250,7 +253,6 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
except Exception:
self.handle_callback_exception(info["callback"])
def handle_callback_exception(self, callback):
self.io_loop.handle_callback_exception(callback)
@@ -372,7 +374,7 @@ def _curl_setup_request(curl, request, buffer, headers):
# Handle curl's cryptic options for every individual HTTP method
if request.method in ("POST", "PUT"):
request_buffer = cStringIO.StringIO(utf8(request.body))
request_buffer = cStringIO.StringIO(utf8(request.body))
curl.setopt(pycurl.READFUNCTION, request_buffer.read)
if request.method == "POST":
def ioctl(cmd):
@@ -393,8 +395,11 @@ def _curl_setup_request(curl, request, buffer, headers):
curl.unsetopt(pycurl.USERPWD)
logging.debug("%s %s", request.method, request.url)
if request.client_key is not None or request.client_cert is not None:
raise ValueError("Client certificate not supported with curl_httpclient")
if request.client_cert is not None:
curl.setopt(pycurl.SSLCERT, request.client_cert)
if request.client_key is not None:
curl.setopt(pycurl.SSLKEY, request.client_key)
if threading.activeCount() > 1:
# libcurl/pycurl is not thread-safe by default. When multiple threads
@@ -420,6 +425,7 @@ def _curl_header_callback(headers, header_line):
return
headers.parse_line(header_line)
def _curl_debug(debug_type, debug_msg):
debug_types = ('I', '<', '>', '<', '>')
if debug_type == 0:
Regular → Executable
+27 -18
View File
@@ -16,14 +16,24 @@
"""A lightweight wrapper around MySQLdb."""
from __future__ import absolute_import, division, with_statement
import copy
import MySQLdb.constants
import MySQLdb.converters
import MySQLdb.cursors
import itertools
import logging
import time
try:
import MySQLdb.constants
import MySQLdb.converters
import MySQLdb.cursors
except ImportError:
# If MySQLdb isn't available this module won't actually be useable,
# but we want it to at least be importable (mainly for readthedocs.org,
# which has limitations on third-party modules)
MySQLdb = None
class Connection(object):
"""A lightweight wrapper around MySQLdb DB-API connections.
@@ -41,7 +51,7 @@ class Connection(object):
UTF-8 on all connections to avoid time zone and encoding errors.
"""
def __init__(self, host, database, user=None, password=None,
max_idle_time=7*3600):
max_idle_time=7 * 3600):
self.host = host
self.database = database
self.max_idle_time = max_idle_time
@@ -210,20 +220,19 @@ class Row(dict):
except KeyError:
raise AttributeError(name)
if MySQLdb is not None:
# Fix the access conversions to properly recognize unicode/binary
FIELD_TYPE = MySQLdb.constants.FIELD_TYPE
FLAG = MySQLdb.constants.FLAG
CONVERSIONS = copy.copy(MySQLdb.converters.conversions)
# Fix the access conversions to properly recognize unicode/binary
FIELD_TYPE = MySQLdb.constants.FIELD_TYPE
FLAG = MySQLdb.constants.FLAG
CONVERSIONS = copy.copy(MySQLdb.converters.conversions)
field_types = [FIELD_TYPE.BLOB, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING]
if 'VARCHAR' in vars(FIELD_TYPE):
field_types.append(FIELD_TYPE.VARCHAR)
field_types = [FIELD_TYPE.BLOB, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING]
if 'VARCHAR' in vars(FIELD_TYPE):
field_types.append(FIELD_TYPE.VARCHAR)
for field_type in field_types:
CONVERSIONS[field_type] = [(FLAG.BINARY, str)] + CONVERSIONS[field_type]
for field_type in field_types:
CONVERSIONS[field_type] = [(FLAG.BINARY, str)] + CONVERSIONS[field_type]
# Alias some common MySQL exceptions
IntegrityError = MySQLdb.IntegrityError
OperationalError = MySQLdb.OperationalError
# Alias some common MySQL exceptions
IntegrityError = MySQLdb.IntegrityError
OperationalError = MySQLdb.OperationalError
Regular → Executable
View File
Regular → Executable
+34 -10
View File
@@ -20,14 +20,18 @@ Also includes a few other miscellaneous string manipulation functions that
have crept in over time.
"""
from __future__ import absolute_import, division, with_statement
import htmlentitydefs
import re
import sys
import urllib
# Python3 compatibility: On python2.5, introduce the bytes alias from 2.6
try: bytes
except Exception: bytes = str
try:
bytes
except Exception:
bytes = str
try:
from urlparse import parse_qs # Python 2.6+
@@ -62,6 +66,8 @@ except Exception:
_XHTML_ESCAPE_RE = re.compile('[&<>"]')
_XHTML_ESCAPE_DICT = {'&': '&amp;', '<': '&lt;', '>': '&gt;', '"': '&quot;'}
def xhtml_escape(value):
"""Escapes a string so it is valid within XML or XHTML."""
return _XHTML_ESCAPE_RE.sub(lambda match: _XHTML_ESCAPE_DICT[match.group(0)],
@@ -143,13 +149,14 @@ else:
result = parse_qs(qs, keep_blank_values, strict_parsing,
encoding='latin1', errors='strict')
encoded = {}
for k,v in result.iteritems():
for k, v in result.iteritems():
encoded[k] = [i.encode('latin1') for i in v]
return encoded
_UTF8_TYPES = (bytes, type(None))
def utf8(value):
"""Converts a string argument to a byte string.
@@ -162,6 +169,8 @@ def utf8(value):
return value.encode("utf-8")
_TO_UNICODE_TYPES = (unicode, type(None))
def to_unicode(value):
"""Converts a string argument to a unicode string.
@@ -185,6 +194,8 @@ else:
native_str = utf8
_BASESTRING_TYPES = (basestring, type(None))
def to_basestring(value):
"""Converts a string argument to a subclass of basestring.
@@ -199,13 +210,14 @@ def to_basestring(value):
assert isinstance(value, bytes)
return value.decode("utf-8")
def recursive_unicode(obj):
"""Walks a simple data structure, converting byte strings to unicode.
Supports lists, tuples, and dictionaries.
"""
if isinstance(obj, dict):
return dict((recursive_unicode(k), recursive_unicode(v)) for (k,v) in obj.iteritems())
return dict((recursive_unicode(k), recursive_unicode(v)) for (k, v) in obj.iteritems())
elif isinstance(obj, list):
return list(recursive_unicode(i) for i in obj)
elif isinstance(obj, tuple):
@@ -215,7 +227,7 @@ def recursive_unicode(obj):
else:
return obj
# I originally used the regex from
# I originally used the regex from
# http://daringfireball.net/2010/07/improved_regex_for_matching_urls
# but it gets all exponential on certain patterns (such as too many trailing
# dots), causing the regex matcher to never return.
@@ -234,8 +246,17 @@ def linkify(text, shorten=False, extra_params="",
shorten: Long urls will be shortened for display.
extra_params: Extra text to include in the link tag,
e.g. linkify(text, extra_params='rel="nofollow" class="external"')
extra_params: Extra text to include in the link tag, or a callable
taking the link as an argument and returning the extra text
e.g. ``linkify(text, extra_params='rel="nofollow" class="external"')``,
or::
def extra_params_cb(url):
if url.startswith("http://example.com"):
return 'class="internal"'
else:
return 'class="external" rel="nofollow"'
linkify(text, extra_params=extra_params_cb)
require_protocol: Only linkify urls which include a protocol. If this is
False, urls such as www.facebook.com will also be linkified.
@@ -244,7 +265,7 @@ def linkify(text, shorten=False, extra_params="",
e.g. linkify(text, permitted_protocols=["http", "ftp", "mailto"]).
It is very unsafe to include protocols such as "javascript".
"""
if extra_params:
if extra_params and not callable(extra_params):
extra_params = " " + extra_params.strip()
def make_link(m):
@@ -260,7 +281,10 @@ def linkify(text, shorten=False, extra_params="",
if not proto:
href = "http://" + href # no proto specified, use http
params = extra_params
if callable(extra_params):
params = " " + extra_params(href).strip()
else:
params = extra_params
# clip long urls. max_len is just an approximation
max_len = 30
Regular → Executable
+43 -16
View File
@@ -62,7 +62,7 @@ it was called with one argument, the result is that argument. If it was
called with more than one argument or any keyword arguments, the result
is an `Arguments` object, which is a named tuple ``(args, kwargs)``.
"""
from __future__ import with_statement
from __future__ import absolute_import, division, with_statement
import functools
import operator
@@ -71,10 +71,22 @@ import types
from tornado.stack_context import ExceptionStackContext
class KeyReuseError(Exception): pass
class UnknownKeyError(Exception): pass
class LeakedCallbackError(Exception): pass
class BadYieldError(Exception): pass
class KeyReuseError(Exception):
pass
class UnknownKeyError(Exception):
pass
class LeakedCallbackError(Exception):
pass
class BadYieldError(Exception):
pass
def engine(func):
"""Decorator for asynchronous generators.
@@ -92,6 +104,7 @@ def engine(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
runner = None
def handle_exception(typ, value, tb):
# if the function throws an exception before its first "yield"
# (or is not a generator at all), the Runner won't exist yet.
@@ -100,21 +113,23 @@ def engine(func):
if runner is not None:
return runner.handle_exception(typ, value, tb)
return False
with ExceptionStackContext(handle_exception):
with ExceptionStackContext(handle_exception) as deactivate:
gen = func(*args, **kwargs)
if isinstance(gen, types.GeneratorType):
runner = Runner(gen)
runner = Runner(gen, deactivate)
runner.run()
return
assert gen is None, gen
deactivate()
# no yield, so we're done
return wrapper
class YieldPoint(object):
"""Base class for objects that may be yielded from the generator."""
def start(self, runner):
"""Called by the runner after the generator has yielded.
No other methods will be called on this object before ``start``.
"""
raise NotImplementedError()
@@ -128,12 +143,13 @@ class YieldPoint(object):
def get_result(self):
"""Returns the value to use as the result of the yield expression.
This method will only be called once, and only after `is_ready`
has returned true.
"""
raise NotImplementedError()
class Callback(YieldPoint):
"""Returns a callable object that will allow a matching `Wait` to proceed.
@@ -159,6 +175,7 @@ class Callback(YieldPoint):
def get_result(self):
return self.runner.result_callback(self.key)
class Wait(YieldPoint):
"""Returns the argument passed to the result of a previous `Callback`."""
def __init__(self, key):
@@ -173,6 +190,7 @@ class Wait(YieldPoint):
def get_result(self):
return self.runner.pop_result(self.key)
class WaitAll(YieldPoint):
"""Returns the results of multiple previous `Callbacks`.
@@ -189,10 +207,10 @@ class WaitAll(YieldPoint):
def is_ready(self):
return all(self.runner.is_ready(key) for key in self.keys)
def get_result(self):
return [self.runner.pop_result(key) for key in self.keys]
class Task(YieldPoint):
"""Runs a single asynchronous operation.
@@ -203,9 +221,9 @@ class Task(YieldPoint):
A `Task` is equivalent to a `Callback`/`Wait` pair (with a unique
key generated automatically)::
result = yield gen.Task(func, args)
func(args, callback=(yield gen.Callback(key)))
result = yield gen.Wait(key)
"""
@@ -221,13 +239,14 @@ class Task(YieldPoint):
runner.register_callback(self.key)
self.kwargs["callback"] = runner.result_callback(self.key)
self.func(*self.args, **self.kwargs)
def is_ready(self):
return self.runner.is_ready(self.key)
def get_result(self):
return self.runner.pop_result(self.key)
class Multi(YieldPoint):
"""Runs multiple asynchronous operations in parallel.
@@ -239,7 +258,7 @@ class Multi(YieldPoint):
def __init__(self, children):
assert all(isinstance(i, YieldPoint) for i in children)
self.children = children
def start(self, runner):
for i in self.children:
i.start(runner)
@@ -250,21 +269,26 @@ class Multi(YieldPoint):
def get_result(self):
return [i.get_result() for i in self.children]
class _NullYieldPoint(YieldPoint):
def start(self, runner):
pass
def is_ready(self):
return True
def get_result(self):
return None
class Runner(object):
"""Internal implementation of `tornado.gen.engine`.
Maintains information about pending callbacks and their results.
"""
def __init__(self, gen):
def __init__(self, gen, deactivate_stack_context):
self.gen = gen
self.deactivate_stack_context = deactivate_stack_context
self.yield_point = _NullYieldPoint()
self.pending_callbacks = set()
self.results = {}
@@ -329,6 +353,7 @@ class Runner(object):
raise LeakedCallbackError(
"finished without waiting for callbacks %r" %
self.pending_callbacks)
self.deactivate_stack_context()
return
except Exception:
self.finished = True
@@ -366,6 +391,8 @@ class Runner(object):
return False
# in python 2.6+ this could be a collections.namedtuple
class Arguments(tuple):
"""The result of a yield expression whose callback had more than one
argument (or keyword arguments).
Regular → Executable
+42 -17
View File
@@ -29,6 +29,8 @@ you use a recent version of ``libcurl`` and ``pycurl``. Currently the minimum
supported version is 7.18.2, and the recommended version is 7.21.1 or newer.
"""
from __future__ import absolute_import, division, with_statement
import calendar
import email.utils
import httplib
@@ -40,6 +42,7 @@ from tornado import httputil
from tornado.ioloop import IOLoop
from tornado.util import import_object, bytes_type
class HTTPClient(object):
"""A blocking HTTP client.
@@ -54,11 +57,11 @@ class HTTPClient(object):
except httpclient.HTTPError, e:
print "Error:", e
"""
def __init__(self, async_client_class=None):
def __init__(self, async_client_class=None, **kwargs):
self._io_loop = IOLoop()
if async_client_class is None:
async_client_class = AsyncHTTPClient
self._async_client = async_client_class(self._io_loop)
self._async_client = async_client_class(self._io_loop, **kwargs)
self._response = None
self._closed = False
@@ -74,7 +77,7 @@ class HTTPClient(object):
def fetch(self, request, **kwargs):
"""Executes a request, returning an `HTTPResponse`.
The request may be either a string URL or an `HTTPRequest` object.
If it is a string, we construct an `HTTPRequest` using any additional
kwargs: ``HTTPRequest(request, **kwargs)``
@@ -91,6 +94,7 @@ class HTTPClient(object):
response.rethrow()
return response
class AsyncHTTPClient(object):
"""An non-blocking HTTP client.
@@ -120,6 +124,8 @@ class AsyncHTTPClient(object):
_impl_class = None
_impl_kwargs = None
_DEFAULT_MAX_CLIENTS = 10
@classmethod
def _async_clients(cls):
assert cls is not AsyncHTTPClient, "should only be called on subclasses"
@@ -127,7 +133,7 @@ class AsyncHTTPClient(object):
cls._async_client_dict = weakref.WeakKeyDictionary()
return cls._async_client_dict
def __new__(cls, io_loop=None, max_clients=10, force_instance=False,
def __new__(cls, io_loop=None, max_clients=None, force_instance=False,
**kwargs):
io_loop = io_loop or IOLoop.instance()
if cls is AsyncHTTPClient:
@@ -145,7 +151,13 @@ class AsyncHTTPClient(object):
if cls._impl_kwargs:
args.update(cls._impl_kwargs)
args.update(kwargs)
instance.initialize(io_loop, max_clients, **args)
if max_clients is not None:
# max_clients is special because it may be passed
# positionally instead of by keyword
args["max_clients"] = max_clients
elif "max_clients" not in args:
args["max_clients"] = AsyncHTTPClient._DEFAULT_MAX_CLIENTS
instance.initialize(io_loop, **args)
if not force_instance:
impl._async_clients()[io_loop] = instance
return instance
@@ -200,6 +212,16 @@ class AsyncHTTPClient(object):
AsyncHTTPClient._impl_class = impl
AsyncHTTPClient._impl_kwargs = kwargs
@staticmethod
def _save_configuration():
return (AsyncHTTPClient._impl_class, AsyncHTTPClient._impl_kwargs)
@staticmethod
def _restore_configuration(saved):
AsyncHTTPClient._impl_class = saved[0]
AsyncHTTPClient._impl_kwargs = saved[1]
class HTTPRequest(object):
"""HTTP client request object."""
def __init__(self, url, method="GET", headers=None, body=None,
@@ -235,23 +257,23 @@ class HTTPRequest(object):
:arg bool use_gzip: Request gzip encoding from the server
:arg string network_interface: Network interface to use for request
:arg callable streaming_callback: If set, `streaming_callback` will
be run with each chunk of data as it is received, and
`~HTTPResponse.body` and `~HTTPResponse.buffer` will be empty in
be run with each chunk of data as it is received, and
`~HTTPResponse.body` and `~HTTPResponse.buffer` will be empty in
the final response.
:arg callable header_callback: If set, `header_callback` will
be run with each header line as it is received, and
be run with each header line as it is received, and
`~HTTPResponse.headers` will be empty in the final response.
:arg callable prepare_curl_callback: If set, will be called with
a `pycurl.Curl` object to allow the application to make additional
`setopt` calls.
:arg string proxy_host: HTTP proxy hostname. To use proxies,
`proxy_host` and `proxy_port` must be set; `proxy_username` and
`proxy_pass` are optional. Proxies are currently only support
:arg string proxy_host: HTTP proxy hostname. To use proxies,
`proxy_host` and `proxy_port` must be set; `proxy_username` and
`proxy_pass` are optional. Proxies are currently only support
with `curl_httpclient`.
:arg int proxy_port: HTTP proxy port
:arg string proxy_username: HTTP proxy username
:arg string proxy_password: HTTP proxy password
:arg bool allow_nonstandard_methods: Allow unknown values for `method`
:arg bool allow_nonstandard_methods: Allow unknown values for `method`
argument?
:arg bool validate_cert: For HTTPS requests, validate the server's
certificate?
@@ -260,7 +282,7 @@ class HTTPRequest(object):
any request uses a custom `ca_certs` file, they all must (they
don't have to all use the same `ca_certs`, but it's not possible
to mix requests with ca_certs and requests that use the defaults.
:arg bool allow_ipv6: Use IPv6 when available? Default is false in
:arg bool allow_ipv6: Use IPv6 when available? Default is false in
`simple_httpclient` and true in `curl_httpclient`
:arg string client_key: Filename for client SSL key, if any
:arg string client_cert: Filename for client SSL certificate, if any
@@ -325,12 +347,15 @@ class HTTPResponse(object):
plus 'queue', which is the delay (if any) introduced by waiting for
a slot under AsyncHTTPClient's max_clients setting.
"""
def __init__(self, request, code, headers={}, buffer=None,
def __init__(self, request, code, headers=None, buffer=None,
effective_url=None, error=None, request_time=None,
time_info={}):
time_info=None):
self.request = request
self.code = code
self.headers = headers
if headers is not None:
self.headers = headers
else:
self.headers = httputil.HTTPHeaders()
self.buffer = buffer
self._body = None
if effective_url is None:
@@ -345,7 +370,7 @@ class HTTPResponse(object):
else:
self.error = error
self.request_time = request_time
self.time_info = time_info
self.time_info = time_info or {}
def _get_body(self):
if self.buffer is None:
Regular → Executable
+32 -39
View File
@@ -24,13 +24,14 @@ This module also defines the `HTTPRequest` class which is exposed via
`tornado.web.RequestHandler.request`.
"""
from __future__ import absolute_import, division, with_statement
import Cookie
import logging
import socket
import time
import urlparse
from tornado.escape import utf8, native_str, parse_qs_bytes
from tornado.escape import native_str, parse_qs_bytes
from tornado import httputil
from tornado import iostream
from tornado.netutil import TCPServer
@@ -38,10 +39,11 @@ from tornado import stack_context
from tornado.util import b, bytes_type
try:
import ssl # Python 2.6+
import ssl # Python 2.6+
except ImportError:
ssl = None
class HTTPServer(TCPServer):
r"""A non-blocking, single-threaded HTTP server.
@@ -103,7 +105,7 @@ class HTTPServer(TCPServer):
In many cases, `tornado.web.Application.listen` can be used to avoid
the need to explicitly create the `HTTPServer`.
2. `~tornado.netutil.TCPServer.bind`/`~tornado.netutil.TCPServer.start`:
2. `~tornado.netutil.TCPServer.bind`/`~tornado.netutil.TCPServer.start`:
simple multi-process::
server = HTTPServer(app)
@@ -143,10 +145,12 @@ class HTTPServer(TCPServer):
HTTPConnection(stream, address, self.request_callback,
self.no_keep_alive, self.xheaders)
class _BadRequestException(Exception):
"""Exception class for malformed HTTP requests."""
pass
class HTTPConnection(object):
"""Handles a connection to an HTTP client, executing HTTP requests.
@@ -156,9 +160,6 @@ class HTTPConnection(object):
def __init__(self, stream, address, request_callback, no_keep_alive=False,
xheaders=False):
self.stream = stream
if self.stream.socket.family not in (socket.AF_INET, socket.AF_INET6):
# Unix (or other) socket; fake the remote address
address = ('0.0.0.0', 0)
self.address = address
self.request_callback = request_callback
self.no_keep_alive = no_keep_alive
@@ -189,7 +190,7 @@ class HTTPConnection(object):
if self._write_callback is not None:
callback = self._write_callback
self._write_callback = None
callback()
callback()
# _on_write_complete is enqueued on the IOLoop whenever the
# IOStream's write buffer becomes empty, but it's possible for
# another callback that runs on the IOLoop before it to
@@ -233,9 +234,20 @@ class HTTPConnection(object):
if not version.startswith("HTTP/"):
raise _BadRequestException("Malformed HTTP version in HTTP Request-Line")
headers = httputil.HTTPHeaders.parse(data[eol:])
# HTTPRequest wants an IP, not a full socket address
if getattr(self.stream.socket, 'family', socket.AF_INET) in (
socket.AF_INET, socket.AF_INET6):
# Jython 2.5.2 doesn't have the socket.family attribute,
# so just assume IP in that case.
remote_ip = self.address[0]
else:
# Unix (or other) socket; fake the remote address
remote_ip = '0.0.0.0'
self._request = HTTPRequest(
connection=self, method=method, uri=uri, version=version,
headers=headers, remote_ip=self.address[0])
headers=headers, remote_ip=remote_ip)
content_length = headers.get("Content-Length")
if content_length:
@@ -256,27 +268,10 @@ class HTTPConnection(object):
def _on_request_body(self, data):
self._request.body = data
content_type = self._request.headers.get("Content-Type", "")
if self._request.method in ("POST", "PUT"):
if content_type.startswith("application/x-www-form-urlencoded"):
arguments = parse_qs_bytes(native_str(self._request.body))
for name, values in arguments.iteritems():
values = [v for v in values if v]
if values:
self._request.arguments.setdefault(name, []).extend(
values)
elif content_type.startswith("multipart/form-data"):
fields = content_type.split(";")
for field in fields:
k, sep, v = field.strip().partition("=")
if k == "boundary" and v:
httputil.parse_multipart_form_data(
utf8(v), data,
self._request.arguments,
self._request.files)
break
else:
logging.warning("Invalid multipart/form-data")
if self._request.method in ("POST", "PATCH", "PUT"):
httputil.parse_body_arguments(
self._request.headers.get("Content-Type", ""), data,
self._request.arguments, self._request.files)
self.request_callback(self._request)
@@ -336,8 +331,8 @@ class HTTPRequest(object):
GET/POST arguments are available in the arguments property, which
maps arguments names to lists of values (to support multiple values
for individual names). Names are of type `str`, while arguments
are byte strings. Note that this is different from
`RequestHandler.get_argument`, which returns argument values as
are byte strings. Note that this is different from
`RequestHandler.get_argument`, which returns argument values as
unicode strings.
.. attribute:: files
@@ -375,7 +370,7 @@ class HTTPRequest(object):
self.remote_ip = remote_ip
if protocol:
self.protocol = protocol
elif connection and isinstance(connection.stream,
elif connection and isinstance(connection.stream,
iostream.SSLIOStream):
self.protocol = "https"
else:
@@ -386,14 +381,13 @@ class HTTPRequest(object):
self._start_time = time.time()
self._finish_time = None
scheme, netloc, path, query, fragment = urlparse.urlsplit(native_str(uri))
self.path = path
self.query = query
arguments = parse_qs_bytes(query)
self.path, sep, self.query = uri.partition('?')
arguments = parse_qs_bytes(self.query)
self.arguments = {}
for name, values in arguments.iteritems():
values = [v for v in values if v]
if values: self.arguments[name] = values
if values:
self.arguments[name] = values
def supports_http_1_1(self):
"""Returns True if this request supports HTTP/1.1 semantics"""
@@ -473,4 +467,3 @@ class HTTPRequest(object):
return False
raise
return True
Regular → Executable
+45 -13
View File
@@ -16,12 +16,16 @@
"""HTTP utility code shared by clients and servers."""
from __future__ import absolute_import, division, with_statement
import logging
import urllib
import re
from tornado.escape import native_str, parse_qs_bytes, utf8
from tornado.util import b, ObjectDict
class HTTPHeaders(dict):
"""A dictionary that maintains Http-Header-Case for all keys.
@@ -55,7 +59,14 @@ class HTTPHeaders(dict):
dict.__init__(self)
self._as_list = {}
self._last_key = None
self.update(*args, **kwargs)
if (len(args) == 1 and len(kwargs) == 0 and
isinstance(args[0], HTTPHeaders)):
# Copy constructor
for k, v in args[0].get_all():
self.add(k, v)
else:
# Dict-style initialization
self.update(*args, **kwargs)
# new public methods
@@ -144,6 +155,10 @@ class HTTPHeaders(dict):
for k, v in dict(*args, **kwargs).iteritems():
self[k] = v
def copy(self):
# default implementation returns dict(self), not the subclass
return HTTPHeaders(self)
_NORMALIZED_HEADER_RE = re.compile(r'^[A-Z0-9][a-z0-9]*(-[A-Z0-9][a-z0-9]*)*$')
_normalized_headers = {}
@@ -172,7 +187,8 @@ def url_concat(url, args):
>>> url_concat("http://example.com/foo?a=b", dict(c="d"))
'http://example.com/foo?a=b&c=d'
"""
if not args: return url
if not args:
return url
if url[-1] not in ('?', '&'):
url += '&' if ('?' in url) else '?'
return url + urllib.urlencode(args)
@@ -190,6 +206,24 @@ class HTTPFile(ObjectDict):
pass
def parse_body_arguments(content_type, body, arguments, files):
if content_type.startswith("application/x-www-form-urlencoded"):
uri_arguments = parse_qs_bytes(native_str(body))
for name, values in uri_arguments.iteritems():
values = [v for v in values if v]
if values:
arguments.setdefault(name, []).extend(values)
elif content_type.startswith("multipart/form-data"):
fields = content_type.split(";")
for field in fields:
k, sep, v = field.strip().partition("=")
if k == "boundary" and v:
parse_multipart_form_data(utf8(v), body, arguments, files)
break
else:
logging.warning("Invalid multipart/form-data")
def parse_multipart_form_data(boundary, data, arguments, files):
"""Parses a multipart/form-data body.
@@ -204,13 +238,14 @@ def parse_multipart_form_data(boundary, data, arguments, files):
# in the wild.
if boundary.startswith(b('"')) and boundary.endswith(b('"')):
boundary = boundary[1:-1]
if data.endswith(b("\r\n")):
footer_length = len(boundary) + 6
else:
footer_length = len(boundary) + 4
parts = data[:-footer_length].split(b("--") + boundary + b("\r\n"))
final_boundary_index = data.rfind(b("--") + boundary + b("--"))
if final_boundary_index == -1:
logging.warning("Invalid multipart/form-data: no final boundary")
return
parts = data[:final_boundary_index].split(b("--") + boundary + b("\r\n"))
for part in parts:
if not part: continue
if not part:
continue
eoh = part.find(b("\r\n\r\n"))
if eoh == -1:
logging.warning("multipart/form-data missing headers")
@@ -250,6 +285,7 @@ def _parseparam(s):
yield f.strip()
s = s[end:]
def _parse_header(line):
"""Parse a Content-type like header.
@@ -263,7 +299,7 @@ def _parse_header(line):
i = p.find('=')
if i >= 0:
name = p[:i].strip().lower()
value = p[i+1:].strip()
value = p[i + 1:].strip()
if len(value) >= 2 and value[0] == value[-1] == '"':
value = value[1:-1]
value = value.replace('\\\\', '\\').replace('\\"', '"')
@@ -274,7 +310,3 @@ def _parse_header(line):
def doctests():
import doctest
return doctest.DocTestSuite()
if __name__ == "__main__":
import doctest
doctest.testmod()
Regular → Executable
+37 -8
View File
@@ -26,7 +26,7 @@ In addition to I/O events, the `IOLoop` can also schedule time-based events.
`IOLoop.add_timeout` is a non-blocking alternative to `time.sleep`.
"""
from __future__ import with_statement
from __future__ import absolute_import, division, with_statement
import datetime
import errno
@@ -104,6 +104,9 @@ class IOLoop(object):
WRITE = _EPOLLOUT
ERROR = _EPOLLERR | _EPOLLHUP
# Global lock for creating global IOLoop instance
_instance_lock = threading.Lock()
def __init__(self, impl=None):
self._impl = impl or _poll()
if hasattr(self._impl, 'fileno'):
@@ -142,7 +145,10 @@ class IOLoop(object):
self.io_loop = io_loop or IOLoop.instance()
"""
if not hasattr(IOLoop, "_instance"):
IOLoop._instance = IOLoop()
with IOLoop._instance_lock:
if not hasattr(IOLoop, "_instance"):
# New instance after double check
IOLoop._instance = IOLoop()
return IOLoop._instance
@staticmethod
@@ -164,7 +170,20 @@ class IOLoop(object):
"""Closes the IOLoop, freeing any resources used.
If ``all_fds`` is true, all file descriptors registered on the
IOLoop will be closed (not just the ones created by the IOLoop itself.
IOLoop will be closed (not just the ones created by the IOLoop itself).
Many applications will only use a single IOLoop that runs for the
entire lifetime of the process. In that case closing the IOLoop
is not necessary since everything will be cleaned up when the
process exits. `IOLoop.close` is provided mainly for scenarios
such as unit tests, which create and destroy a large number of
IOLoops.
An IOLoop must be completely stopped before it can be closed. This
means that `IOLoop.stop()` must be called *and* `IOLoop.start()` must
be allowed to return before attempting to call `IOLoop.close()`.
Therefore the call to `close` will usually appear just after
the call to `start` rather than near the call to `stop`.
"""
self.remove_handler(self._waker.fileno())
if all_fds:
@@ -335,6 +354,9 @@ class IOLoop(object):
ioloop.start() will return after async_method has run its callback,
whether that callback was invoked before or after ioloop.start.
Note that even after `stop` has been called, the IOLoop is not
completely stopped until `IOLoop.start` has also returned.
"""
self._running = False
self._stopped = True
@@ -431,7 +453,7 @@ class _Timeout(object):
@staticmethod
def timedelta_to_seconds(td):
"""Equivalent to td.total_seconds() (introduced in python 2.7)."""
return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / float(10**6)
return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / float(10 ** 6)
# Comparison methods to sort by deadline, with object id as a tiebreaker
# to guarantee a consistent ordering. The heapq module uses __le__
@@ -474,7 +496,8 @@ class PeriodicCallback(object):
self._timeout = None
def _run(self):
if not self._running: return
if not self._running:
return
try:
self.callback()
except Exception:
@@ -530,6 +553,8 @@ class _KQueue(object):
self._kqueue.close()
def register(self, fd, events):
if fd in self._active:
raise IOError("fd %d already registered" % fd)
self._control(fd, events, select.KQ_EV_ADD)
self._active[fd] = events
@@ -591,8 +616,12 @@ class _Select(object):
pass
def register(self, fd, events):
if events & IOLoop.READ: self.read_fds.add(fd)
if events & IOLoop.WRITE: self.write_fds.add(fd)
if fd in self.read_fds or fd in self.write_fds or fd in self.error_fds:
raise IOError("fd %d already registered" % fd)
if events & IOLoop.READ:
self.read_fds.add(fd)
if events & IOLoop.WRITE:
self.write_fds.add(fd)
if events & IOLoop.ERROR:
self.error_fds.add(fd)
# Closed connections are reported as errors by epoll and kqueue,
@@ -633,7 +662,7 @@ elif hasattr(select, "kqueue"):
else:
try:
# Linux systems with our C module installed
import epoll
from tornado import epoll
_poll = _EPoll
except Exception:
# All other systems
Regular → Executable
+156 -120
View File
@@ -16,11 +16,12 @@
"""A utility class to write to and read from a non-blocking socket."""
from __future__ import with_statement
from __future__ import absolute_import, division, with_statement
import collections
import errno
import logging
import os
import socket
import sys
import re
@@ -30,16 +31,17 @@ from tornado import stack_context
from tornado.util import b, bytes_type
try:
import ssl # Python 2.6+
import ssl # Python 2.6+
except ImportError:
ssl = None
class IOStream(object):
r"""A utility class to write to and read from a non-blocking socket.
We support a non-blocking ``write()`` and a family of ``read_*()`` methods.
All of the methods take callbacks (since writing and reading are
non-blocking and asynchronous).
non-blocking and asynchronous).
The socket parameter may either be connected or unconnected. For
server operations the socket is the result of calling socket.accept().
@@ -47,6 +49,9 @@ class IOStream(object):
and may either be connected before passing it to the IOStream or
connected with IOStream.connect.
When a stream is closed due to an error, the IOStream's `error`
attribute contains the exception object.
A very simple (and broken) HTTP client using this class::
from tornado import ioloop
@@ -83,6 +88,7 @@ class IOStream(object):
self.io_loop = io_loop or ioloop.IOLoop.instance()
self.max_buffer_size = max_buffer_size
self.read_chunk_size = read_chunk_size
self.error = None
self._read_buffer = collections.deque()
self._write_buffer = collections.deque()
self._read_buffer_size = 0
@@ -136,31 +142,15 @@ class IOStream(object):
def read_until_regex(self, regex, callback):
"""Call callback when we read the given regex pattern."""
assert not self._read_callback, "Already reading"
self._set_read_callback(callback)
self._read_regex = re.compile(regex)
self._read_callback = stack_context.wrap(callback)
while True:
# See if we've already got the data from a previous read
if self._read_from_buffer():
return
self._check_closed()
if self._read_to_buffer() == 0:
break
self._add_io_state(self.io_loop.READ)
self._try_inline_read()
def read_until(self, delimiter, callback):
"""Call callback when we read the given delimiter."""
assert not self._read_callback, "Already reading"
self._set_read_callback(callback)
self._read_delimiter = delimiter
self._read_callback = stack_context.wrap(callback)
while True:
# See if we've already got the data from a previous read
if self._read_from_buffer():
return
self._check_closed()
if self._read_to_buffer() == 0:
break
self._add_io_state(self.io_loop.READ)
self._try_inline_read()
def read_bytes(self, num_bytes, callback, streaming_callback=None):
"""Call callback when we read the given number of bytes.
@@ -169,18 +159,11 @@ class IOStream(object):
of data as they become available, and the argument to the final
``callback`` will be empty.
"""
assert not self._read_callback, "Already reading"
self._set_read_callback(callback)
assert isinstance(num_bytes, (int, long))
self._read_bytes = num_bytes
self._read_callback = stack_context.wrap(callback)
self._streaming_callback = stack_context.wrap(streaming_callback)
while True:
if self._read_from_buffer():
return
self._check_closed()
if self._read_to_buffer() == 0:
break
self._add_io_state(self.io_loop.READ)
self._try_inline_read()
def read_until_close(self, callback, streaming_callback=None):
"""Reads all data from the socket until it is closed.
@@ -192,12 +175,12 @@ class IOStream(object):
Subject to ``max_buffer_size`` limit from `IOStream` constructor if
a ``streaming_callback`` is not used.
"""
assert not self._read_callback, "Already reading"
self._set_read_callback(callback)
if self.closed():
self._run_callback(callback, self._consume(self._read_buffer_size))
self._read_callback = None
return
self._read_until_close = True
self._read_callback = stack_context.wrap(callback)
self._streaming_callback = stack_context.wrap(streaming_callback)
self._add_io_state(self.io_loop.READ)
@@ -211,10 +194,18 @@ class IOStream(object):
"""
assert isinstance(data, bytes_type)
self._check_closed()
# We use bool(_write_buffer) as a proxy for write_buffer_size>0,
# so never put empty strings in the buffer.
if data:
# We use bool(_write_buffer) as a proxy for write_buffer_size>0,
# so never put empty strings in the buffer.
self._write_buffer.append(data)
# Break up large contiguous strings before inserting them in the
# write buffer, so we don't have to recopy the entire thing
# as we slice off pieces to send to the socket.
WRITE_BUFFER_CHUNK_SIZE = 128 * 1024
if len(data) > WRITE_BUFFER_CHUNK_SIZE:
for i in range(0, len(data), WRITE_BUFFER_CHUNK_SIZE):
self._write_buffer.append(data[i:i + WRITE_BUFFER_CHUNK_SIZE])
else:
self._write_buffer.append(data)
self._write_callback = stack_context.wrap(callback)
self._handle_write()
if self._write_buffer:
@@ -228,6 +219,8 @@ class IOStream(object):
def close(self):
"""Close this stream."""
if self.socket is not None:
if any(sys.exc_info()):
self.error = sys.exc_info()[1]
if self._read_until_close:
callback = self._read_callback
self._read_callback = None
@@ -239,12 +232,16 @@ class IOStream(object):
self._state = None
self.socket.close()
self.socket = None
if self._close_callback and self._pending_callbacks == 0:
# if there are pending callbacks, don't run the close callback
# until they're done (see _maybe_add_error_handler)
cb = self._close_callback
self._close_callback = None
self._run_callback(cb)
self._maybe_run_close_callback()
def _maybe_run_close_callback(self):
if (self.socket is None and self._close_callback and
self._pending_callbacks == 0):
# if there are pending callbacks, don't run the close callback
# until they're done (see _maybe_add_error_handler)
cb = self._close_callback
self._close_callback = None
self._run_callback(cb)
def reading(self):
"""Returns true if we are currently reading from the stream."""
@@ -274,6 +271,9 @@ class IOStream(object):
if not self.socket:
return
if events & self.io_loop.ERROR:
errno = self.socket.getsockopt(socket.SOL_SOCKET,
socket.SO_ERROR)
self.error = socket.error(errno, os.strerror(errno))
# We may have queued up a user callback in _handle_read or
# _handle_write, so don't close the IOStream until those
# callbacks have had a chance to run.
@@ -332,22 +332,65 @@ class IOStream(object):
self.io_loop.add_callback(wrapper)
def _handle_read(self):
while True:
try:
try:
# Read from the socket until we get EWOULDBLOCK or equivalent.
# SSL sockets do some internal buffering, and if the data is
# sitting in the SSL object's buffer select() and friends
# can't see it; the only way to find out if it's there is to
# try to read it.
result = self._read_to_buffer()
except Exception:
self.close()
return
if result == 0:
break
else:
if self._read_from_buffer():
return
# Pretend to have a pending callback so that an EOF in
# _read_to_buffer doesn't trigger an immediate close
# callback. At the end of this method we'll either
# estabilsh a real pending callback via
# _read_from_buffer or run the close callback.
#
# We need two try statements here so that
# pending_callbacks is decremented before the `except`
# clause below (which calls `close` and does need to
# trigger the callback)
self._pending_callbacks += 1
while True:
# Read from the socket until we get EWOULDBLOCK or equivalent.
# SSL sockets do some internal buffering, and if the data is
# sitting in the SSL object's buffer select() and friends
# can't see it; the only way to find out if it's there is to
# try to read it.
if self._read_to_buffer() == 0:
break
finally:
self._pending_callbacks -= 1
except Exception:
logging.warning("error on read", exc_info=True)
self.close()
return
if self._read_from_buffer():
return
else:
self._maybe_run_close_callback()
def _set_read_callback(self, callback):
assert not self._read_callback, "Already reading"
self._read_callback = stack_context.wrap(callback)
def _try_inline_read(self):
"""Attempt to complete the current read operation from buffered data.
If the read can be completed without blocking, schedules the
read callback on the next IOLoop iteration; otherwise starts
listening for reads on the socket.
"""
# See if we've already got the data from a previous read
if self._read_from_buffer():
return
self._check_closed()
try:
# See comments in _handle_read about incrementing _pending_callbacks
self._pending_callbacks += 1
while True:
if self._read_to_buffer() == 0:
break
self._check_closed()
finally:
self._pending_callbacks -= 1
if self._read_from_buffer():
return
self._add_io_state(self.io_loop.READ)
def _read_from_socket(self):
"""Attempts to read from the socket.
@@ -397,20 +440,21 @@ class IOStream(object):
Returns True if the read was completed.
"""
if self._read_bytes is not None:
if self._streaming_callback is not None and self._read_buffer_size:
bytes_to_consume = min(self._read_bytes, self._read_buffer_size)
if self._streaming_callback is not None and self._read_buffer_size:
bytes_to_consume = self._read_buffer_size
if self._read_bytes is not None:
bytes_to_consume = min(self._read_bytes, bytes_to_consume)
self._read_bytes -= bytes_to_consume
self._run_callback(self._streaming_callback,
self._consume(bytes_to_consume))
if self._read_buffer_size >= self._read_bytes:
num_bytes = self._read_bytes
callback = self._read_callback
self._read_callback = None
self._streaming_callback = None
self._read_bytes = None
self._run_callback(callback, self._consume(num_bytes))
return True
self._run_callback(self._streaming_callback,
self._consume(bytes_to_consume))
if self._read_bytes is not None and self._read_buffer_size >= self._read_bytes:
num_bytes = self._read_bytes
callback = self._read_callback
self._read_callback = None
self._streaming_callback = None
self._read_bytes = None
self._run_callback(callback, self._consume(num_bytes))
return True
elif self._read_delimiter is not None:
# Multi-byte delimiters (e.g. '\r\n') may straddle two
# chunks in the read buffer, so we can't easily find them
@@ -420,56 +464,41 @@ class IOStream(object):
# to be in the first few chunks. Merge the buffer gradually
# since large merges are relatively expensive and get undone in
# consume().
loc = -1
if self._read_buffer:
loc = self._read_buffer[0].find(self._read_delimiter)
while loc == -1 and len(self._read_buffer) > 1:
# Grow by doubling, but don't split the second chunk just
# because the first one is small.
new_len = max(len(self._read_buffer[0]) * 2,
(len(self._read_buffer[0]) +
len(self._read_buffer[1])))
_merge_prefix(self._read_buffer, new_len)
loc = self._read_buffer[0].find(self._read_delimiter)
if loc != -1:
callback = self._read_callback
delimiter_len = len(self._read_delimiter)
self._read_callback = None
self._streaming_callback = None
self._read_delimiter = None
self._run_callback(callback,
self._consume(loc + delimiter_len))
return True
while True:
loc = self._read_buffer[0].find(self._read_delimiter)
if loc != -1:
callback = self._read_callback
delimiter_len = len(self._read_delimiter)
self._read_callback = None
self._streaming_callback = None
self._read_delimiter = None
self._run_callback(callback,
self._consume(loc + delimiter_len))
return True
if len(self._read_buffer) == 1:
break
_double_prefix(self._read_buffer)
elif self._read_regex is not None:
m = None
if self._read_buffer:
m = self._read_regex.search(self._read_buffer[0])
while m is None and len(self._read_buffer) > 1:
# Grow by doubling, but don't split the second chunk just
# because the first one is small.
new_len = max(len(self._read_buffer[0]) * 2,
(len(self._read_buffer[0]) +
len(self._read_buffer[1])))
_merge_prefix(self._read_buffer, new_len)
m = self._read_regex.search(self._read_buffer[0])
_merge_prefix(self._read_buffer, sys.maxint)
m = self._read_regex.search(self._read_buffer[0])
if m:
callback = self._read_callback
self._read_callback = None
self._streaming_callback = None
self._read_regex = None
self._run_callback(callback, self._consume(m.end()))
return True
elif self._read_until_close:
if self._streaming_callback is not None and self._read_buffer_size:
self._run_callback(self._streaming_callback,
self._consume(self._read_buffer_size))
while True:
m = self._read_regex.search(self._read_buffer[0])
if m is not None:
callback = self._read_callback
self._read_callback = None
self._streaming_callback = None
self._read_regex = None
self._run_callback(callback, self._consume(m.end()))
return True
if len(self._read_buffer) == 1:
break
_double_prefix(self._read_buffer)
return False
def _handle_connect(self):
err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if err != 0:
self.error = socket.error(err, os.strerror(err))
# IOLoop implementations may vary: some of them return
# an error state before the socket becomes writable, so
# in that case a connection failure would be handled by the
@@ -537,10 +566,7 @@ class IOStream(object):
def _maybe_add_error_listener(self):
if self._state is None and self._pending_callbacks == 0:
if self.socket is None:
cb = self._close_callback
if cb is not None:
self._close_callback = None
self._run_callback(cb)
self._maybe_run_close_callback()
else:
self._add_io_state(ioloop.IOLoop.READ)
@@ -628,7 +654,7 @@ class SSLIOStream(IOStream):
return self.close()
raise
except socket.error, err:
if err.args[0] == errno.ECONNABORTED:
if err.args[0] in (errno.ECONNABORTED, errno.ECONNRESET):
return self.close()
else:
self._ssl_accepting = False
@@ -655,7 +681,6 @@ class SSLIOStream(IOStream):
# until we've completed the SSL handshake (so certificates are
# available, etc).
def _read_from_socket(self):
if self._ssl_accepting:
# If the handshake hasn't finished yet, there can't be anything
@@ -686,6 +711,16 @@ class SSLIOStream(IOStream):
return None
return chunk
def _double_prefix(deque):
"""Grow by doubling, but don't split the second chunk just because the
first one is small.
"""
new_len = max(len(deque[0]) * 2,
(len(deque[0]) + len(deque[1])))
_merge_prefix(deque, new_len)
def _merge_prefix(deque, size):
"""Replace the first entries in a deque of strings with a single
string of up to size bytes.
@@ -723,6 +758,7 @@ def _merge_prefix(deque, size):
if not deque:
deque.appendleft(b(""))
def doctests():
import doctest
return doctest.DocTestSuite()
Regular → Executable
+55 -18
View File
@@ -39,17 +39,22 @@ supported by gettext and related tools). If neither method is called,
the locale.translate method will simply return the original string.
"""
from __future__ import absolute_import, division, with_statement
import csv
import datetime
import logging
import os
import re
from tornado import escape
_default_locale = "en_US"
_translations = {}
_supported_locales = frozenset([_default_locale])
_use_gettext = False
def get(*locale_codes):
"""Returns the closest match for the given locale codes.
@@ -109,17 +114,28 @@ def load_translations(directory):
global _supported_locales
_translations = {}
for path in os.listdir(directory):
if not path.endswith(".csv"): continue
if not path.endswith(".csv"):
continue
locale, extension = path.split(".")
if not re.match("[a-z]+(_[A-Z]+)?$", locale):
logging.error("Unrecognized locale %r (path: %s)", locale,
os.path.join(directory, path))
continue
f = open(os.path.join(directory, path), "r")
full_path = os.path.join(directory, path)
try:
# python 3: csv.reader requires a file open in text mode.
# Force utf8 to avoid dependence on $LANG environment variable.
f = open(full_path, "r", encoding="utf-8")
except TypeError:
# python 2: files return byte strings, which are decoded below.
# Once we drop python 2.5, this could use io.open instead
# on both 2 and 3.
f = open(full_path, "r")
_translations[locale] = {}
for i, row in enumerate(csv.reader(f)):
if not row or len(row) < 2: continue
row = [c.decode("utf-8").strip() for c in row]
if not row or len(row) < 2:
continue
row = [escape.to_unicode(c).strip() for c in row]
english, translation = row[:2]
if len(row) > 2:
plural = row[2] or "unknown"
@@ -132,7 +148,8 @@ def load_translations(directory):
_translations[locale].setdefault(plural, {})[english] = translation
f.close()
_supported_locales = frozenset(_translations.keys() + [_default_locale])
logging.info("Supported locales: %s", sorted(_supported_locales))
logging.debug("Supported locales: %s", sorted(_supported_locales))
def load_gettext_translations(directory, domain):
"""Loads translations from gettext's locale tree
@@ -158,10 +175,12 @@ def load_gettext_translations(directory, domain):
global _use_gettext
_translations = {}
for lang in os.listdir(directory):
if lang.startswith('.'): continue # skip .svn, etc
if os.path.isfile(os.path.join(directory, lang)): continue
if lang.startswith('.'):
continue # skip .svn, etc
if os.path.isfile(os.path.join(directory, lang)):
continue
try:
os.stat(os.path.join(directory, lang, "LC_MESSAGES", domain+".mo"))
os.stat(os.path.join(directory, lang, "LC_MESSAGES", domain + ".mo"))
_translations[lang] = gettext.translation(domain, directory,
languages=[lang])
except Exception, e:
@@ -169,10 +188,10 @@ def load_gettext_translations(directory, domain):
continue
_supported_locales = frozenset(_translations.keys() + [_default_locale])
_use_gettext = True
logging.info("Supported locales: %s", sorted(_supported_locales))
logging.debug("Supported locales: %s", sorted(_supported_locales))
def get_supported_locales(cls):
def get_supported_locales():
"""Returns a list of all the supported locale codes."""
return _supported_locales
@@ -187,7 +206,8 @@ class Locale(object):
def get_closest(cls, *locale_codes):
"""Returns the closest match for the given locale code."""
for code in locale_codes:
if not code: continue
if not code:
continue
code = code.replace("-", "_")
parts = code.split("_")
if len(parts) > 2:
@@ -289,16 +309,16 @@ class Locale(object):
if relative and days == 0:
if seconds < 50:
return _("1 second ago", "%(seconds)d seconds ago",
seconds) % { "seconds": seconds }
seconds) % {"seconds": seconds}
if seconds < 50 * 60:
minutes = round(seconds / 60.0)
return _("1 minute ago", "%(minutes)d minutes ago",
minutes) % { "minutes": minutes }
minutes) % {"minutes": minutes}
hours = round(seconds / (60.0 * 60))
return _("1 hour ago", "%(hours)d hours ago",
hours) % { "hours": hours }
hours) % {"hours": hours}
if days == 0:
format = _("%(time)s")
@@ -364,8 +384,10 @@ class Locale(object):
of size 1.
"""
_ = self.translate
if len(parts) == 0: return ""
if len(parts) == 1: return parts[0]
if len(parts) == 0:
return ""
if len(parts) == 1:
return parts[0]
comma = u' \u0648 ' if self.code.startswith("fa") else u", "
return _("%(commas)s and %(last)s") % {
"commas": comma.join(parts[:-1]),
@@ -383,6 +405,7 @@ class Locale(object):
value = value[:-3]
return ",".join(reversed(parts))
class CSVLocale(Locale):
"""Locale implementation using tornado's CSV translation format."""
def translate(self, message, plural_message=None, count=None):
@@ -397,14 +420,28 @@ class CSVLocale(Locale):
message_dict = self.translations.get("unknown", {})
return message_dict.get(message, message)
class GettextLocale(Locale):
"""Locale implementation using the gettext module."""
def __init__(self, code, translations):
try:
# python 2
self.ngettext = translations.ungettext
self.gettext = translations.ugettext
except AttributeError:
# python 3
self.ngettext = translations.ngettext
self.gettext = translations.gettext
# self.gettext must exist before __init__ is called, since it
# calls into self.translate
super(GettextLocale, self).__init__(code, translations)
def translate(self, message, plural_message=None, count=None):
if plural_message is not None:
assert count is not None
return self.translations.ungettext(message, plural_message, count)
return self.ngettext(message, plural_message, count)
else:
return self.translations.ugettext(message)
return self.gettext(message)
LOCALE_NAMES = {
"af_ZA": {"name_en": u"Afrikaans", "name": u"Afrikaans"},
Regular → Executable
+33 -4
View File
@@ -16,6 +16,8 @@
"""Miscellaneous network utility code."""
from __future__ import absolute_import, division, with_statement
import errno
import logging
import os
@@ -28,10 +30,11 @@ from tornado.iostream import IOStream, SSLIOStream
from tornado.platform.auto import set_close_exec
try:
import ssl # Python 2.6+
import ssl # Python 2.6+
except ImportError:
ssl = None
class TCPServer(object):
r"""A non-blocking, single-threaded TCP server.
@@ -89,6 +92,23 @@ class TCPServer(object):
self._pending_sockets = []
self._started = False
# Verify the SSL options. Otherwise we don't get errors until clients
# connect. This doesn't verify that the keys are legitimate, but
# the SSL module doesn't do that until there is a connected socket
# which seems like too much work
if self.ssl_options is not None:
# Only certfile is required: it can contain both keys
if 'certfile' not in self.ssl_options:
raise KeyError('missing key "certfile" in ssl_options')
if not os.path.exists(self.ssl_options['certfile']):
raise ValueError('certfile "%s" does not exist' %
self.ssl_options['certfile'])
if ('keyfile' in self.ssl_options and
not os.path.exists(self.ssl_options['keyfile'])):
raise ValueError('keyfile "%s" does not exist' %
self.ssl_options['keyfile'])
def listen(self, port, address=""):
"""Starts accepting connections on the given port.
@@ -231,19 +251,26 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128):
or socket.AF_INET6 to restrict to ipv4 or ipv6 addresses, otherwise
both will be used if available.
The ``backlog`` argument has the same meaning as for
The ``backlog`` argument has the same meaning as for
``socket.listen()``.
"""
sockets = []
if address == "":
address = None
flags = socket.AI_PASSIVE
if hasattr(socket, "AI_ADDRCONFIG"):
# AI_ADDRCONFIG ensures that we only try to bind on ipv6
# if the system is configured for it, but the flag doesn't
# exist on some platforms (specifically WinXP, although
# newer versions of windows have it)
flags |= socket.AI_ADDRCONFIG
for res in set(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM,
0, flags)):
af, socktype, proto, canonname, sockaddr = res
sock = socket.socket(af, socktype, proto)
set_close_exec(sock.fileno())
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if os.name != 'nt':
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if af == socket.AF_INET6:
# On linux, ipv6 sockets accept ipv4 too by default,
# but this makes it impossible to bind to both
@@ -269,7 +296,7 @@ if hasattr(socket, 'AF_UNIX'):
If any other file with that name exists, an exception will be
raised.
Returns a socket object (not a list of socket objects like
Returns a socket object (not a list of socket objects like
`bind_sockets`)
"""
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
@@ -291,6 +318,7 @@ if hasattr(socket, 'AF_UNIX'):
sock.listen(backlog)
return sock
def add_accept_handler(sock, callback, io_loop=None):
"""Adds an ``IOLoop`` event handler to accept new connections on ``sock``.
@@ -302,6 +330,7 @@ def add_accept_handler(sock, callback, io_loop=None):
"""
if io_loop is None:
io_loop = IOLoop.instance()
def accept_handler(fd, events):
while True:
try:
Regular → Executable
+198 -139
View File
@@ -48,12 +48,16 @@ kwarg to define). We also accept multi-value options. See the documentation
for define() below.
"""
from __future__ import absolute_import, division, with_statement
import datetime
import logging
import logging.handlers
import re
import sys
import os
import time
import textwrap
from tornado.escape import _unicode
@@ -64,136 +68,123 @@ except ImportError:
curses = None
def define(name, default=None, type=None, help=None, metavar=None,
multiple=False, group=None):
"""Defines a new command line option.
If type is given (one of str, float, int, datetime, or timedelta)
or can be inferred from the default, we parse the command line
arguments based on the given type. If multiple is True, we accept
comma-separated values, and the option value is always a list.
For multi-value integers, we also accept the syntax x:y, which
turns into range(x, y) - very useful for long integer ranges.
help and metavar are used to construct the automatically generated
command line help string. The help message is formatted like::
--name=METAVAR help string
group is used to group the defined options in logical groups. By default,
command line options are grouped by the defined file.
Command line option names must be unique globally. They can be parsed
from the command line with parse_command_line() or parsed from a
config file with parse_config_file.
"""
if name in options:
raise Error("Option %r already defined in %s", name,
options[name].file_name)
frame = sys._getframe(0)
options_file = frame.f_code.co_filename
file_name = frame.f_back.f_code.co_filename
if file_name == options_file: file_name = ""
if type is None:
if not multiple and default is not None:
type = default.__class__
else:
type = str
if group:
group_name = group
else:
group_name = file_name
options[name] = _Option(name, file_name=file_name, default=default,
type=type, help=help, metavar=metavar,
multiple=multiple, group_name=group_name)
def parse_command_line(args=None):
"""Parses all options given on the command line.
We return all command line arguments that are not options as a list.
"""
if args is None: args = sys.argv
remaining = []
for i in xrange(1, len(args)):
# All things after the last option are command line arguments
if not args[i].startswith("-"):
remaining = args[i:]
break
if args[i] == "--":
remaining = args[i+1:]
break
arg = args[i].lstrip("-")
name, equals, value = arg.partition("=")
name = name.replace('-', '_')
if not name in options:
print_help()
raise Error('Unrecognized command line option: %r' % name)
option = options[name]
if not equals:
if option.type == bool:
value = "true"
else:
raise Error('Option %r requires a value' % name)
option.parse(value)
if options.help:
print_help()
sys.exit(0)
# Set up log level and pretty console logging by default
if options.logging != 'none':
logging.getLogger().setLevel(getattr(logging, options.logging.upper()))
enable_pretty_logging()
return remaining
def parse_config_file(path):
"""Parses and loads the Python config file at the given path."""
config = {}
execfile(path, config, config)
for name in config:
if name in options:
options[name].set(config[name])
def print_help(file=sys.stdout):
"""Prints all the command line options to stdout."""
print >> file, "Usage: %s [OPTIONS]" % sys.argv[0]
print >> file, ""
print >> file, "Options:"
by_group = {}
for option in options.itervalues():
by_group.setdefault(option.group_name, []).append(option)
for filename, o in sorted(by_group.items()):
if filename: print >> file, filename
o.sort(key=lambda option: option.name)
for option in o:
prefix = option.name
if option.metavar:
prefix += "=" + option.metavar
print >> file, " --%-30s %s" % (prefix, option.help or "")
print >> file
class Error(Exception):
"""Exception raised by errors in the options module."""
pass
class _Options(dict):
"""Our global program options, an dictionary with object-like access."""
@classmethod
def instance(cls):
if not hasattr(cls, "_instance"):
cls._instance = cls()
return cls._instance
"""A collection of options, a dictionary with object-like access.
Normally accessed via static functions in the `tornado.options` module,
which reference a global instance.
"""
def __getattr__(self, name):
if isinstance(self.get(name), _Option):
return self[name].value()
raise AttributeError("Unrecognized option %r" % name)
def __setattr__(self, name, value):
if isinstance(self.get(name), _Option):
return self[name].set(value)
raise AttributeError("Unrecognized option %r" % name)
def define(self, name, default=None, type=None, help=None, metavar=None,
multiple=False, group=None):
if name in self:
raise Error("Option %r already defined in %s", name,
self[name].file_name)
frame = sys._getframe(0)
options_file = frame.f_code.co_filename
file_name = frame.f_back.f_code.co_filename
if file_name == options_file:
file_name = ""
if type is None:
if not multiple and default is not None:
type = default.__class__
else:
type = str
if group:
group_name = group
else:
group_name = file_name
self[name] = _Option(name, file_name=file_name, default=default,
type=type, help=help, metavar=metavar,
multiple=multiple, group_name=group_name)
def parse_command_line(self, args=None):
if args is None:
args = sys.argv
remaining = []
for i in xrange(1, len(args)):
# All things after the last option are command line arguments
if not args[i].startswith("-"):
remaining = args[i:]
break
if args[i] == "--":
remaining = args[i + 1:]
break
arg = args[i].lstrip("-")
name, equals, value = arg.partition("=")
name = name.replace('-', '_')
if not name in self:
print_help()
raise Error('Unrecognized command line option: %r' % name)
option = self[name]
if not equals:
if option.type == bool:
value = "true"
else:
raise Error('Option %r requires a value' % name)
option.parse(value)
if self.help:
print_help()
sys.exit(0)
# Set up log level and pretty console logging by default
if self.logging != 'none':
logging.getLogger().setLevel(getattr(logging, self.logging.upper()))
enable_pretty_logging()
return remaining
def parse_config_file(self, path):
config = {}
execfile(path, config, config)
for name in config:
if name in self:
self[name].set(config[name])
def print_help(self, file=sys.stdout):
"""Prints all the command line options to stdout."""
print >> file, "Usage: %s [OPTIONS]" % sys.argv[0]
print >> file, "\nOptions:\n"
by_group = {}
for option in self.itervalues():
by_group.setdefault(option.group_name, []).append(option)
for filename, o in sorted(by_group.items()):
if filename:
print >> file, "\n%s options:\n" % os.path.normpath(filename)
o.sort(key=lambda option: option.name)
for option in o:
prefix = option.name
if option.metavar:
prefix += "=" + option.metavar
description = option.help or ""
if option.default is not None and option.default != '':
description += " (default %s)" % option.default
lines = textwrap.wrap(description, 79 - 35)
if len(prefix) > 30 or len(lines) == 0:
lines.insert(0, '')
print >> file, " --%-30s %s" % (prefix, lines[0])
for line in lines[1:]:
print >> file, "%-34s %s" % (' ', line)
print >> file
class _Option(object):
def __init__(self, name, default=None, type=str, help=None, metavar=None,
def __init__(self, name, default=None, type=basestring, help=None, metavar=None,
multiple=False, file_name=None, group_name=None):
if default is None and multiple:
default = []
@@ -215,18 +206,17 @@ class _Option(object):
datetime.datetime: self._parse_datetime,
datetime.timedelta: self._parse_timedelta,
bool: self._parse_bool,
str: self._parse_string,
basestring: self._parse_string,
}.get(self.type, self.type)
if self.multiple:
if self._value is None:
self._value = []
self._value = []
for part in value.split(","):
if self.type in (int, long):
# allow ranges of the form X:Y (inclusive at both ends)
lo, _, hi = part.partition(":")
lo = _parse(lo)
hi = _parse(hi) if hi else lo
self._value.extend(range(lo, hi+1))
self._value.extend(range(lo, hi + 1))
else:
self._value.append(_parse(part))
else:
@@ -244,8 +234,8 @@ class _Option(object):
(self.name, self.type.__name__))
else:
if value != None and not isinstance(value, self.type):
raise Error("Option %r is required to be a %s" %
(self.name, self.type.__name__))
raise Error("Option %r is required to be a %s (%s given)" %
(self.name, self.type.__name__, type(value)))
self._value = value
# Supported date/time formats in our options
@@ -313,14 +303,64 @@ class _Option(object):
return _unicode(value)
class Error(Exception):
"""Exception raised by errors in the options module."""
pass
options = _Options()
"""Global options dictionary.
Supports both attribute-style and dict-style access.
"""
def enable_pretty_logging():
def define(name, default=None, type=None, help=None, metavar=None,
multiple=False, group=None):
"""Defines a new command line option.
If type is given (one of str, float, int, datetime, or timedelta)
or can be inferred from the default, we parse the command line
arguments based on the given type. If multiple is True, we accept
comma-separated values, and the option value is always a list.
For multi-value integers, we also accept the syntax x:y, which
turns into range(x, y) - very useful for long integer ranges.
help and metavar are used to construct the automatically generated
command line help string. The help message is formatted like::
--name=METAVAR help string
group is used to group the defined options in logical groups. By default,
command line options are grouped by the defined file.
Command line option names must be unique globally. They can be parsed
from the command line with parse_command_line() or parsed from a
config file with parse_config_file.
"""
return options.define(name, default=default, type=type, help=help,
metavar=metavar, multiple=multiple, group=group)
def parse_command_line(args=None):
"""Parses all options given on the command line (defaults to sys.argv).
Note that args[0] is ignored since it is the program name in sys.argv.
We return a list of all arguments that are not parsed as options.
"""
return options.parse_command_line(args)
def parse_config_file(path):
"""Parses and loads the Python config file at the given path."""
return options.parse_config_file(path)
def print_help(file=sys.stdout):
"""Prints all the command line options to stdout."""
return options.print_help(file)
def enable_pretty_logging(options=options):
"""Turns on formatted logging output as configured.
This is called automatically by `parse_command_line`.
"""
root_logger = logging.getLogger()
@@ -348,7 +388,6 @@ def enable_pretty_logging():
root_logger.addHandler(channel)
class _LogFormatter(logging.Formatter):
def __init__(self, color, *args, **kwargs):
logging.Formatter.__init__(self, *args, **kwargs)
@@ -366,13 +405,13 @@ class _LogFormatter(logging.Formatter):
if (3, 0) < sys.version_info < (3, 2, 3):
fg_color = unicode(fg_color, "ascii")
self._colors = {
logging.DEBUG: unicode(curses.tparm(fg_color, 4), # Blue
logging.DEBUG: unicode(curses.tparm(fg_color, 4), # Blue
"ascii"),
logging.INFO: unicode(curses.tparm(fg_color, 2), # Green
logging.INFO: unicode(curses.tparm(fg_color, 2), # Green
"ascii"),
logging.WARNING: unicode(curses.tparm(fg_color, 3), # Yellow
logging.WARNING: unicode(curses.tparm(fg_color, 3), # Yellow
"ascii"),
logging.ERROR: unicode(curses.tparm(fg_color, 1), # Red
logging.ERROR: unicode(curses.tparm(fg_color, 1), # Red
"ascii"),
}
self._normal = unicode(curses.tigetstr("sgr0"), "ascii")
@@ -382,6 +421,7 @@ class _LogFormatter(logging.Formatter):
record.message = record.getMessage()
except Exception, e:
record.message = "Bad message (%r): %r" % (e, record.__dict__)
assert isinstance(record.message, basestring) # guaranteed by logging
record.asctime = time.strftime(
"%y%m%d %H:%M:%S", self.converter(record.created))
prefix = '[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]' % \
@@ -389,7 +429,29 @@ class _LogFormatter(logging.Formatter):
if self._color:
prefix = (self._colors.get(record.levelno, self._normal) +
prefix + self._normal)
formatted = prefix + " " + record.message
# Encoding notes: The logging module prefers to work with character
# strings, but only enforces that log messages are instances of
# basestring. In python 2, non-ascii bytestrings will make
# their way through the logging framework until they blow up with
# an unhelpful decoding error (with this formatter it happens
# when we attach the prefix, but there are other opportunities for
# exceptions further along in the framework).
#
# If a byte string makes it this far, convert it to unicode to
# ensure it will make it out to the logs. Use repr() as a fallback
# to ensure that all byte strings can be converted successfully,
# but don't do it by default so we don't add extra quotes to ascii
# bytestrings. This is a bit of a hacky place to do this, but
# it's worth it since the encoding errors that would otherwise
# result are so useless (and tornado is fond of using utf8-encoded
# byte strings whereever possible).
try:
message = _unicode(record.message)
except UnicodeDecodeError:
message = repr(record.message)
formatted = prefix + " " + message
if record.exc_info:
if not record.exc_text:
record.exc_text = self.formatException(record.exc_info)
@@ -398,9 +460,6 @@ class _LogFormatter(logging.Formatter):
return formatted.replace("\n", "\n ")
options = _Options.instance()
# Default options
define("help", type=bool, help="show this help information")
define("logging", default="info",
View File
Regular → Executable
+4 -1
View File
@@ -23,9 +23,12 @@ Most code that needs access to this functionality should do e.g.::
from tornado.platform.auto import set_close_exec
"""
from __future__ import absolute_import, division, with_statement
import os
if os.name == 'nt':
from tornado.platform.windows import set_close_exec, Waker
from tornado.platform.common import Waker
from tornado.platform.windows import set_close_exec
else:
from tornado.platform.posix import set_close_exec, Waker
+89
View File
@@ -0,0 +1,89 @@
"""Lowest-common-denominator implementations of platform functionality."""
from __future__ import absolute_import, division, with_statement
import errno
import socket
from tornado.platform import interface
from tornado.util import b
class Waker(interface.Waker):
"""Create an OS independent asynchronous pipe.
For use on platforms that don't have os.pipe() (or where pipes cannot
be passed to select()), but do have sockets. This includes Windows
and Jython.
"""
def __init__(self):
# Based on Zope async.py: http://svn.zope.org/zc.ngi/trunk/src/zc/ngi/async.py
self.writer = socket.socket()
# Disable buffering -- pulling the trigger sends 1 byte,
# and we want that sent immediately, to wake up ASAP.
self.writer.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
count = 0
while 1:
count += 1
# Bind to a local port; for efficiency, let the OS pick
# a free port for us.
# Unfortunately, stress tests showed that we may not
# be able to connect to that port ("Address already in
# use") despite that the OS picked it. This appears
# to be a race bug in the Windows socket implementation.
# So we loop until a connect() succeeds (almost always
# on the first try). See the long thread at
# http://mail.zope.org/pipermail/zope/2005-July/160433.html
# for hideous details.
a = socket.socket()
a.bind(("127.0.0.1", 0))
a.listen(1)
connect_address = a.getsockname() # assigned (host, port) pair
try:
self.writer.connect(connect_address)
break # success
except socket.error, detail:
if (not hasattr(errno, 'WSAEADDRINUSE') or
detail[0] != errno.WSAEADDRINUSE):
# "Address already in use" is the only error
# I've seen on two WinXP Pro SP2 boxes, under
# Pythons 2.3.5 and 2.4.1.
raise
# (10048, 'Address already in use')
# assert count <= 2 # never triggered in Tim's tests
if count >= 10: # I've never seen it go above 2
a.close()
self.writer.close()
raise socket.error("Cannot bind trigger!")
# Close `a` and try again. Note: I originally put a short
# sleep() here, but it didn't appear to help or hurt.
a.close()
self.reader, addr = a.accept()
self.reader.setblocking(0)
self.writer.setblocking(0)
a.close()
self.reader_fd = self.reader.fileno()
def fileno(self):
return self.reader.fileno()
def wake(self):
try:
self.writer.send(b("x"))
except (IOError, socket.error):
pass
def consume(self):
try:
while True:
result = self.reader.recv(1024)
if not result:
break
except (IOError, socket.error):
pass
def close(self):
self.reader.close()
self.writer.close()
+5 -3
View File
@@ -21,10 +21,14 @@ for other tornado.platform modules. Most code should import the appropriate
implementation from `tornado.platform.auto`.
"""
from __future__ import absolute_import, division, with_statement
def set_close_exec(fd):
"""Sets the close-on-exec bit (``FD_CLOEXEC``)for a file descriptor."""
raise NotImplementedError()
class Waker(object):
"""A socket-like object that can wake another thread from ``select()``.
@@ -36,7 +40,7 @@ class Waker(object):
"""
def fileno(self):
"""Returns a file descriptor for this waker.
Must be suitable for use with ``select()`` or equivalent on the
local platform.
"""
@@ -53,5 +57,3 @@ class Waker(object):
def close(self):
"""Closes the waker's file descriptor(s)."""
raise NotImplementedError()
Regular → Executable
+8 -2
View File
@@ -16,20 +16,25 @@
"""Posix implementations of platform-specific functionality."""
from __future__ import absolute_import, division, with_statement
import fcntl
import os
from tornado.platform import interface
from tornado.util import b
def set_close_exec(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFD)
fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC)
def _set_nonblocking(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
class Waker(interface.Waker):
def __init__(self):
r, w = os.pipe()
@@ -53,7 +58,8 @@ class Waker(interface.Waker):
try:
while True:
result = self.reader.read()
if not result: break;
if not result:
break
except IOError:
pass
Regular → Executable
+9 -9
View File
@@ -41,10 +41,10 @@ recommended to call::
before closing the `IOLoop`.
This module has been tested with Twisted versions 11.0.0 and 11.1.0.
This module has been tested with Twisted versions 11.0.0, 11.1.0, and 12.0.0
"""
from __future__ import with_statement, absolute_import
from __future__ import absolute_import, division, with_statement
import functools
import logging
@@ -56,7 +56,7 @@ from twisted.internet.interfaces import \
from twisted.python import failure, log
from twisted.internet import error
from zope.interface import implements
from zope.interface import implementer
import tornado
import tornado.ioloop
@@ -66,8 +66,6 @@ from tornado.ioloop import IOLoop
class TornadoDelayedCall(object):
"""DelayedCall object for Tornado."""
implements(IDelayedCall)
def __init__(self, reactor, seconds, f, *args, **kw):
self._reactor = reactor
self._func = functools.partial(f, *args, **kw)
@@ -106,6 +104,9 @@ class TornadoDelayedCall(object):
def active(self):
return self._active
# Fake class decorator for python 2.5 compatibility
TornadoDelayedCall = implementer(IDelayedCall)(TornadoDelayedCall)
class TornadoReactor(PosixReactorBase):
"""Twisted reactor built on the Tornado IOLoop.
@@ -117,15 +118,13 @@ class TornadoReactor(PosixReactorBase):
timed call functionality on top of `IOLoop.add_timeout` rather than
using the implementation in `PosixReactorBase`.
"""
implements(IReactorTime, IReactorFDSet)
def __init__(self, io_loop=None):
if not io_loop:
io_loop = tornado.ioloop.IOLoop.instance()
self._io_loop = io_loop
self._readers = {} # map of reader objects to fd
self._writers = {} # map of writer objects to fd
self._fds = {} # a map of fd to a (reader, writer) tuple
self._fds = {} # a map of fd to a (reader, writer) tuple
self._delayedCalls = {}
PosixReactorBase.__init__(self)
@@ -294,6 +293,8 @@ class TornadoReactor(PosixReactorBase):
self._io_loop.start()
if self._stopped:
self.fireSystemEvent("shutdown")
TornadoReactor = implementer(IReactorTime, IReactorFDSet)(TornadoReactor)
class _TestReactor(TornadoReactor):
"""Subclass of TornadoReactor for use in unittests.
@@ -319,7 +320,6 @@ class _TestReactor(TornadoReactor):
port, protocol, interface=interface, maxPacketSize=maxPacketSize)
def install(io_loop=None):
"""Install this package as the default Twisted reactor."""
if not io_loop:
Regular → Executable
+2 -79
View File
@@ -1,13 +1,10 @@
# NOTE: win32 support is currently experimental, and not recommended
# for production use.
from __future__ import absolute_import, division, with_statement
import ctypes
import ctypes.wintypes
import socket
import errno
from tornado.platform import interface
from tornado.util import b
# See: http://msdn.microsoft.com/en-us/library/ms724935(VS.85).aspx
SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation
@@ -21,77 +18,3 @@ def set_close_exec(fd):
success = SetHandleInformation(fd, HANDLE_FLAG_INHERIT, 0)
if not success:
raise ctypes.GetLastError()
class Waker(interface.Waker):
"""Create an OS independent asynchronous pipe"""
def __init__(self):
# Based on Zope async.py: http://svn.zope.org/zc.ngi/trunk/src/zc/ngi/async.py
self.writer = socket.socket()
# Disable buffering -- pulling the trigger sends 1 byte,
# and we want that sent immediately, to wake up ASAP.
self.writer.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
count = 0
while 1:
count += 1
# Bind to a local port; for efficiency, let the OS pick
# a free port for us.
# Unfortunately, stress tests showed that we may not
# be able to connect to that port ("Address already in
# use") despite that the OS picked it. This appears
# to be a race bug in the Windows socket implementation.
# So we loop until a connect() succeeds (almost always
# on the first try). See the long thread at
# http://mail.zope.org/pipermail/zope/2005-July/160433.html
# for hideous details.
a = socket.socket()
a.bind(("127.0.0.1", 0))
connect_address = a.getsockname() # assigned (host, port) pair
a.listen(1)
try:
self.writer.connect(connect_address)
break # success
except socket.error, detail:
if detail[0] != errno.WSAEADDRINUSE:
# "Address already in use" is the only error
# I've seen on two WinXP Pro SP2 boxes, under
# Pythons 2.3.5 and 2.4.1.
raise
# (10048, 'Address already in use')
# assert count <= 2 # never triggered in Tim's tests
if count >= 10: # I've never seen it go above 2
a.close()
self.writer.close()
raise socket.error("Cannot bind trigger!")
# Close `a` and try again. Note: I originally put a short
# sleep() here, but it didn't appear to help or hurt.
a.close()
self.reader, addr = a.accept()
self.reader.setblocking(0)
self.writer.setblocking(0)
a.close()
self.reader_fd = self.reader.fileno()
def fileno(self):
return self.reader.fileno()
def wake(self):
try:
self.writer.send(b("x"))
except IOError:
pass
def consume(self):
try:
while True:
result = self.reader.recv(1024)
if not result: break
except IOError:
pass
def close(self):
self.reader.close()
self.writer.close()
Regular → Executable
+12 -3
View File
@@ -16,6 +16,8 @@
"""Utilities for working with multiple processes."""
from __future__ import absolute_import, division, with_statement
import errno
import logging
import os
@@ -27,10 +29,11 @@ from binascii import hexlify
from tornado import ioloop
try:
import multiprocessing # Python 2.6+
import multiprocessing # Python 2.6+
except ImportError:
multiprocessing = None
def cpu_count():
"""Returns the number of processors on this machine."""
if multiprocessing is not None:
@@ -45,6 +48,7 @@ def cpu_count():
logging.error("Could not detect number of processors; assuming 1")
return 1
def _reseed_random():
if 'random' not in sys.modules:
return
@@ -61,6 +65,7 @@ def _reseed_random():
_task_id = None
def fork_processes(num_processes, max_restarts=100):
"""Starts multiple worker processes.
@@ -95,6 +100,7 @@ def fork_processes(num_processes, max_restarts=100):
"IOLoop.instance() before calling start_processes()")
logging.info("Starting %d processes", num_processes)
children = {}
def start_child(i):
pid = os.fork()
if pid == 0:
@@ -108,7 +114,8 @@ def fork_processes(num_processes, max_restarts=100):
return None
for i in range(num_processes):
id = start_child(i)
if id is not None: return id
if id is not None:
return id
num_restarts = 0
while children:
try:
@@ -133,13 +140,15 @@ def fork_processes(num_processes, max_restarts=100):
if num_restarts > max_restarts:
raise RuntimeError("Too many child restarts, giving up")
new_id = start_child(id)
if new_id is not None: return new_id
if new_id is not None:
return new_id
# All child processes exited cleanly, so exit the master process
# instead of just returning to right after the call to
# fork_processes (which will probably just start up another IOLoop
# unless the caller checks the return value).
sys.exit(0)
def task_id():
"""Returns the current task id, if any.
+59 -34
View File
@@ -1,12 +1,12 @@
#!/usr/bin/env python
from __future__ import with_statement
from __future__ import absolute_import, division, with_statement
from tornado.escape import utf8, _unicode, native_str
from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError, AsyncHTTPClient, main
from tornado.httputil import HTTPHeaders
from tornado.iostream import IOStream, SSLIOStream
from tornado import stack_context
from tornado.util import b
from tornado.util import b, GzipDecompressor
import base64
import collections
@@ -20,7 +20,6 @@ import socket
import sys
import time
import urlparse
import zlib
try:
from io import BytesIO # python 3
@@ -28,12 +27,13 @@ except ImportError:
from cStringIO import StringIO as BytesIO # python 2
try:
import ssl # python 2.6+
import ssl # python 2.6+
except ImportError:
ssl = None
_DEFAULT_CA_CERTS = os.path.dirname(__file__) + '/ca-certificates.crt'
class SimpleAsyncHTTPClient(AsyncHTTPClient):
"""Non-blocking HTTP client with no external dependencies.
@@ -93,8 +93,10 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
def fetch(self, request, callback, **kwargs):
if not isinstance(request, HTTPRequest):
request = HTTPRequest(url=request, **kwargs)
if not isinstance(request.headers, HTTPHeaders):
request.headers = HTTPHeaders(request.headers)
# We're going to modify this (to add Host, Accept-Encoding, etc),
# so make sure we don't modify the caller's object. This is also
# where normal dicts get converted to HTTPHeaders objects.
request.headers = HTTPHeaders(request.headers)
callback = stack_context.wrap(callback)
self.queue.append((request, callback))
self._process_queue()
@@ -119,9 +121,8 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
self._process_queue()
class _HTTPConnection(object):
_SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"])
_SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
def __init__(self, io_loop, client, request, release_callback,
final_callback, max_buffer_size):
@@ -160,6 +161,7 @@ class _HTTPConnection(object):
if re.match(r'^\[.*\]$', host):
# raw ipv6 addresses in urls are enclosed in brackets
host = host[1:-1]
parsed_hostname = host # save final parsed host for _on_connect
if self.client.hostname_mapping is not None:
host = self.client.hostname_mapping.get(host, host)
@@ -198,7 +200,7 @@ class _HTTPConnection(object):
# compatibility with servers configured for TLSv1 only,
# but nearly all servers support SSLv3:
# http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html
if sys.version_info >= (2,7):
if sys.version_info >= (2, 7):
ssl_options["ciphers"] = "DEFAULT:!SSLv2"
else:
# This is really only necessary for pre-1.0 versions
@@ -218,30 +220,33 @@ class _HTTPConnection(object):
if timeout:
self._timeout = self.io_loop.add_timeout(
self.start_time + timeout,
self._on_timeout)
stack_context.wrap(self._on_timeout))
self.stream.set_close_callback(self._on_close)
self.stream.connect(sockaddr,
functools.partial(self._on_connect, parsed))
functools.partial(self._on_connect, parsed,
parsed_hostname))
def _on_timeout(self):
self._timeout = None
self._run_callback(HTTPResponse(self.request, 599,
request_time=time.time() - self.start_time,
error=HTTPError(599, "Timeout")))
self.stream.close()
if self.final_callback is not None:
raise HTTPError(599, "Timeout")
def _on_connect(self, parsed):
def _on_connect(self, parsed, parsed_hostname):
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
if self.request.request_timeout:
self._timeout = self.io_loop.add_timeout(
self.start_time + self.request.request_timeout,
self._on_timeout)
stack_context.wrap(self._on_timeout))
if (self.request.validate_cert and
isinstance(self.stream, SSLIOStream)):
match_hostname(self.stream.socket.getpeercert(),
parsed.hostname)
# ipv6 addresses are broken (in
# parsed.hostname) until 2.7, here is
# correctly parsed value calculated in
# __init__
parsed_hostname)
if (self.request.method not in self._SUPPORTED_METHODS and
not self.request.allow_nonstandard_methods):
raise KeyError("unknown method %s" % self.request.method)
@@ -250,8 +255,13 @@ class _HTTPConnection(object):
'proxy_username', 'proxy_password'):
if getattr(self.request, key, None):
raise NotImplementedError('%s not supported' % key)
if "Connection" not in self.request.headers:
self.request.headers["Connection"] = "close"
if "Host" not in self.request.headers:
self.request.headers["Host"] = parsed.netloc
if '@' in parsed.netloc:
self.request.headers["Host"] = parsed.netloc.rpartition('@')[-1]
else:
self.request.headers["Host"] = parsed.netloc
username, password = None, None
if parsed.username is not None:
username, password = parsed.username, parsed.password
@@ -265,7 +275,7 @@ class _HTTPConnection(object):
if self.request.user_agent:
self.request.headers["User-Agent"] = self.request.user_agent
if not self.request.allow_nonstandard_methods:
if self.request.method in ("POST", "PUT"):
if self.request.method in ("POST", "PATCH", "PUT"):
assert self.request.body is not None
else:
assert self.request.body is None
@@ -313,12 +323,12 @@ class _HTTPConnection(object):
self._run_callback(HTTPResponse(self.request, 599, error=e,
request_time=time.time() - self.start_time,
))
if hasattr(self, "stream"):
self.stream.close()
def _on_close(self):
self._run_callback(HTTPResponse(
self.request, 599,
request_time=time.time() - self.start_time,
error=HTTPError(599, "Connection closed")))
if self.final_callback is not None:
raise HTTPError(599, "Connection closed")
def _on_headers(self, data):
data = native_str(data.decode("latin1"))
@@ -354,16 +364,16 @@ class _HTTPConnection(object):
if 100 <= self.code < 200 or self.code in (204, 304):
# These response codes never have bodies
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
assert "Transfer-Encoding" not in self.headers
assert content_length in (None, 0)
if ("Transfer-Encoding" in self.headers or
content_length not in (None, 0)):
raise ValueError("Response with code %d should not have body" %
self.code)
self._on_body(b(""))
return
if (self.request.use_gzip and
self.headers.get("Content-Encoding") == "gzip"):
# Magic parameter makes zlib module understand gzip header
# http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib
self._decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
self._decompressor = GzipDecompressor()
if self.headers.get("Transfer-Encoding") == "chunked":
self.chunks = []
self.stream.read_until(b("\r\n"), self._on_chunk_length)
@@ -405,7 +415,8 @@ class _HTTPConnection(object):
self.stream.close()
return
if self._decompressor:
data = self._decompressor.decompress(data)
data = (self._decompressor.decompress(data) +
self._decompressor.flush())
if self.request.streaming_callback:
if self.chunks is None:
# if chunks is not None, we already called streaming_callback
@@ -413,7 +424,7 @@ class _HTTPConnection(object):
self.request.streaming_callback(data)
buffer = BytesIO()
else:
buffer = BytesIO(data) # TODO: don't require one big string?
buffer = BytesIO(data) # TODO: don't require one big string?
response = HTTPResponse(original_request,
self.code, headers=self.headers,
request_time=time.time() - self.start_time,
@@ -426,9 +437,21 @@ class _HTTPConnection(object):
# TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
length = int(data.strip(), 16)
if length == 0:
# all the data has been decompressed, so we don't need to
# decompress again in _on_body
self._decompressor = None
if self._decompressor is not None:
tail = self._decompressor.flush()
if tail:
# I believe the tail will always be empty (i.e.
# decompress will return all it can). The purpose
# of the flush call is to detect errors such
# as truncated input. But in case it ever returns
# anything, treat it as an extra chunk
if self.request.streaming_callback is not None:
self.request.streaming_callback(tail)
else:
self.chunks.append(tail)
# all the data has been decompressed, so we don't need to
# decompress again in _on_body
self._decompressor = None
self._on_body(b('').join(self.chunks))
else:
self.stream.read_bytes(length + 2, # chunk ends with \r\n
@@ -452,6 +475,7 @@ class _HTTPConnection(object):
class CertificateError(ValueError):
pass
def _dnsname_to_pat(dn):
pats = []
for frag in dn.split(r'.'):
@@ -465,6 +489,7 @@ def _dnsname_to_pat(dn):
pats.append(frag.replace(r'\*', '[^.]*'))
return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE)
def match_hostname(cert, hostname):
"""Verify that *cert* (in decoded format as returned by
SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 rules
Regular → Executable
+42 -13
View File
@@ -66,19 +66,24 @@ Here are a few rules of thumb for when it's necessary:
block that references your `StackContext`.
'''
from __future__ import with_statement
from __future__ import absolute_import, division, with_statement
import contextlib
import functools
import itertools
import operator
import sys
import threading
from tornado.util import raise_exc_info
class _State(threading.local):
def __init__(self):
self.contexts = ()
_state = _State()
class StackContext(object):
'''Establishes the given context as a StackContext that will be transferred.
@@ -91,24 +96,33 @@ class StackContext(object):
StackContext takes the function itself rather than its result::
with StackContext(my_context):
The result of ``with StackContext() as cb:`` is a deactivation
callback. Run this callback when the StackContext is no longer
needed to ensure that it is not propagated any further (note that
deactivating a context does not affect any instances of that
context that are currently pending). This is an advanced feature
and not necessary in most applications.
'''
def __init__(self, context_factory):
def __init__(self, context_factory, _active_cell=None):
self.context_factory = context_factory
self.active_cell = _active_cell or [True]
# Note that some of this code is duplicated in ExceptionStackContext
# below. ExceptionStackContext is more common and doesn't need
# the full generality of this class.
def __enter__(self):
self.old_contexts = _state.contexts
# _state.contexts is a tuple of (class, arg) pairs
# _state.contexts is a tuple of (class, arg, active_cell) tuples
_state.contexts = (self.old_contexts +
((StackContext, self.context_factory),))
((StackContext, self.context_factory, self.active_cell),))
try:
self.context = self.context_factory()
self.context.__enter__()
except Exception:
_state.contexts = self.old_contexts
raise
return lambda: operator.setitem(self.active_cell, 0, False)
def __exit__(self, type, value, traceback):
try:
@@ -116,6 +130,7 @@ class StackContext(object):
finally:
_state.contexts = self.old_contexts
class ExceptionStackContext(object):
'''Specialization of StackContext for exception handling.
@@ -129,13 +144,16 @@ class ExceptionStackContext(object):
If the exception handler returns true, the exception will be
consumed and will not be propagated to other exception handlers.
'''
def __init__(self, exception_handler):
def __init__(self, exception_handler, _active_cell=None):
self.exception_handler = exception_handler
self.active_cell = _active_cell or [True]
def __enter__(self):
self.old_contexts = _state.contexts
_state.contexts = (self.old_contexts +
((ExceptionStackContext, self.exception_handler),))
((ExceptionStackContext, self.exception_handler,
self.active_cell),))
return lambda: operator.setitem(self.active_cell, 0, False)
def __exit__(self, type, value, traceback):
try:
@@ -144,6 +162,7 @@ class ExceptionStackContext(object):
finally:
_state.contexts = self.old_contexts
class NullContext(object):
'''Resets the StackContext.
@@ -158,9 +177,11 @@ class NullContext(object):
def __exit__(self, type, value, traceback):
_state.contexts = self.old_contexts
class _StackContextWrapper(functools.partial):
pass
def wrap(fn):
'''Returns a callable object that will restore the current StackContext
when executed.
@@ -173,12 +194,17 @@ def wrap(fn):
return fn
# functools.wraps doesn't appear to work on functools.partial objects
#@functools.wraps(fn)
def wrapped(callback, contexts, *args, **kwargs):
def wrapped(*args, **kwargs):
callback, contexts, args = args[0], args[1], args[2:]
if contexts is _state.contexts or not contexts:
callback(*args, **kwargs)
return
if not _state.contexts:
new_contexts = [cls(arg) for (cls, arg) in contexts]
new_contexts = [cls(arg, active_cell)
for (cls, arg, active_cell) in contexts
if active_cell[0]]
# If we're moving down the stack, _state.contexts is a prefix
# of contexts. For each element of contexts not in that prefix,
# create a new StackContext object.
@@ -190,10 +216,13 @@ def wrap(fn):
for a, b in itertools.izip(_state.contexts, contexts))):
# contexts have been removed or changed, so start over
new_contexts = ([NullContext()] +
[cls(arg) for (cls,arg) in contexts])
[cls(arg, active_cell)
for (cls, arg, active_cell) in contexts
if active_cell[0]])
else:
new_contexts = [cls(arg)
for (cls, arg) in contexts[len(_state.contexts):]]
new_contexts = [cls(arg, active_cell)
for (cls, arg, active_cell) in contexts[len(_state.contexts):]
if active_cell[0]]
if len(new_contexts) > 1:
with _nested(*new_contexts):
callback(*args, **kwargs)
@@ -207,6 +236,7 @@ def wrap(fn):
else:
return _StackContextWrapper(fn)
@contextlib.contextmanager
def _nested(*managers):
"""Support multiple context managers in a single with-statement.
@@ -240,5 +270,4 @@ def _nested(*managers):
# Don't rely on sys.exc_info() still containing
# the right information. Another exception may
# have been raised and caught by an exit method
raise exc[0], exc[1], exc[2]
raise_exc_info(exc)
Regular → Executable
+44 -16
View File
@@ -101,6 +101,11 @@ with ``{# ... #}``.
{% apply linkify %}{{name}} said: {{message}}{% end %}
Note that as an implementation detail apply blocks are implemented
as nested functions and thus may interact strangely with variables
set via ``{% set %}``, or the use of ``{% break %}`` or ``{% continue %}``
within loops.
``{% autoescape *function* %}``
Sets the autoescape mode for the current file. This does not affect
other files, even those referenced by ``{% include %}``. Note that
@@ -134,8 +139,9 @@ with ``{# ... #}``.
tag will be ignored. For an example, see the ``{% block %}`` tag.
``{% for *var* in *expr* %}...{% end %}``
Same as the python ``for`` statement.
Same as the python ``for`` statement. ``{% break %}`` and
``{% continue %}`` may be used inside the loop.
``{% from *x* import *y* %}``
Same as the python ``import`` statement.
@@ -165,14 +171,15 @@ with ``{# ... #}``.
``{% set *x* = *y* %}``
Sets a local variable.
``{% try %}...{% except %}...{% finally %}...{% end %}``
``{% try %}...{% except %}...{% finally %}...{% else %}...{% end %}``
Same as the python ``try`` statement.
``{% while *condition* %}... {% end %}``
Same as the python ``while`` statement.
Same as the python ``while`` statement. ``{% break %}`` and
``{% continue %}`` may be used inside the loop.
"""
from __future__ import with_statement
from __future__ import absolute_import, division, with_statement
import cStringIO
import datetime
@@ -189,6 +196,7 @@ from tornado.util import bytes_type, ObjectDict
_DEFAULT_AUTOESCAPE = "xhtml_escape"
_UNSET = object()
class Template(object):
"""A compiled template.
@@ -217,7 +225,7 @@ class Template(object):
# the module name used in __name__ below.
self.compiled = compile(
escape.to_unicode(self.code),
"%s.generated.py" % self.name.replace('.','_'),
"%s.generated.py" % self.name.replace('.', '_'),
"exec")
except Exception:
formatted_code = _format_code(self.code).rstrip()
@@ -326,6 +334,7 @@ class BaseLoader(object):
def _create_template(self, name):
raise NotImplementedError()
class Loader(BaseLoader):
"""A template loader that loads from a single root directory.
@@ -350,7 +359,7 @@ class Loader(BaseLoader):
def _create_template(self, name):
path = os.path.join(self.root, name)
f = open(path, "r")
f = open(path, "rb")
template = Template(f.read(), name=name, loader=self)
f.close()
return template
@@ -404,7 +413,6 @@ class _File(_Node):
return (self.body,)
class _ChunkList(_Node):
def __init__(self, chunks):
self.chunks = chunks
@@ -531,11 +539,13 @@ class _Expression(_Node):
writer.current_template.autoescape, self.line)
writer.write_line("_append(_tmp)", self.line)
class _Module(_Expression):
def __init__(self, expression, line):
super(_Module, self).__init__("_modules." + expression, line,
raw=True)
class _Text(_Node):
def __init__(self, value, line):
self.value = value
@@ -608,7 +618,7 @@ class _CodeWriter(object):
ancestors = ["%s:%d" % (tmpl.name, lineno)
for (tmpl, lineno) in self.include_stack]
line_comment += ' (via %s)' % ', '.join(reversed(ancestors))
print >> self.file, " "*indent + line + line_comment
print >> self.file, " " * indent + line + line_comment
class _TemplateReader(object):
@@ -651,9 +661,12 @@ class _TemplateReader(object):
if type(key) is slice:
size = len(self)
start, stop, step = key.indices(size)
if start is None: start = self.pos
else: start += self.pos
if stop is not None: stop += self.pos
if start is None:
start = self.pos
else:
start += self.pos
if stop is not None:
stop += self.pos
return self.text[slice(start, stop, step)]
elif key < 0:
return self.text[key]
@@ -670,7 +683,7 @@ def _format_code(code):
return "".join([format % (i + 1, line) for (i, line) in enumerate(lines)])
def _parse(reader, template, in_block=None):
def _parse(reader, template, in_block=None, in_loop=None):
body = _ChunkList([])
while True:
# Find next template directive
@@ -751,7 +764,7 @@ def _parse(reader, template, in_block=None):
# Intermediate ("else", "elif", etc) blocks
intermediate_blocks = {
"else": set(["if", "for", "while"]),
"else": set(["if", "for", "while", "try"]),
"elif": set(["if"]),
"except": set(["try"]),
"finally": set(["try"]),
@@ -796,7 +809,8 @@ def _parse(reader, template, in_block=None):
block = _Statement(suffix, line)
elif operator == "autoescape":
fn = suffix.strip()
if fn == "None": fn = None
if fn == "None":
fn = None
template.autoescape = fn
continue
elif operator == "raw":
@@ -808,7 +822,15 @@ def _parse(reader, template, in_block=None):
elif operator in ("apply", "block", "try", "if", "for", "while"):
# parse inner body recursively
block_body = _parse(reader, template, operator)
if operator in ("for", "while"):
block_body = _parse(reader, template, operator, operator)
elif operator == "apply":
# apply creates a nested function so syntactically it's not
# in the loop.
block_body = _parse(reader, template, operator, None)
else:
block_body = _parse(reader, template, operator, in_loop)
if operator == "apply":
if not suffix:
raise ParseError("apply missing method name on line %d" % line)
@@ -822,5 +844,11 @@ def _parse(reader, template, in_block=None):
body.chunks.append(block)
continue
elif operator in ("break", "continue"):
if not in_loop:
raise ParseError("%s outside %s block" % (operator, set(["for", "while"])))
body.chunks.append(_Statement(contents, line))
continue
else:
raise ParseError("unknown operator: %r" % operator)
Regular → Executable
+85 -21
View File
@@ -18,12 +18,13 @@ inheritance. See the docstrings for each class/function below for more
information.
"""
from __future__ import with_statement
from __future__ import absolute_import, division, with_statement
from cStringIO import StringIO
try:
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.ioloop import IOLoop
except ImportError:
# These modules are not importable on app engine. Parts of this module
@@ -31,15 +32,20 @@ except ImportError:
AsyncHTTPClient = None
HTTPServer = None
IOLoop = None
SimpleAsyncHTTPClient = None
from tornado.stack_context import StackContext, NullContext
from tornado.util import raise_exc_info
import contextlib
import logging
import os
import signal
import sys
import time
import unittest
_next_port = 10000
def get_unused_port():
"""Returns a (hopefully) unused port number."""
global _next_port
@@ -47,6 +53,7 @@ def get_unused_port():
_next_port = _next_port + 1
return port
class AsyncTestCase(unittest.TestCase):
"""TestCase subclass for testing IOLoop-based asynchronous code.
@@ -104,6 +111,7 @@ class AsyncTestCase(unittest.TestCase):
self.__running = False
self.__failure = None
self.__stop_args = None
self.__timeout = None
def setUp(self):
super(AsyncTestCase, self).setUp()
@@ -134,9 +142,18 @@ class AsyncTestCase(unittest.TestCase):
self.__failure = sys.exc_info()
self.stop()
def __rethrow(self):
if self.__failure is not None:
failure = self.__failure
self.__failure = None
raise_exc_info(failure)
def run(self, result=None):
with StackContext(self._stack_context):
super(AsyncTestCase, self).run(result)
# In case an exception escaped super.run or the StackContext caught
# an exception when there wasn't a wait() to re-raise it, do so here.
self.__rethrow()
def stop(self, _arg=None, **kwargs):
'''Stops the ioloop, causing one pending (or future) call to wait()
@@ -165,12 +182,14 @@ class AsyncTestCase(unittest.TestCase):
def timeout_func():
try:
raise self.failureException(
'Async operation timed out after %d seconds' %
'Async operation timed out after %s seconds' %
timeout)
except Exception:
self.__failure = sys.exc_info()
self.stop()
self.io_loop.add_timeout(time.time() + timeout, timeout_func)
if self.__timeout is not None:
self.io_loop.remove_timeout(self.__timeout)
self.__timeout = self.io_loop.add_timeout(time.time() + timeout, timeout_func)
while True:
self.__running = True
with NullContext():
@@ -183,13 +202,7 @@ class AsyncTestCase(unittest.TestCase):
break
assert self.__stopped
self.__stopped = False
if self.__failure is not None:
# 2to3 isn't smart enough to convert three-argument raise
# statements correctly in some cases.
if isinstance(self.__failure[1], self.__failure[0]):
raise self.__failure[1], None, self.__failure[2]
else:
raise self.__failure[0], self.__failure[1], self.__failure[2]
self.__rethrow()
result = self.__stop_args
self.__stop_args = None
return result
@@ -222,12 +235,19 @@ class AsyncHTTPTestCase(AsyncTestCase):
super(AsyncHTTPTestCase, self).setUp()
self.__port = None
self.http_client = AsyncHTTPClient(io_loop=self.io_loop)
self.http_client = self.get_http_client()
self._app = self.get_app()
self.http_server = HTTPServer(self._app, io_loop=self.io_loop,
**self.get_httpserver_options())
self.http_server = self.get_http_server()
self.http_server.listen(self.get_http_port(), address="127.0.0.1")
def get_http_client(self):
return AsyncHTTPClient(io_loop=self.io_loop)
def get_http_server(self):
return HTTPServer(self._app, io_loop=self.io_loop,
**self.get_httpserver_options())
def get_app(self):
"""Should be overridden by subclasses to return a
tornado.web.Application or other HTTPServer callback.
@@ -247,12 +267,12 @@ class AsyncHTTPTestCase(AsyncTestCase):
def get_httpserver_options(self):
"""May be overridden by subclasses to return additional
keyword arguments for HTTPServer.
keyword arguments for the server.
"""
return {}
def get_http_port(self):
"""Returns the port used by the HTTPServer.
"""Returns the port used by the server.
A new port is chosen for each test.
"""
@@ -260,15 +280,53 @@ class AsyncHTTPTestCase(AsyncTestCase):
self.__port = get_unused_port()
return self.__port
def get_protocol(self):
return 'http'
def get_url(self, path):
"""Returns an absolute url for the given path on the test server."""
return 'http://localhost:%s%s' % (self.get_http_port(), path)
return '%s://localhost:%s%s' % (self.get_protocol(),
self.get_http_port(), path)
def tearDown(self):
self.http_server.stop()
self.http_client.close()
super(AsyncHTTPTestCase, self).tearDown()
class AsyncHTTPSTestCase(AsyncHTTPTestCase):
"""A test case that starts an HTTPS server.
Interface is generally the same as `AsyncHTTPTestCase`.
"""
def get_http_client(self):
# Some versions of libcurl have deadlock bugs with ssl,
# so always run these tests with SimpleAsyncHTTPClient.
return SimpleAsyncHTTPClient(io_loop=self.io_loop, force_instance=True)
def get_httpserver_options(self):
return dict(ssl_options=self.get_ssl_options())
def get_ssl_options(self):
"""May be overridden by subclasses to select SSL options.
By default includes a self-signed testing certificate.
"""
# Testing keys were generated with:
# openssl req -new -keyout tornado/test/test.key -out tornado/test/test.crt -nodes -days 3650 -x509
module_dir = os.path.dirname(__file__)
return dict(
certfile=os.path.join(module_dir, 'test', 'test.crt'),
keyfile=os.path.join(module_dir, 'test', 'test.key'))
def get_protocol(self):
return 'https'
def fetch(self, path, **kwargs):
return AsyncHTTPTestCase.fetch(self, path, validate_cert=False,
**kwargs)
class LogTrapTestCase(unittest.TestCase):
"""A test case that captures and discards all logging output
if the test passes.
@@ -308,7 +366,8 @@ class LogTrapTestCase(unittest.TestCase):
finally:
handler.stream = old_stream
def main():
def main(**kwargs):
"""A simple test runner.
This test runner is essentially equivalent to `unittest.main` from
@@ -329,10 +388,15 @@ def main():
be overridden by naming a single test on the command line::
# Runs all tests
tornado/test/runtests.py
python -m tornado.test.runtests
# Runs one test
tornado/test/runtests.py tornado.test.stack_context_test
python -m tornado.test.runtests tornado.test.stack_context_test
Additional keyword arguments passed through to ``unittest.main()``.
For example, use ``tornado.testing.main(verbosity=2)``
to show many test details as they are run.
See http://docs.python.org/library/unittest.html#unittest.main
for full argument list.
"""
from tornado.options import define, options, parse_command_line
@@ -364,9 +428,9 @@ def main():
# test discovery, which is incompatible with auto2to3), so don't
# set module if we're not asking for a specific test.
if len(argv) > 1:
unittest.main(module=None, argv=argv)
unittest.main(module=None, argv=argv, **kwargs)
else:
unittest.main(defaultTest="all", argv=argv)
unittest.main(defaultTest="all", argv=argv, **kwargs)
except SystemExit, e:
if e.code == 0:
logging.info('PASS')
Regular → Executable
+54
View File
@@ -1,5 +1,10 @@
"""Miscellaneous utility functions."""
from __future__ import absolute_import, division, with_statement
import zlib
class ObjectDict(dict):
"""Makes a dictionary behave like an object."""
def __getattr__(self, name):
@@ -12,6 +17,36 @@ class ObjectDict(dict):
self[name] = value
class GzipDecompressor(object):
"""Streaming gzip decompressor.
The interface is like that of `zlib.decompressobj` (without the
optional arguments, but it understands gzip headers and checksums.
"""
def __init__(self):
# Magic parameter makes zlib module understand gzip header
# http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib
# This works on cpython and pypy, but not jython.
self.decompressobj = zlib.decompressobj(16 + zlib.MAX_WBITS)
def decompress(self, value):
"""Decompress a chunk, returning newly-available data.
Some data may be buffered for later processing; `flush` must
be called when there is no more input data to ensure that
all data was processed.
"""
return self.decompressobj.decompress(value)
def flush(self):
"""Return any remaining buffered data not yet returned by decompress.
Also checks for errors such as truncated input.
No other methods may be called on this object after `flush`.
"""
return self.decompressobj.flush()
def import_object(name):
"""Imports an object by name.
@@ -42,6 +77,25 @@ else:
return s
bytes_type = str
def raise_exc_info(exc_info):
"""Re-raise an exception (with original traceback) from an exc_info tuple.
The argument is a ``(type, value, traceback)`` tuple as returned by
`sys.exc_info`.
"""
# 2to3 isn't smart enough to convert three-argument raise
# statements correctly in some cases.
if isinstance(exc_info[1], exc_info[0]):
raise exc_info[1], None, exc_info[2]
# After 2to3: raise exc_info[1].with_traceback(exc_info[2])
else:
# I think this branch is only taken for string exceptions,
# which were removed in Python 2.6.
raise exc_info[0], exc_info[1], exc_info[2]
# After 2to3: raise exc_info[0](exc_info[1]).with_traceback(exc_info[2])
def doctests():
import doctest
return doctest.DocTestSuite()
Regular → Executable
+147 -72
View File
@@ -49,7 +49,7 @@ threads it is important to use IOLoop.add_callback to transfer control
back to the main thread before finishing the request.
"""
from __future__ import with_statement
from __future__ import absolute_import, division, with_statement
import Cookie
import base64
@@ -83,13 +83,14 @@ from tornado import locale
from tornado import stack_context
from tornado import template
from tornado.escape import utf8, _unicode
from tornado.util import b, bytes_type, import_object, ObjectDict
from tornado.util import b, bytes_type, import_object, ObjectDict, raise_exc_info
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
class RequestHandler(object):
"""Subclass this class and define get() or post() to make a handler.
@@ -97,7 +98,8 @@ class RequestHandler(object):
should override the class variable SUPPORTED_METHODS in your
RequestHandler class.
"""
SUPPORTED_METHODS = ("GET", "HEAD", "POST", "DELETE", "PUT", "OPTIONS")
SUPPORTED_METHODS = ("GET", "HEAD", "POST", "DELETE", "PATCH", "PUT",
"OPTIONS")
_template_loaders = {} # {path: template.BaseLoader}
_template_loader_lock = threading.Lock()
@@ -121,7 +123,7 @@ class RequestHandler(object):
self.ui["modules"] = self.ui["_modules"]
self.clear()
# Check since connection is not available in WSGI
if hasattr(self.request, "connection"):
if getattr(self.request, "connection", None):
self.request.connection.stream.set_close_callback(
self.on_connection_close)
self.initialize(**kwargs)
@@ -164,6 +166,9 @@ class RequestHandler(object):
def delete(self, *args, **kwargs):
raise HTTPError(405)
def patch(self, *args, **kwargs):
raise HTTPError(405)
def put(self, *args, **kwargs):
raise HTTPError(405)
@@ -208,7 +213,7 @@ class RequestHandler(object):
"""Resets all headers and content for this response."""
# The performance cost of tornado.httputil.HTTPHeaders is significant
# (slowing down a benchmark with a trivial handler by more than 10%),
# and its case-normalization is not generally necessary for
# and its case-normalization is not generally necessary for
# headers we generate on the server side, so use a plain dict
# and list instead.
self._headers = {
@@ -259,6 +264,15 @@ class RequestHandler(object):
"""
self._list_headers.append((name, self._convert_header_value(value)))
def clear_header(self, name):
"""Clears an outgoing header, undoing a previous `set_header` call.
Note that this method does not apply to multi-valued headers
set by `add_header`.
"""
if name in self._headers:
del self._headers[name]
def _convert_header_value(self, value):
if isinstance(value, bytes_type):
pass
@@ -279,8 +293,8 @@ class RequestHandler(object):
raise ValueError("Unsafe header value %r", value)
return value
_ARG_DEFAULT = []
def get_argument(self, name, default=_ARG_DEFAULT, strip=True):
"""Returns the value of the argument with the given name.
@@ -358,25 +372,27 @@ class RequestHandler(object):
if re.search(r"[\x00-\x20]", name + value):
# Don't let us accidentally inject bad stuff
raise ValueError("Invalid cookie %r: %r" % (name, value))
if not hasattr(self, "_new_cookies"):
self._new_cookies = []
new_cookie = Cookie.SimpleCookie()
self._new_cookies.append(new_cookie)
new_cookie[name] = value
if not hasattr(self, "_new_cookie"):
self._new_cookie = Cookie.SimpleCookie()
if name in self._new_cookie:
del self._new_cookie[name]
self._new_cookie[name] = value
morsel = self._new_cookie[name]
if domain:
new_cookie[name]["domain"] = domain
morsel["domain"] = domain
if expires_days is not None and not expires:
expires = datetime.datetime.utcnow() + datetime.timedelta(
days=expires_days)
if expires:
timestamp = calendar.timegm(expires.utctimetuple())
new_cookie[name]["expires"] = email.utils.formatdate(
morsel["expires"] = email.utils.formatdate(
timestamp, localtime=False, usegmt=True)
if path:
new_cookie[name]["path"] = path
morsel["path"] = path
for k, v in kwargs.iteritems():
if k == 'max_age': k = 'max-age'
new_cookie[name][k] = v
if k == 'max_age':
k = 'max-age'
morsel[k] = v
def clear_cookie(self, name, path="/", domain=None):
"""Deletes the cookie with the given name."""
@@ -401,6 +417,9 @@ class RequestHandler(object):
Note that the ``expires_days`` parameter sets the lifetime of the
cookie in the browser, but is independent of the ``max_age_days``
parameter to `get_secure_cookie`.
Secure cookies may contain arbitrary byte values, not just unicode
strings (unlike regular cookies)
"""
self.set_cookie(name, self.create_signed_value(name, value),
expires_days=expires_days, **kwargs)
@@ -417,9 +436,14 @@ class RequestHandler(object):
name, value)
def get_secure_cookie(self, name, value=None, max_age_days=31):
"""Returns the given signed cookie if it validates, or None."""
"""Returns the given signed cookie if it validates, or None.
The decoded cookie value is returned as a byte string (unlike
`get_cookie`).
"""
self.require_setting("cookie_secret", "secure cookies")
if value is None: value = self.get_cookie(name)
if value is None:
value = self.get_cookie(name)
return decode_signed_value(self.application.settings["cookie_secret"],
name, value, max_age_days=max_age_days)
@@ -482,7 +506,8 @@ class RequestHandler(object):
html_bodies = []
for module in getattr(self, "_active_modules", {}).itervalues():
embed_part = module.embedded_javascript()
if embed_part: js_embed.append(utf8(embed_part))
if embed_part:
js_embed.append(utf8(embed_part))
file_part = module.javascript_files()
if file_part:
if isinstance(file_part, (unicode, bytes_type)):
@@ -490,7 +515,8 @@ class RequestHandler(object):
else:
js_files.extend(file_part)
embed_part = module.embedded_css()
if embed_part: css_embed.append(utf8(embed_part))
if embed_part:
css_embed.append(utf8(embed_part))
file_part = module.css_files()
if file_part:
if isinstance(file_part, (unicode, bytes_type)):
@@ -498,9 +524,12 @@ class RequestHandler(object):
else:
css_files.extend(file_part)
head_part = module.html_head()
if head_part: html_heads.append(utf8(head_part))
if head_part:
html_heads.append(utf8(head_part))
body_part = module.html_body()
if body_part: html_bodies.append(utf8(body_part))
if body_part:
html_bodies.append(utf8(body_part))
def is_absolute(path):
return any(path.startswith(x) for x in ["/", "http:", "https:"])
if js_files:
@@ -579,7 +608,7 @@ class RequestHandler(object):
_=self.locale.translate,
static_url=self.static_url,
xsrf_form_html=self.xsrf_form_html,
reverse_url=self.application.reverse_url
reverse_url=self.reverse_url
)
args.update(self.ui)
args.update(kwargs)
@@ -596,10 +625,9 @@ class RequestHandler(object):
kwargs["autoescape"] = settings["autoescape"]
return template.Loader(template_path, **kwargs)
def flush(self, include_footers=False, callback=None):
"""Flushes the current output buffer to the network.
The ``callback`` argument, if given, can be used for flow control:
it will be run when all flushed data has been written to the socket.
Note that only one flush callback can be outstanding at a time;
@@ -614,8 +642,9 @@ class RequestHandler(object):
if not self._headers_written:
self._headers_written = True
for transform in self._transforms:
self._headers, chunk = transform.transform_first_chunk(
self._headers, chunk, include_footers)
self._status_code, self._headers, chunk = \
transform.transform_first_chunk(
self._status_code, self._headers, chunk, include_footers)
headers = self._generate_headers()
else:
for transform in self._transforms:
@@ -624,11 +653,11 @@ class RequestHandler(object):
# Ignore the chunk and only write the headers for HEAD requests
if self.request.method == "HEAD":
if headers: self.request.write(headers, callback=callback)
if headers:
self.request.write(headers, callback=callback)
return
if headers or chunk:
self.request.write(headers + chunk, callback=callback)
self.request.write(headers + chunk, callback=callback)
def finish(self, chunk=None):
"""Finishes this response, ending the HTTP request."""
@@ -637,7 +666,8 @@ class RequestHandler(object):
"by using async operations without the "
"@asynchronous decorator.")
if chunk is not None: self.write(chunk)
if chunk is not None:
self.write(chunk)
# Automatically support ETags and add the Content-Length header if
# we have not flushed any content yet.
@@ -647,13 +677,15 @@ class RequestHandler(object):
"Etag" not in self._headers):
etag = self.compute_etag()
if etag is not None:
self.set_header("Etag", etag)
inm = self.request.headers.get("If-None-Match")
if inm and inm.find(etag) != -1:
self._write_buffer = []
self.set_status(304)
else:
self.set_header("Etag", etag)
if "Content-Length" not in self._headers:
if self._status_code == 304:
assert not self._write_buffer, "Cannot send body with 304"
self._clear_headers_for_304()
elif "Content-Length" not in self._headers:
content_length = sum(len(part) for part in self._write_buffer)
self.set_header("Content-Length", content_length)
@@ -720,7 +752,7 @@ class RequestHandler(object):
kwargs['exception'] = exc_info[1]
try:
# Put the traceback into sys.exc_info()
raise exc_info[0], exc_info[1], exc_info[2]
raise_exc_info(exc_info)
except Exception:
self.finish(self.get_error_html(status_code, **kwargs))
else:
@@ -733,7 +765,7 @@ class RequestHandler(object):
self.write(line)
self.finish()
else:
self.finish("<html><title>%(code)d: %(message)s</title>"
self.finish("<html><title>%(code)d: %(message)s</title>"
"<body>%(code)d: %(message)s</body></html>" % {
"code": status_code,
"message": httplib.responses[status_code],
@@ -926,6 +958,7 @@ class RequestHandler(object):
return None
if args or kwargs:
callback = functools.partial(callback, *args, **kwargs)
def wrapper(*args, **kwargs):
try:
return callback(*args, **kwargs)
@@ -964,7 +997,7 @@ class RequestHandler(object):
# the exception value instead of the full triple,
# so re-raise the exception to ensure that it's in
# sys.exc_info()
raise type, value, traceback
raise_exc_info((type, value, traceback))
except Exception:
self._handle_request_exception(value)
return True
@@ -984,7 +1017,7 @@ class RequestHandler(object):
if not self._finished:
args = [self.decode_argument(arg) for arg in args]
kwargs = dict((k, self.decode_argument(v, name=k))
for (k,v) in kwargs.iteritems())
for (k, v) in kwargs.iteritems())
getattr(self, self.request.method.lower())(*args, **kwargs)
if self._auto_finish and not self._finished:
self.finish()
@@ -995,10 +1028,10 @@ class RequestHandler(object):
lines = [utf8(self.request.version + " " +
str(self._status_code) +
" " + httplib.responses[self._status_code])]
lines.extend([(utf8(n) + b(": ") + utf8(v)) for n, v in
lines.extend([(utf8(n) + b(": ") + utf8(v)) for n, v in
itertools.chain(self._headers.iteritems(), self._list_headers)])
for cookie_dict in getattr(self, "_new_cookies", []):
for cookie in cookie_dict.values():
if hasattr(self, "_new_cookie"):
for cookie in self._new_cookie.values():
lines.append(utf8("Set-Cookie: " + cookie.OutputString(None)))
return b("\r\n").join(lines) + b("\r\n\r\n")
@@ -1044,6 +1077,17 @@ class RequestHandler(object):
def _ui_method(self, method):
return lambda *args, **kwargs: method(self, *args, **kwargs)
def _clear_headers_for_304(self):
# 304 responses should not contain entity headers (defined in
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.1)
# not explicitly allowed by
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.5
headers = ["Allow", "Content-Encoding", "Content-Language",
"Content-Length", "Content-MD5", "Content-Range",
"Content-Type", "Last-Modified"]
for h in headers:
self.clear_header(h)
def asynchronous(method):
"""Wrap request handler methods with this if they are asynchronous.
@@ -1088,8 +1132,9 @@ def removeslash(method):
if self.request.method in ("GET", "HEAD"):
uri = self.request.path.rstrip("/")
if uri: # don't try to redirect '/' to ''
if self.request.query: uri += "?" + self.request.query
self.redirect(uri)
if self.request.query:
uri += "?" + self.request.query
self.redirect(uri, permanent=True)
return
else:
raise HTTPError(404)
@@ -1109,8 +1154,9 @@ def addslash(method):
if not self.request.path.endswith("/"):
if self.request.method in ("GET", "HEAD"):
uri = self.request.path + "/"
if self.request.query: uri += "?" + self.request.query
self.redirect(uri)
if self.request.query:
uri += "?" + self.request.query
self.redirect(uri, permanent=True)
return
raise HTTPError(404)
return method(self, *args, **kwargs)
@@ -1162,9 +1208,15 @@ class Application(object):
.. attribute:: settings
Additonal keyword arguments passed to the constructor are saved in the
Additional keyword arguments passed to the constructor are saved in the
`settings` dictionary, and are often referred to in documentation as
"application settings".
.. attribute:: debug
If `True` the application runs in debug mode, described in
:ref:`debug-mode`. This is an application setting in the `settings`
dictionary, so handlers can access it.
"""
def __init__(self, handlers=None, default_host="", transforms=None,
wsgi=False, **settings):
@@ -1200,7 +1252,8 @@ class Application(object):
r"/(favicon\.ico)", r"/(robots\.txt)"]:
handlers.insert(0, (pattern, static_handler_class,
static_handler_args))
if handlers: self.add_handlers(".*$", handlers)
if handlers:
self.add_handlers(".*$", handlers)
# Automatically reload modified modules
if self.settings.get("debug") and not wsgi:
@@ -1292,7 +1345,8 @@ class Application(object):
self._load_ui_methods(dict((n, getattr(methods, n))
for n in dir(methods)))
elif isinstance(methods, list):
for m in methods: self._load_ui_methods(m)
for m in methods:
self._load_ui_methods(m)
else:
for name, fn in methods.iteritems():
if not name.startswith("_") and hasattr(fn, "__call__") \
@@ -1304,7 +1358,8 @@ class Application(object):
self._load_ui_modules(dict((n, getattr(modules, n))
for n in dir(modules)))
elif isinstance(modules, list):
for m in modules: self._load_ui_modules(m)
for m in modules:
self._load_ui_modules(m)
else:
assert isinstance(modules, dict)
for name, cls in modules.iteritems():
@@ -1333,7 +1388,8 @@ class Application(object):
# None-safe wrapper around url_unescape to handle
# unmatched optional groups correctly
def unquote(s):
if s is None: return s
if s is None:
return s
return escape.url_unescape(s, encoding=None)
# Pass matched groups to the handler. Since
# match.groups() includes both named and unnamed groups,
@@ -1343,7 +1399,7 @@ class Application(object):
if spec.regex.groupindex:
kwargs = dict(
(k, unquote(v))
(str(k), unquote(v))
for (k, v) in match.groupdict().iteritems())
else:
args = [unquote(s) for s in match.groups()]
@@ -1365,7 +1421,11 @@ class Application(object):
def reverse_url(self, name, *args):
"""Returns a URL path for handler named `name`
The handler must be added to the application as a named URLSpec
The handler must be added to the application as a named URLSpec.
Args will be substituted for capturing groups in the URLSpec regex.
They will be converted to strings if necessary, encoded as utf8,
and url-escaped.
"""
if name in self.named_handlers:
return self.named_handlers[name].reverse(*args)
@@ -1393,7 +1453,6 @@ class Application(object):
handler._request_summary(), request_time)
class HTTPError(Exception):
"""An exception that will turn into an HTTP error response."""
def __init__(self, status_code, log_message=None, *args):
@@ -1455,7 +1514,7 @@ class StaticFileHandler(RequestHandler):
/static/images/myimage.png?v=xxx. Override ``get_cache_time`` method for
more fine-grained cache control.
"""
CACHE_MAX_AGE = 86400*365*10 #10 years
CACHE_MAX_AGE = 86400 * 365 * 10 # 10 years
_static_hashes = {}
_lock = threading.Lock() # protects _static_hashes
@@ -1554,7 +1613,7 @@ class StaticFileHandler(RequestHandler):
This method may be overridden in subclasses (but note that it is
a class method rather than an instance method).
``settings`` is the `Application.settings` dictionary. ``path``
is the static path being requested. The url returned should be
relative to the current host.
@@ -1639,8 +1698,8 @@ class OutputTransform(object):
def __init__(self, request):
pass
def transform_first_chunk(self, headers, chunk, finishing):
return headers, chunk
def transform_first_chunk(self, status_code, headers, chunk, finishing):
return status_code, headers, chunk
def transform_chunk(self, chunk, finishing):
return chunk
@@ -1652,7 +1711,7 @@ class GZipContentEncoding(OutputTransform):
See http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.11
"""
CONTENT_TYPES = set([
"text/plain", "text/html", "text/css", "text/xml", "application/javascript",
"text/plain", "text/html", "text/css", "text/xml", "application/javascript",
"application/x-javascript", "application/xml", "application/atom+xml",
"text/javascript", "application/json", "application/xhtml+xml"])
MIN_LENGTH = 5
@@ -1661,7 +1720,7 @@ class GZipContentEncoding(OutputTransform):
self._gzipping = request.supports_http_1_1() and \
"gzip" in request.headers.get("Accept-Encoding", "")
def transform_first_chunk(self, headers, chunk, finishing):
def transform_first_chunk(self, status_code, headers, chunk, finishing):
if self._gzipping:
ctype = _unicode(headers.get("Content-Type", "")).split(";")[0]
self._gzipping = (ctype in self.CONTENT_TYPES) and \
@@ -1675,7 +1734,7 @@ class GZipContentEncoding(OutputTransform):
chunk = self.transform_chunk(chunk, finishing)
if "Content-Length" in headers:
headers["Content-Length"] = str(len(chunk))
return headers, chunk
return status_code, headers, chunk
def transform_chunk(self, chunk, finishing):
if self._gzipping:
@@ -1698,15 +1757,17 @@ class ChunkedTransferEncoding(OutputTransform):
def __init__(self, request):
self._chunking = request.supports_http_1_1()
def transform_first_chunk(self, headers, chunk, finishing):
if self._chunking:
def transform_first_chunk(self, status_code, headers, chunk, finishing):
# 304 responses have no body (not even a zero-length body), and so
# should not have either Content-Length or Transfer-Encoding headers.
if self._chunking and status_code != 304:
# No need to chunk the output if a Content-Length is specified
if "Content-Length" in headers or "Transfer-Encoding" in headers:
self._chunking = False
else:
headers["Transfer-Encoding"] = "chunked"
chunk = self.transform_chunk(chunk, finishing)
return headers, chunk
return status_code, headers, chunk
def transform_chunk(self, block, finishing):
if self._chunking:
@@ -1786,14 +1847,17 @@ class UIModule(object):
"""Renders a template and returns it as a string."""
return self.handler.render_string(path, **kwargs)
class _linkify(UIModule):
def render(self, text, **kwargs):
return escape.linkify(text, **kwargs)
class _xsrf_form_html(UIModule):
def render(self):
return self.handler.xsrf_form_html()
class TemplateModule(UIModule):
"""UIModule that simply renders the given template.
@@ -1806,7 +1870,7 @@ class TemplateModule(UIModule):
inside the template and give it keyword arguments corresponding to
the methods on UIModule: {{ set_resources(js_files=static_url("my.js")) }}
Note that these resources are output once per template file, not once
per instantiation of the template, so they must not depend on
per instantiation of the template, so they must not depend on
any arguments to the template.
"""
def __init__(self, handler):
@@ -1862,10 +1926,9 @@ class TemplateModule(UIModule):
return "".join(self._get_resources("html_body"))
class URLSpec(object):
"""Specifies mappings between URLs and handlers."""
def __init__(self, pattern, handler_class, kwargs={}, name=None):
def __init__(self, pattern, handler_class, kwargs=None, name=None):
"""Creates a URLSpec.
Parameters:
@@ -1889,7 +1952,7 @@ class URLSpec(object):
("groups in url regexes must either be all named or all "
"positional: %r" % self.regex.pattern)
self.handler_class = handler_class
self.kwargs = kwargs
self.kwargs = kwargs or {}
self.name = name
self._path, self._group_count = self._find_groups()
@@ -1928,7 +1991,12 @@ class URLSpec(object):
"not found"
if not len(args):
return self._path
return self._path % tuple([str(a) for a in args])
converted_args = []
for a in args:
if not isinstance(a, (unicode, bytes_type)):
a = str(a)
converted_args.append(escape.url_escape(utf8(a)))
return self._path % tuple(converted_args)
url = URLSpec
@@ -1938,13 +2006,14 @@ def _time_independent_equals(a, b):
return False
result = 0
if type(a[0]) is int: # python3 byte strings
for x, y in zip(a,b):
for x, y in zip(a, b):
result |= x ^ y
else: # python2
for x, y in zip(a, b):
result |= ord(x) ^ ord(y)
return result == 0
def create_signed_value(secret, name, value):
timestamp = utf8(str(int(time.time())))
value = base64.b64encode(utf8(value))
@@ -1952,10 +2021,13 @@ def create_signed_value(secret, name, value):
value = b("|").join([value, timestamp, signature])
return value
def decode_signed_value(secret, name, value, max_age_days=31):
if not value: return None
if not value:
return None
parts = utf8(value).split(b("|"))
if len(parts) != 3: return None
if len(parts) != 3:
return None
signature = _create_signature(secret, name, parts[0], parts[1])
if not _time_independent_equals(parts[2], signature):
logging.warning("Invalid cookie signature %r", value)
@@ -1974,12 +2046,15 @@ def decode_signed_value(secret, name, value, max_age_days=31):
return None
if parts[1].startswith(b("0")):
logging.warning("Tampered cookie %r", value)
return None
try:
return base64.b64decode(parts[0])
except Exception:
return None
def _create_signature(secret, *parts):
hash = hmac.new(utf8(secret), digestmod=hashlib.sha1)
for part in parts: hash.update(utf8(part))
for part in parts:
hash.update(utf8(part))
return utf8(hash.hexdigest())
Regular → Executable
+12 -8
View File
@@ -16,6 +16,8 @@ communication between the browser and server.
overriding `WebSocketHandler.allow_draft76` (see that method's
documentation for caveats).
"""
from __future__ import absolute_import, division, with_statement
# Author: Jacob Kristhammar, 2010
import array
@@ -30,6 +32,7 @@ import tornado.web
from tornado.util import bytes_type, b
class WebSocketHandler(tornado.web.RequestHandler):
"""Subclass this class to create a basic WebSocket handler.
@@ -202,7 +205,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
may wish to override this if they are using an SSL proxy
that does not provide the X-Scheme header as understood
by HTTPServer.
Note that this is only used by the draft76 protocol.
"""
return "wss" if self.request.protocol == "https" else "ws"
@@ -249,6 +252,7 @@ class WebSocketProtocol(object):
"""
if args or kwargs:
callback = functools.partial(callback, *args, **kwargs)
def wrapper(*args, **kwargs):
try:
return callback(*args, **kwargs)
@@ -471,7 +475,7 @@ class WebSocketProtocol13(WebSocketProtocol):
sha1 = hashlib.sha1()
sha1.update(tornado.escape.utf8(
self.request.headers.get("Sec-Websocket-Key")))
sha1.update(b("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) # Magic value
sha1.update(b("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) # Magic value
return tornado.escape.native_str(base64.b64encode(sha1.digest()))
def _accept_connection(self):
@@ -552,12 +556,12 @@ class WebSocketProtocol13(WebSocketProtocol):
self.stream.read_bytes(8, self._on_frame_length_64)
def _on_frame_length_16(self, data):
self._frame_length = struct.unpack("!H", data)[0];
self.stream.read_bytes(4, self._on_masking_key);
self._frame_length = struct.unpack("!H", data)[0]
self.stream.read_bytes(4, self._on_masking_key)
def _on_frame_length_64(self, data):
self._frame_length = struct.unpack("!Q", data)[0];
self.stream.read_bytes(4, self._on_masking_key);
self._frame_length = struct.unpack("!Q", data)[0]
self.stream.read_bytes(4, self._on_masking_key)
def _on_masking_key(self, data):
self._frame_mask = array.array("B", data)
@@ -604,9 +608,9 @@ class WebSocketProtocol13(WebSocketProtocol):
if not self.client_terminated:
self._receive_frame()
def _handle_message(self, opcode, data):
if self.client_terminated: return
if self.client_terminated:
return
if opcode == 0x1:
# UTF-8 data
Regular → Executable
+45 -28
View File
@@ -20,7 +20,7 @@ WSGI is the Python standard for web servers, and allows for interoperability
between Tornado and other Python web frameworks and servers. This module
provides WSGI support in two ways:
* `WSGIApplication` is a version of `tornado.web.Application` that can run
* `WSGIApplication` is a version of `tornado.web.Application` that can run
inside a WSGI server. This is useful for running a Tornado app on another
HTTP server, such as Google App Engine. See the `WSGIApplication` class
documentation for limitations that apply.
@@ -29,8 +29,9 @@ provides WSGI support in two ways:
and Tornado handlers in a single server.
"""
from __future__ import absolute_import, division, with_statement
import Cookie
import cgi
import httplib
import logging
import sys
@@ -41,14 +42,37 @@ import urllib
from tornado import escape
from tornado import httputil
from tornado import web
from tornado.escape import native_str, utf8
from tornado.util import b
from tornado.escape import native_str, utf8, parse_qs_bytes
from tornado.util import b, bytes_type
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
# PEP 3333 specifies that WSGI on python 3 generally deals with byte strings
# that are smuggled inside objects of type unicode (via the latin1 encoding).
# These functions are like those in the tornado.escape module, but defined
# here to minimize the temptation to use them in non-wsgi contexts.
if str is unicode:
def to_wsgi_str(s):
assert isinstance(s, bytes_type)
return s.decode('latin1')
def from_wsgi_str(s):
assert isinstance(s, str)
return s.encode('latin1')
else:
def to_wsgi_str(s):
assert isinstance(s, bytes_type)
return s
def from_wsgi_str(s):
assert isinstance(s, str)
return s
class WSGIApplication(web.Application):
"""A WSGI equivalent of `tornado.web.Application`.
@@ -81,7 +105,7 @@ class WSGIApplication(web.Application):
Since no asynchronous methods are available for WSGI applications, the
httpclient and auth modules are both not available for WSGI applications.
We support the same interface, but handlers running in a WSGIApplication
do not support flush() or asynchronous methods.
do not support flush() or asynchronous methods.
"""
def __init__(self, handlers=None, default_host="", **settings):
web.Application.__init__(self, handlers, default_host, transforms=[],
@@ -92,12 +116,12 @@ class WSGIApplication(web.Application):
assert handler._finished
status = str(handler._status_code) + " " + \
httplib.responses[handler._status_code]
headers = handler._headers.items()
for cookie_dict in getattr(handler, "_new_cookies", []):
for cookie in cookie_dict.values():
headers = handler._headers.items() + handler._list_headers
if hasattr(handler, "_new_cookie"):
for cookie in handler._new_cookie.values():
headers.append(("Set-Cookie", cookie.OutputString(None)))
start_response(status,
[(native_str(k), native_str(v)) for (k,v) in headers])
[(native_str(k), native_str(v)) for (k, v) in headers])
return handler._write_buffer
@@ -106,17 +130,18 @@ class HTTPRequest(object):
def __init__(self, environ):
"""Parses the given WSGI environ to construct the request."""
self.method = environ["REQUEST_METHOD"]
self.path = urllib.quote(environ.get("SCRIPT_NAME", ""))
self.path += urllib.quote(environ.get("PATH_INFO", ""))
self.path = urllib.quote(from_wsgi_str(environ.get("SCRIPT_NAME", "")))
self.path += urllib.quote(from_wsgi_str(environ.get("PATH_INFO", "")))
self.uri = self.path
self.arguments = {}
self.query = environ.get("QUERY_STRING", "")
if self.query:
self.uri += "?" + self.query
arguments = cgi.parse_qs(self.query)
arguments = parse_qs_bytes(native_str(self.query))
for name, values in arguments.iteritems():
values = [v for v in values if v]
if values: self.arguments[name] = values
if values:
self.arguments[name] = values
self.version = "HTTP/1.1"
self.headers = httputil.HTTPHeaders()
if environ.get("CONTENT_TYPE"):
@@ -140,18 +165,8 @@ class HTTPRequest(object):
# Parse request body
self.files = {}
content_type = self.headers.get("Content-Type", "")
if content_type.startswith("application/x-www-form-urlencoded"):
for name, values in cgi.parse_qs(self.body).iteritems():
self.arguments.setdefault(name, []).extend(values)
elif content_type.startswith("multipart/form-data"):
if 'boundary=' in content_type:
boundary = content_type.split('boundary=',1)[1]
if boundary:
httputil.parse_multipart_form_data(
utf8(boundary), self.body, self.arguments, self.files)
else:
logging.warning("Invalid multipart/form-data")
httputil.parse_body_arguments(self.headers.get("Content-Type", ""),
self.body, self.arguments, self.files)
self._start_time = time.time()
self._finish_time = None
@@ -215,6 +230,7 @@ class WSGIContainer(object):
def __call__(self, request):
data = {}
response = []
def start_response(status, response_headers, exc_info=None):
data["status"] = status
data["headers"] = response_headers
@@ -225,11 +241,12 @@ class WSGIContainer(object):
body = b("").join(response)
if hasattr(app_response, "close"):
app_response.close()
if not data: raise Exception("WSGI app did not call start_response")
if not data:
raise Exception("WSGI app did not call start_response")
status_code = int(data["status"].split()[0])
headers = data["headers"]
header_set = set(k.lower() for (k,v) in headers)
header_set = set(k.lower() for (k, v) in headers)
body = escape.utf8(body)
if "content-length" not in header_set:
headers.append(("Content-Length", str(len(body))))
@@ -261,7 +278,7 @@ class WSGIContainer(object):
environ = {
"REQUEST_METHOD": request.method,
"SCRIPT_NAME": "",
"PATH_INFO": urllib.unquote(request.path),
"PATH_INFO": to_wsgi_str(escape.url_unescape(request.path, encoding=None)),
"QUERY_STRING": request.query,
"REMOTE_ADDR": request.remote_ip,
"SERVER_NAME": host,