Update Tornado webserver to 4.2.dev1 (609dbb9).

Conflicts:
	CHANGES.md
This commit is contained in:
JackDandy 2015-04-27 20:06:19 +01:00
parent bed370f811
commit 84fb3e5df9
65 changed files with 4882 additions and 1164 deletions

View file

@ -1,5 +1,6 @@
### 0.x.x (2015-xx-xx xx:xx:xx UTC)
* Update Tornado webserver to 4.2.dev1 (609dbb9)
* Change network names to only display on top line of Day by Day layout on Episode View
* Reposition country part of network name into the hover over in Day by Day layout
* Add ToTV provider

View file

@ -25,5 +25,5 @@ from __future__ import absolute_import, division, print_function, with_statement
# is zero for an official release, positive for a development branch,
# or negative for a release candidate or beta (after the base version
# number has been incremented)
version = "4.1.dev1"
version_info = (4, 1, 0, -100)
version = "4.2.dev1"
version_info = (4, 2, 0, -100)

View file

@ -32,7 +32,9 @@ They all take slightly different arguments due to the fact all these
services implement authentication and authorization slightly differently.
See the individual service classes below for complete documentation.
Example usage for Google OpenID::
Example usage for Google OAuth:
.. testcode::
class GoogleOAuth2LoginHandler(tornado.web.RequestHandler,
tornado.auth.GoogleOAuth2Mixin):
@ -51,6 +53,10 @@ Example usage for Google OpenID::
response_type='code',
extra_params={'approval_prompt': 'auto'})
.. testoutput::
:hide:
.. versionchanged:: 4.0
All of the callback interfaces in this module are now guaranteed
to run their callback with an argument of ``None`` on error.
@ -69,7 +75,7 @@ import hmac
import time
import uuid
from tornado.concurrent import TracebackFuture, chain_future, return_future
from tornado.concurrent import TracebackFuture, return_future
from tornado import gen
from tornado import httpclient
from tornado import escape
@ -123,6 +129,7 @@ def _auth_return_future(f):
if callback is not None:
future.add_done_callback(
functools.partial(_auth_future_to_callback, callback))
def handle_exception(typ, value, tb):
if future.done():
return False
@ -138,9 +145,6 @@ def _auth_return_future(f):
class OpenIdMixin(object):
"""Abstract implementation of OpenID and Attribute Exchange.
See `GoogleMixin` below for a customized example (which also
includes OAuth support).
Class attributes:
* ``_OPENID_ENDPOINT``: the identity provider's URI.
@ -312,8 +316,7 @@ class OpenIdMixin(object):
class OAuthMixin(object):
"""Abstract implementation of OAuth 1.0 and 1.0a.
See `TwitterMixin` and `FriendFeedMixin` below for example implementations,
or `GoogleMixin` for an OAuth/OpenID hybrid.
See `TwitterMixin` below for an example implementation.
Class attributes:
@ -565,7 +568,8 @@ class OAuthMixin(object):
class OAuth2Mixin(object):
"""Abstract implementation of OAuth 2.0.
See `FacebookGraphMixin` below for an example implementation.
See `FacebookGraphMixin` or `GoogleOAuth2Mixin` below for example
implementations.
Class attributes:
@ -629,7 +633,9 @@ class TwitterMixin(OAuthMixin):
URL you registered as your application's callback URL.
When your application is set up, you can use this mixin like this
to authenticate the user with Twitter and get access to their stream::
to authenticate the user with Twitter and get access to their stream:
.. testcode::
class TwitterLoginHandler(tornado.web.RequestHandler,
tornado.auth.TwitterMixin):
@ -641,6 +647,9 @@ class TwitterMixin(OAuthMixin):
else:
yield self.authorize_redirect()
.. testoutput::
:hide:
The user object returned by `~OAuthMixin.get_authenticated_user`
includes the attributes ``username``, ``name``, ``access_token``,
and all of the custom Twitter user attributes described at
@ -689,7 +698,9 @@ class TwitterMixin(OAuthMixin):
`~OAuthMixin.get_authenticated_user`. The user returned through that
process includes an 'access_token' attribute that can be used
to make authenticated requests via this method. Example
usage::
usage:
.. testcode::
class MainHandler(tornado.web.RequestHandler,
tornado.auth.TwitterMixin):
@ -706,6 +717,9 @@ class TwitterMixin(OAuthMixin):
return
self.finish("Posted a message!")
.. testoutput::
:hide:
"""
if path.startswith('http:') or path.startswith('https:'):
# Raw urls are useful for e.g. search which doesn't follow the
@ -757,223 +771,6 @@ class TwitterMixin(OAuthMixin):
raise gen.Return(user)
class FriendFeedMixin(OAuthMixin):
"""FriendFeed OAuth authentication.
To authenticate with FriendFeed, register your application with
FriendFeed at http://friendfeed.com/api/applications. Then copy
your Consumer Key and Consumer Secret to the application
`~tornado.web.Application.settings` ``friendfeed_consumer_key``
and ``friendfeed_consumer_secret``. Use this mixin on the handler
for the URL you registered as your application's Callback URL.
When your application is set up, you can use this mixin like this
to authenticate the user with FriendFeed and get access to their feed::
class FriendFeedLoginHandler(tornado.web.RequestHandler,
tornado.auth.FriendFeedMixin):
@tornado.gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user()
# Save the user using e.g. set_secure_cookie()
else:
yield self.authorize_redirect()
The user object returned by `~OAuthMixin.get_authenticated_user()` includes the
attributes ``username``, ``name``, and ``description`` in addition to
``access_token``. You should save the access token with the user;
it is required to make requests on behalf of the user later with
`friendfeed_request()`.
"""
_OAUTH_VERSION = "1.0"
_OAUTH_REQUEST_TOKEN_URL = "https://friendfeed.com/account/oauth/request_token"
_OAUTH_ACCESS_TOKEN_URL = "https://friendfeed.com/account/oauth/access_token"
_OAUTH_AUTHORIZE_URL = "https://friendfeed.com/account/oauth/authorize"
_OAUTH_NO_CALLBACKS = True
_OAUTH_VERSION = "1.0"
@_auth_return_future
def friendfeed_request(self, path, callback, access_token=None,
post_args=None, **args):
"""Fetches the given relative API path, e.g., "/bret/friends"
If the request is a POST, ``post_args`` should be provided. Query
string arguments should be given as keyword arguments.
All the FriendFeed methods are documented at
http://friendfeed.com/api/documentation.
Many methods require an OAuth access token which you can
obtain through `~OAuthMixin.authorize_redirect` and
`~OAuthMixin.get_authenticated_user`. The user returned
through that process includes an ``access_token`` attribute that
can be used to make authenticated requests via this
method.
Example usage::
class MainHandler(tornado.web.RequestHandler,
tornado.auth.FriendFeedMixin):
@tornado.web.authenticated
@tornado.gen.coroutine
def get(self):
new_entry = yield self.friendfeed_request(
"/entry",
post_args={"body": "Testing Tornado Web Server"},
access_token=self.current_user["access_token"])
if not new_entry:
# Call failed; perhaps missing permission?
yield self.authorize_redirect()
return
self.finish("Posted a message!")
"""
# Add the OAuth resource request signature if we have credentials
url = "http://friendfeed-api.com/v2" + path
if access_token:
all_args = {}
all_args.update(args)
all_args.update(post_args or {})
method = "POST" if post_args is not None else "GET"
oauth = self._oauth_request_parameters(
url, access_token, all_args, method=method)
args.update(oauth)
if args:
url += "?" + urllib_parse.urlencode(args)
callback = functools.partial(self._on_friendfeed_request, callback)
http = self.get_auth_http_client()
if post_args is not None:
http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args),
callback=callback)
else:
http.fetch(url, callback=callback)
def _on_friendfeed_request(self, future, response):
if response.error:
future.set_exception(AuthError(
"Error response %s fetching %s" % (response.error,
response.request.url)))
return
future.set_result(escape.json_decode(response.body))
def _oauth_consumer_token(self):
self.require_setting("friendfeed_consumer_key", "FriendFeed OAuth")
self.require_setting("friendfeed_consumer_secret", "FriendFeed OAuth")
return dict(
key=self.settings["friendfeed_consumer_key"],
secret=self.settings["friendfeed_consumer_secret"])
@gen.coroutine
def _oauth_get_user_future(self, access_token, callback):
user = yield self.friendfeed_request(
"/feedinfo/" + access_token["username"],
include="id,name,description", access_token=access_token)
if user:
user["username"] = user["id"]
callback(user)
def _parse_user_response(self, callback, user):
if user:
user["username"] = user["id"]
callback(user)
class GoogleMixin(OpenIdMixin, OAuthMixin):
"""Google Open ID / OAuth authentication.
.. deprecated:: 4.0
New applications should use `GoogleOAuth2Mixin`
below instead of this class. As of May 19, 2014, Google has stopped
supporting registration-free authentication.
No application registration is necessary to use Google for
authentication or to access Google resources on behalf of a user.
Google implements both OpenID and OAuth in a hybrid mode. If you
just need the user's identity, use
`~OpenIdMixin.authenticate_redirect`. If you need to make
requests to Google on behalf of the user, use
`authorize_redirect`. On return, parse the response with
`~OpenIdMixin.get_authenticated_user`. We send a dict containing
the values for the user, including ``email``, ``name``, and
``locale``.
Example usage::
class GoogleLoginHandler(tornado.web.RequestHandler,
tornado.auth.GoogleMixin):
@tornado.gen.coroutine
def get(self):
if self.get_argument("openid.mode", None):
user = yield self.get_authenticated_user()
# Save the user with e.g. set_secure_cookie()
else:
yield self.authenticate_redirect()
"""
_OPENID_ENDPOINT = "https://www.google.com/accounts/o8/ud"
_OAUTH_ACCESS_TOKEN_URL = "https://www.google.com/accounts/OAuthGetAccessToken"
@return_future
def authorize_redirect(self, oauth_scope, callback_uri=None,
ax_attrs=["name", "email", "language", "username"],
callback=None):
"""Authenticates and authorizes for the given Google resource.
Some of the available resources which can be used in the ``oauth_scope``
argument are:
* Gmail Contacts - http://www.google.com/m8/feeds/
* Calendar - http://www.google.com/calendar/feeds/
* Finance - http://finance.google.com/finance/feeds/
You can authorize multiple resources by separating the resource
URLs with a space.
.. versionchanged:: 3.1
Returns a `.Future` and takes an optional callback. These are
not strictly necessary as this method is synchronous,
but they are supplied for consistency with
`OAuthMixin.authorize_redirect`.
"""
callback_uri = callback_uri or self.request.uri
args = self._openid_args(callback_uri, ax_attrs=ax_attrs,
oauth_scope=oauth_scope)
self.redirect(self._OPENID_ENDPOINT + "?" + urllib_parse.urlencode(args))
callback()
@_auth_return_future
def get_authenticated_user(self, callback):
"""Fetches the authenticated user data upon redirect."""
# Look to see if we are doing combined OpenID/OAuth
oauth_ns = ""
for name, values in self.request.arguments.items():
if name.startswith("openid.ns.") and \
values[-1] == b"http://specs.openid.net/extensions/oauth/1.0":
oauth_ns = name[10:]
break
token = self.get_argument("openid." + oauth_ns + ".request_token", "")
if token:
http = self.get_auth_http_client()
token = dict(key=token, secret="")
http.fetch(self._oauth_access_token_url(token),
functools.partial(self._on_access_token, callback))
else:
chain_future(OpenIdMixin.get_authenticated_user(self),
callback)
def _oauth_consumer_token(self):
self.require_setting("google_consumer_key", "Google OAuth")
self.require_setting("google_consumer_secret", "Google OAuth")
return dict(
key=self.settings["google_consumer_key"],
secret=self.settings["google_consumer_secret"])
def _oauth_get_user_future(self, access_token):
return OpenIdMixin.get_authenticated_user(self)
class GoogleOAuth2Mixin(OAuth2Mixin):
"""Google authentication using OAuth2.
@ -1001,7 +798,9 @@ class GoogleOAuth2Mixin(OAuth2Mixin):
def get_authenticated_user(self, redirect_uri, code, callback):
"""Handles the login for the Google user, returning a user object.
Example usage::
Example usage:
.. testcode::
class GoogleOAuth2LoginHandler(tornado.web.RequestHandler,
tornado.auth.GoogleOAuth2Mixin):
@ -1019,6 +818,10 @@ class GoogleOAuth2Mixin(OAuth2Mixin):
scope=['profile', 'email'],
response_type='code',
extra_params={'approval_prompt': 'auto'})
.. testoutput::
:hide:
"""
http = self.get_auth_http_client()
body = urllib_parse.urlencode({
@ -1051,217 +854,6 @@ class GoogleOAuth2Mixin(OAuth2Mixin):
return httpclient.AsyncHTTPClient()
class FacebookMixin(object):
"""Facebook Connect authentication.
.. deprecated:: 1.1
New applications should use `FacebookGraphMixin`
below instead of this class. This class does not support the
Future-based interface seen on other classes in this module.
To authenticate with Facebook, register your application with
Facebook at http://www.facebook.com/developers/apps.php. Then
copy your API Key and Application Secret to the application settings
``facebook_api_key`` and ``facebook_secret``.
When your application is set up, you can use this mixin like this
to authenticate the user with Facebook::
class FacebookHandler(tornado.web.RequestHandler,
tornado.auth.FacebookMixin):
@tornado.web.asynchronous
def get(self):
if self.get_argument("session", None):
self.get_authenticated_user(self._on_auth)
return
yield self.authenticate_redirect()
def _on_auth(self, user):
if not user:
raise tornado.web.HTTPError(500, "Facebook auth failed")
# Save the user using, e.g., set_secure_cookie()
The user object returned by `get_authenticated_user` includes the
attributes ``facebook_uid`` and ``name`` in addition to session attributes
like ``session_key``. You should save the session key with the user; it is
required to make requests on behalf of the user later with
`facebook_request`.
"""
@return_future
def authenticate_redirect(self, callback_uri=None, cancel_uri=None,
extended_permissions=None, callback=None):
"""Authenticates/installs this app for the current user.
.. versionchanged:: 3.1
Returns a `.Future` and takes an optional callback. These are
not strictly necessary as this method is synchronous,
but they are supplied for consistency with
`OAuthMixin.authorize_redirect`.
"""
self.require_setting("facebook_api_key", "Facebook Connect")
callback_uri = callback_uri or self.request.uri
args = {
"api_key": self.settings["facebook_api_key"],
"v": "1.0",
"fbconnect": "true",
"display": "page",
"next": urlparse.urljoin(self.request.full_url(), callback_uri),
"return_session": "true",
}
if cancel_uri:
args["cancel_url"] = urlparse.urljoin(
self.request.full_url(), cancel_uri)
if extended_permissions:
if isinstance(extended_permissions, (unicode_type, bytes)):
extended_permissions = [extended_permissions]
args["req_perms"] = ",".join(extended_permissions)
self.redirect("http://www.facebook.com/login.php?" +
urllib_parse.urlencode(args))
callback()
def authorize_redirect(self, extended_permissions, callback_uri=None,
cancel_uri=None, callback=None):
"""Redirects to an authorization request for the given FB resource.
The available resource names are listed at
http://wiki.developers.facebook.com/index.php/Extended_permission.
The most common resource types include:
* publish_stream
* read_stream
* email
* sms
extended_permissions can be a single permission name or a list of
names. To get the session secret and session key, call
get_authenticated_user() just as you would with
authenticate_redirect().
.. versionchanged:: 3.1
Returns a `.Future` and takes an optional callback. These are
not strictly necessary as this method is synchronous,
but they are supplied for consistency with
`OAuthMixin.authorize_redirect`.
"""
return self.authenticate_redirect(callback_uri, cancel_uri,
extended_permissions,
callback=callback)
def get_authenticated_user(self, callback):
"""Fetches the authenticated Facebook user.
The authenticated user includes the special Facebook attributes
'session_key' and 'facebook_uid' in addition to the standard
user attributes like 'name'.
"""
self.require_setting("facebook_api_key", "Facebook Connect")
session = escape.json_decode(self.get_argument("session"))
self.facebook_request(
method="facebook.users.getInfo",
callback=functools.partial(
self._on_get_user_info, callback, session),
session_key=session["session_key"],
uids=session["uid"],
fields="uid,first_name,last_name,name,locale,pic_square,"
"profile_url,username")
def facebook_request(self, method, callback, **args):
"""Makes a Facebook API REST request.
We automatically include the Facebook API key and signature, but
it is the callers responsibility to include 'session_key' and any
other required arguments to the method.
The available Facebook methods are documented here:
http://wiki.developers.facebook.com/index.php/API
Here is an example for the stream.get() method::
class MainHandler(tornado.web.RequestHandler,
tornado.auth.FacebookMixin):
@tornado.web.authenticated
@tornado.web.asynchronous
def get(self):
self.facebook_request(
method="stream.get",
callback=self._on_stream,
session_key=self.current_user["session_key"])
def _on_stream(self, stream):
if stream is None:
# Not authorized to read the stream yet?
self.redirect(self.authorize_redirect("read_stream"))
return
self.render("stream.html", stream=stream)
"""
self.require_setting("facebook_api_key", "Facebook Connect")
self.require_setting("facebook_secret", "Facebook Connect")
if not method.startswith("facebook."):
method = "facebook." + method
args["api_key"] = self.settings["facebook_api_key"]
args["v"] = "1.0"
args["method"] = method
args["call_id"] = str(long(time.time() * 1e6))
args["format"] = "json"
args["sig"] = self._signature(args)
url = "http://api.facebook.com/restserver.php?" + \
urllib_parse.urlencode(args)
http = self.get_auth_http_client()
http.fetch(url, callback=functools.partial(
self._parse_response, callback))
def _on_get_user_info(self, callback, session, users):
if users is None:
callback(None)
return
callback({
"name": users[0]["name"],
"first_name": users[0]["first_name"],
"last_name": users[0]["last_name"],
"uid": users[0]["uid"],
"locale": users[0]["locale"],
"pic_square": users[0]["pic_square"],
"profile_url": users[0]["profile_url"],
"username": users[0].get("username"),
"session_key": session["session_key"],
"session_expires": session.get("expires"),
})
def _parse_response(self, callback, response):
if response.error:
gen_log.warning("HTTP error from Facebook: %s", response.error)
callback(None)
return
try:
json = escape.json_decode(response.body)
except Exception:
gen_log.warning("Invalid JSON from Facebook: %r", response.body)
callback(None)
return
if isinstance(json, dict) and json.get("error_code"):
gen_log.warning("Facebook error: %d: %r", json["error_code"],
json.get("error_msg"))
callback(None)
return
callback(json)
def _signature(self, args):
parts = ["%s=%s" % (n, args[n]) for n in sorted(args.keys())]
body = "".join(parts) + self.settings["facebook_secret"]
if isinstance(body, unicode_type):
body = body.encode("utf-8")
return hashlib.md5(body).hexdigest()
def get_auth_http_client(self):
"""Returns the `.AsyncHTTPClient` instance to be used for auth requests.
May be overridden by subclasses to use an HTTP client other than
the default.
"""
return httpclient.AsyncHTTPClient()
class FacebookGraphMixin(OAuth2Mixin):
"""Facebook authentication using the new Graph API and OAuth2."""
_OAUTH_ACCESS_TOKEN_URL = "https://graph.facebook.com/oauth/access_token?"
@ -1274,9 +866,12 @@ class FacebookGraphMixin(OAuth2Mixin):
code, callback, extra_fields=None):
"""Handles the login for the Facebook user, returning a user object.
Example usage::
Example usage:
class FacebookGraphLoginHandler(LoginHandler, tornado.auth.FacebookGraphMixin):
.. testcode::
class FacebookGraphLoginHandler(tornado.web.RequestHandler,
tornado.auth.FacebookGraphMixin):
@tornado.gen.coroutine
def get(self):
if self.get_argument("code", False):
@ -1291,6 +886,10 @@ class FacebookGraphMixin(OAuth2Mixin):
redirect_uri='/auth/facebookgraph/',
client_id=self.settings["facebook_api_key"],
extra_params={"scope": "read_stream,offline_access"})
.. testoutput::
:hide:
"""
http = self.get_auth_http_client()
args = {
@ -1358,7 +957,9 @@ class FacebookGraphMixin(OAuth2Mixin):
process includes an ``access_token`` attribute that can be
used to make authenticated requests via this method.
Example usage::
Example usage:
..testcode::
class MainHandler(tornado.web.RequestHandler,
tornado.auth.FacebookGraphMixin):
@ -1376,6 +977,9 @@ class FacebookGraphMixin(OAuth2Mixin):
return
self.finish("Posted a message!")
.. testoutput::
:hide:
The given path is relative to ``self._FACEBOOK_BASE_URL``,
by default "https://graph.facebook.com".

View file

@ -100,6 +100,14 @@ try:
except ImportError:
signal = None
# os.execv is broken on Windows and can't properly parse command line
# arguments and executable name if they contain whitespaces. subprocess
# fixes that behavior.
# This distinction is also important because when we use execv, we want to
# close the IOLoop and all its file descriptors, to guard against any
# file descriptors that were not set CLOEXEC. When execv is not available,
# we must not close the IOLoop because we want the process to exit cleanly.
_has_execv = sys.platform != 'win32'
_watched_files = set()
_reload_hooks = []
@ -108,13 +116,18 @@ _io_loops = weakref.WeakKeyDictionary()
def start(io_loop=None, check_time=500):
"""Begins watching source files for changes using the given `.IOLoop`. """
"""Begins watching source files for changes.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
io_loop = io_loop or ioloop.IOLoop.current()
if io_loop in _io_loops:
return
_io_loops[io_loop] = True
if len(_io_loops) > 1:
gen_log.warning("tornado.autoreload started more than once in the same process")
if _has_execv:
add_reload_hook(functools.partial(io_loop.close, all_fds=True))
modify_times = {}
callback = functools.partial(_reload_on_update, modify_times)
@ -162,7 +175,7 @@ def _reload_on_update(modify_times):
# processes restarted themselves, they'd all restart and then
# all call fork_processes again.
return
for module in sys.modules.values():
for module in list(sys.modules.values()):
# Some modules play games with sys.modules (e.g. email/__init__.py
# in the standard library), and occasionally this can cause strange
# failures in getattr. Just ignore anything that's not an ordinary
@ -211,10 +224,7 @@ def _reload():
not os.environ.get("PYTHONPATH", "").startswith(path_prefix)):
os.environ["PYTHONPATH"] = (path_prefix +
os.environ.get("PYTHONPATH", ""))
if sys.platform == 'win32':
# os.execv is broken on Windows and can't properly parse command line
# arguments and executable name if they contain whitespaces. subprocess
# fixes that behavior.
if not _has_execv:
subprocess.Popen([sys.executable] + sys.argv)
sys.exit(0)
else:
@ -234,7 +244,10 @@ def _reload():
# this error specifically.
os.spawnv(os.P_NOWAIT, sys.executable,
[sys.executable] + sys.argv)
sys.exit(0)
# At this point the IOLoop has been closed and finally
# blocks will experience errors if we allow the stack to
# unwind, so just exit uncleanly.
os._exit(0)
_USAGE = """\
Usage:

View file

@ -25,11 +25,13 @@ module.
from __future__ import absolute_import, division, print_function, with_statement
import functools
import platform
import traceback
import sys
from tornado.log import app_log
from tornado.stack_context import ExceptionStackContext, wrap
from tornado.util import raise_exc_info, ArgReplacer
from tornado.log import app_log
try:
from concurrent import futures
@ -37,9 +39,90 @@ except ImportError:
futures = None
# Can the garbage collector handle cycles that include __del__ methods?
# This is true in cpython beginning with version 3.4 (PEP 442).
_GC_CYCLE_FINALIZERS = (platform.python_implementation() == 'CPython' and
sys.version_info >= (3, 4))
class ReturnValueIgnoredError(Exception):
pass
# This class and associated code in the future object is derived
# from the Trollius project, a backport of asyncio to Python 2.x - 3.x
class _TracebackLogger(object):
"""Helper to log a traceback upon destruction if not cleared.
This solves a nasty problem with Futures and Tasks that have an
exception set: if nobody asks for the exception, the exception is
never logged. This violates the Zen of Python: 'Errors should
never pass silently. Unless explicitly silenced.'
However, we don't want to log the exception as soon as
set_exception() is called: if the calling code is written
properly, it will get the exception and handle it properly. But
we *do* want to log it if result() or exception() was never called
-- otherwise developers waste a lot of time wondering why their
buggy code fails silently.
An earlier attempt added a __del__() method to the Future class
itself, but this backfired because the presence of __del__()
prevents garbage collection from breaking cycles. A way out of
this catch-22 is to avoid having a __del__() method on the Future
class itself, but instead to have a reference to a helper object
with a __del__() method that logs the traceback, where we ensure
that the helper object doesn't participate in cycles, and only the
Future has a reference to it.
The helper object is added when set_exception() is called. When
the Future is collected, and the helper is present, the helper
object is also collected, and its __del__() method will log the
traceback. When the Future's result() or exception() method is
called (and a helper object is present), it removes the the helper
object, after calling its clear() method to prevent it from
logging.
One downside is that we do a fair amount of work to extract the
traceback from the exception, even when it is never logged. It
would seem cheaper to just store the exception object, but that
references the traceback, which references stack frames, which may
reference the Future, which references the _TracebackLogger, and
then the _TracebackLogger would be included in a cycle, which is
what we're trying to avoid! As an optimization, we don't
immediately format the exception; we only do the work when
activate() is called, which call is delayed until after all the
Future's callbacks have run. Since usually a Future has at least
one callback (typically set by 'yield From') and usually that
callback extracts the callback, thereby removing the need to
format the exception.
PS. I don't claim credit for this solution. I first heard of it
in a discussion about closing files when they are collected.
"""
__slots__ = ('exc_info', 'formatted_tb')
def __init__(self, exc_info):
self.exc_info = exc_info
self.formatted_tb = None
def activate(self):
exc_info = self.exc_info
if exc_info is not None:
self.exc_info = None
self.formatted_tb = traceback.format_exception(*exc_info)
def clear(self):
self.exc_info = None
self.formatted_tb = None
def __del__(self):
if self.formatted_tb:
app_log.error('Future exception was never retrieved: %s',
''.join(self.formatted_tb).rstrip())
class Future(object):
"""Placeholder for an asynchronous result.
@ -68,12 +151,23 @@ class Future(object):
if that package was available and fall back to the thread-unsafe
implementation if it was not.
.. versionchanged:: 4.1
If a `.Future` contains an error but that error is never observed
(by calling ``result()``, ``exception()``, or ``exc_info()``),
a stack trace will be logged when the `.Future` is garbage collected.
This normally indicates an error in the application, but in cases
where it results in undesired logging it may be necessary to
suppress the logging by ensuring that the exception is observed:
``f.add_done_callback(lambda f: f.exception())``.
"""
def __init__(self):
self._done = False
self._result = None
self._exception = None
self._exc_info = None
self._log_traceback = False # Used for Python >= 3.4
self._tb_logger = None # Used for Python <= 3.3
self._callbacks = []
def cancel(self):
@ -100,16 +194,21 @@ class Future(object):
"""Returns True if the future has finished running."""
return self._done
def _clear_tb_log(self):
self._log_traceback = False
if self._tb_logger is not None:
self._tb_logger.clear()
self._tb_logger = None
def result(self, timeout=None):
"""If the operation succeeded, return its result. If it failed,
re-raise its exception.
"""
self._clear_tb_log()
if self._result is not None:
return self._result
if self._exc_info is not None:
raise_exc_info(self._exc_info)
elif self._exception is not None:
raise self._exception
self._check_done()
return self._result
@ -117,8 +216,9 @@ class Future(object):
"""If the operation raised an exception, return the `Exception`
object. Otherwise returns None.
"""
if self._exception is not None:
return self._exception
self._clear_tb_log()
if self._exc_info is not None:
return self._exc_info[1]
else:
self._check_done()
return None
@ -147,14 +247,17 @@ class Future(object):
def set_exception(self, exception):
"""Sets the exception of a ``Future.``"""
self._exception = exception
self._set_done()
self.set_exc_info(
(exception.__class__,
exception,
getattr(exception, '__traceback__', None)))
def exc_info(self):
"""Returns a tuple in the same format as `sys.exc_info` or None.
.. versionadded:: 4.0
"""
self._clear_tb_log()
return self._exc_info
def set_exc_info(self, exc_info):
@ -165,7 +268,18 @@ class Future(object):
.. versionadded:: 4.0
"""
self._exc_info = exc_info
self.set_exception(exc_info[1])
self._log_traceback = True
if not _GC_CYCLE_FINALIZERS:
self._tb_logger = _TracebackLogger(exc_info)
try:
self._set_done()
finally:
# Activate the logger after all callbacks have had a
# chance to call result() or exception().
if self._log_traceback and self._tb_logger is not None:
self._tb_logger.activate()
self._exc_info = exc_info
def _check_done(self):
if not self._done:
@ -177,10 +291,25 @@ class Future(object):
try:
cb(self)
except Exception:
app_log.exception('exception calling callback %r for %r',
app_log.exception('Exception in callback %r for %r',
cb, self)
self._callbacks = None
# On Python 3.3 or older, objects with a destructor part of a reference
# cycle are never destroyed. It's no longer the case on Python 3.4 thanks to
# the PEP 442.
if _GC_CYCLE_FINALIZERS:
def __del__(self):
if not self._log_traceback:
# set_exception() was not called, or result() or exception()
# has consumed the exception
return
tb = traceback.format_exception(*self._exc_info)
app_log.error('Future %r exception was never retrieved: %s',
self, ''.join(tb).rstrip())
TracebackFuture = Future
if futures is None:
@ -208,24 +337,42 @@ class DummyExecutor(object):
dummy_executor = DummyExecutor()
def run_on_executor(fn):
def run_on_executor(*args, **kwargs):
"""Decorator to run a synchronous method asynchronously on an executor.
The decorated method may be called with a ``callback`` keyword
argument and returns a future.
This decorator should be used only on methods of objects with attributes
``executor`` and ``io_loop``.
The `.IOLoop` and executor to be used are determined by the ``io_loop``
and ``executor`` attributes of ``self``. To use different attributes,
pass keyword arguments to the decorator::
@run_on_executor(executor='_thread_pool')
def foo(self):
pass
.. versionchanged:: 4.2
Added keyword arguments to use alternative attributes.
"""
def run_on_executor_decorator(fn):
executor = kwargs.get("executor", "executor")
io_loop = kwargs.get("io_loop", "io_loop")
@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
callback = kwargs.pop("callback", None)
future = self.executor.submit(fn, self, *args, **kwargs)
future = getattr(self, executor).submit(fn, self, *args, **kwargs)
if callback:
self.io_loop.add_future(future,
lambda future: callback(future.result()))
getattr(self, io_loop).add_future(
future, lambda future: callback(future.result()))
return future
return wrapper
if args and kwargs:
raise ValueError("cannot combine positional and keyword args")
if len(args) == 1:
return run_on_executor_decorator(args[0])
elif len(args) != 0:
raise ValueError("expected 1 argument, got %d", len(args))
return run_on_executor_decorator
_NO_RESULT = object()
@ -250,7 +397,9 @@ def return_future(f):
wait for the function to complete (perhaps by yielding it in a
`.gen.engine` function, or passing it to `.IOLoop.add_future`).
Usage::
Usage:
.. testcode::
@return_future
def future_func(arg1, arg2, callback):
@ -262,6 +411,8 @@ def return_future(f):
yield future_func(arg1, arg2)
callback()
..
Note that ``@return_future`` and ``@gen.engine`` can be applied to the
same function, provided ``@return_future`` appears first. However,
consider using ``@gen.coroutine`` instead of this combination.
@ -293,7 +444,7 @@ def return_future(f):
# If the initial synchronous part of f() raised an exception,
# go ahead and raise it to the caller directly without waiting
# for them to inspect the Future.
raise_exc_info(exc_info)
future.result()
# If the caller passed in a callback, schedule it to be called
# when the future resolves. It is important that this happens

View file

@ -28,12 +28,13 @@ from io import BytesIO
from tornado import httputil
from tornado import ioloop
from tornado.log import gen_log
from tornado import stack_context
from tornado.escape import utf8, native_str
from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main
curl_log = logging.getLogger('tornado.curl_httpclient')
class CurlAsyncHTTPClient(AsyncHTTPClient):
def initialize(self, io_loop, max_clients=10, defaults=None):
@ -207,8 +208,24 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
"callback": callback,
"curl_start_time": time.time(),
}
self._curl_setup_request(curl, request, curl.info["buffer"],
try:
self._curl_setup_request(
curl, request, curl.info["buffer"],
curl.info["headers"])
except Exception as e:
# If there was an error in setup, pass it on
# to the callback. Note that allowing the
# error to escape here will appear to work
# most of the time since we are still in the
# caller's original stack frame, but when
# _process_queue() is called from
# _finish_pending_requests the exceptions have
# nowhere to go.
callback(HTTPResponse(
request=request,
code=599,
error=e))
else:
self._multi.add_handle(curl)
if not started:
@ -257,7 +274,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
def _curl_create(self):
curl = pycurl.Curl()
if gen_log.isEnabledFor(logging.DEBUG):
if curl_log.isEnabledFor(logging.DEBUG):
curl.setopt(pycurl.VERBOSE, 1)
curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug)
return curl
@ -288,8 +305,8 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
functools.partial(self._curl_header_callback,
headers, request.header_callback))
if request.streaming_callback:
write_function = lambda chunk: self.io_loop.add_callback(
request.streaming_callback, chunk)
def write_function(chunk):
self.io_loop.add_callback(request.streaming_callback, chunk)
else:
write_function = buffer.write
if bytes is str: # py2
@ -381,6 +398,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
% request.method)
request_buffer = BytesIO(utf8(request.body))
def ioctl(cmd):
if cmd == curl.IOCMD_RESTARTREAD:
request_buffer.seek(0)
@ -403,11 +421,11 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
raise ValueError("Unsupported auth_mode %s" % request.auth_mode)
curl.setopt(pycurl.USERPWD, native_str(userpwd))
gen_log.debug("%s %s (username: %r)", request.method, request.url,
curl_log.debug("%s %s (username: %r)", request.method, request.url,
request.auth_username)
else:
curl.unsetopt(pycurl.USERPWD)
gen_log.debug("%s %s", request.method, request.url)
curl_log.debug("%s %s", request.method, request.url)
if request.client_cert is not None:
curl.setopt(pycurl.SSLCERT, request.client_cert)
@ -415,6 +433,9 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
if request.client_key is not None:
curl.setopt(pycurl.SSLKEY, request.client_key)
if request.ssl_options is not None:
raise ValueError("ssl_options not supported in curl_httpclient")
if threading.activeCount() > 1:
# libcurl/pycurl is not thread-safe by default. When multiple threads
# are used, signals should be disabled. This has the side effect
@ -448,12 +469,12 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
def _curl_debug(self, debug_type, debug_msg):
debug_types = ('I', '<', '>', '<', '>')
if debug_type == 0:
gen_log.debug('%s', debug_msg.strip())
curl_log.debug('%s', debug_msg.strip())
elif debug_type in (1, 2):
for line in debug_msg.splitlines():
gen_log.debug('%s %s', debug_types[debug_type], line)
curl_log.debug('%s %s', debug_types[debug_type], line)
elif debug_type == 4:
gen_log.debug('%s %r', debug_types[debug_type], debug_msg)
curl_log.debug('%s %r', debug_types[debug_type], debug_msg)
class CurlError(HTTPError):

View file

@ -82,7 +82,7 @@ def json_encode(value):
# JSON permits but does not require forward slashes to be escaped.
# This is useful when json data is emitted in a <script> tag
# in HTML, as it prevents </script> tags from prematurely terminating
# the javscript. Some json libraries do this escaping by default,
# the javascript. Some json libraries do this escaping by default,
# although python's standard library does not, so we do it here.
# http://stackoverflow.com/questions/1580647/json-why-are-forward-slashes-escaped
return json.dumps(value).replace("</", "<\\/")

View file

@ -3,7 +3,9 @@ work in an asynchronous environment. Code using the ``gen`` module
is technically asynchronous, but it is written as a single generator
instead of a collection of separate functions.
For example, the following asynchronous handler::
For example, the following asynchronous handler:
.. testcode::
class AsyncHandler(RequestHandler):
@asynchronous
@ -16,7 +18,12 @@ For example, the following asynchronous handler::
do_something_with_response(response)
self.render("template.html")
could be written with ``gen`` as::
.. testoutput::
:hide:
could be written with ``gen`` as:
.. testcode::
class GenAsyncHandler(RequestHandler):
@gen.coroutine
@ -26,12 +33,17 @@ could be written with ``gen`` as::
do_something_with_response(response)
self.render("template.html")
.. testoutput::
:hide:
Most asynchronous functions in Tornado return a `.Future`;
yielding this object returns its `~.Future.result`.
You can also yield a list or dict of ``Futures``, which will be
started at the same time and run in parallel; a list or dict of results will
be returned when they are all finished::
be returned when they are all finished:
.. testcode::
@gen.coroutine
def get(self):
@ -43,8 +55,24 @@ be returned when they are all finished::
response3 = response_dict['response3']
response4 = response_dict['response4']
.. testoutput::
:hide:
If the `~functools.singledispatch` library is available (standard in
Python 3.4, available via the `singledispatch
<https://pypi.python.org/pypi/singledispatch>`_ package on older
versions), additional types of objects may be yielded. Tornado includes
support for ``asyncio.Future`` and Twisted's ``Deferred`` class when
``tornado.platform.asyncio`` and ``tornado.platform.twisted`` are imported.
See the `convert_yielded` function to extend this mechanism.
.. versionchanged:: 3.2
Dict support added.
.. versionchanged:: 4.1
Support added for yielding ``asyncio`` Futures and Twisted Deferreds
via ``singledispatch``.
"""
from __future__ import absolute_import, division, print_function, with_statement
@ -53,10 +81,21 @@ import functools
import itertools
import sys
import types
import weakref
from tornado.concurrent import Future, TracebackFuture, is_future, chain_future
from tornado.ioloop import IOLoop
from tornado.log import app_log
from tornado import stack_context
from tornado.util import raise_exc_info
try:
from functools import singledispatch # py34+
except ImportError as e:
try:
from singledispatch import singledispatch # backport
except ImportError:
singledispatch = None
class KeyReuseError(Exception):
@ -101,9 +140,11 @@ def engine(func):
which use ``self.finish()`` in place of a callback argument.
"""
func = _make_coroutine_wrapper(func, replace_callback=False)
@functools.wraps(func)
def wrapper(*args, **kwargs):
future = func(*args, **kwargs)
def final_callback(future):
if future.result() is not None:
raise ReturnValueIgnoredError(
@ -241,6 +282,113 @@ class Return(Exception):
self.value = value
class WaitIterator(object):
"""Provides an iterator to yield the results of futures as they finish.
Yielding a set of futures like this:
``results = yield [future1, future2]``
pauses the coroutine until both ``future1`` and ``future2``
return, and then restarts the coroutine with the results of both
futures. If either future is an exception, the expression will
raise that exception and all the results will be lost.
If you need to get the result of each future as soon as possible,
or if you need the result of some futures even if others produce
errors, you can use ``WaitIterator``::
wait_iterator = gen.WaitIterator(future1, future2)
while not wait_iterator.done():
try:
result = yield wait_iterator.next()
except Exception as e:
print("Error {} from {}".format(e, wait_iterator.current_future))
else:
print("Result {} received from {} at {}".format(
result, wait_iterator.current_future,
wait_iterator.current_index))
Because results are returned as soon as they are available the
output from the iterator *will not be in the same order as the
input arguments*. If you need to know which future produced the
current result, you can use the attributes
``WaitIterator.current_future``, or ``WaitIterator.current_index``
to get the index of the future from the input list. (if keyword
arguments were used in the construction of the `WaitIterator`,
``current_index`` will use the corresponding keyword).
.. versionadded:: 4.1
"""
def __init__(self, *args, **kwargs):
if args and kwargs:
raise ValueError(
"You must provide args or kwargs, not both")
if kwargs:
self._unfinished = dict((f, k) for (k, f) in kwargs.items())
futures = list(kwargs.values())
else:
self._unfinished = dict((f, i) for (i, f) in enumerate(args))
futures = args
self._finished = collections.deque()
self.current_index = self.current_future = None
self._running_future = None
# Use a weak reference to self to avoid cycles that may delay
# garbage collection.
self_ref = weakref.ref(self)
for future in futures:
future.add_done_callback(functools.partial(
self._done_callback, self_ref))
def done(self):
"""Returns True if this iterator has no more results."""
if self._finished or self._unfinished:
return False
# Clear the 'current' values when iteration is done.
self.current_index = self.current_future = None
return True
def next(self):
"""Returns a `.Future` that will yield the next available result.
Note that this `.Future` will not be the same object as any of
the inputs.
"""
self._running_future = TracebackFuture()
# As long as there is an active _running_future, we must
# ensure that the WaitIterator is not GC'd (due to the
# use of weak references in __init__). Add a callback that
# references self so there is a hard reference that will be
# cleared automatically when this Future finishes.
self._running_future.add_done_callback(lambda f: self)
if self._finished:
self._return_result(self._finished.popleft())
return self._running_future
@staticmethod
def _done_callback(self_ref, done):
self = self_ref()
if self is not None:
if self._running_future and not self._running_future.done():
self._return_result(done)
else:
self._finished.append(done)
def _return_result(self, done):
"""Called set the returned future's state that of the future
we yielded, and set the current future for the iterator.
"""
chain_future(done, self._running_future)
self.current_future = done
self.current_index = self._unfinished.pop(done)
class YieldPoint(object):
"""Base class for objects that may be yielded from the generator.
@ -355,11 +503,13 @@ def Task(func, *args, **kwargs):
yielded.
"""
future = Future()
def handle_exception(typ, value, tb):
if future.done():
return False
future.set_exc_info((typ, value, tb))
return True
def set_result(result):
if future.done():
return
@ -371,6 +521,11 @@ def Task(func, *args, **kwargs):
class YieldFuture(YieldPoint):
def __init__(self, future, io_loop=None):
"""Adapts a `.Future` to the `YieldPoint` interface.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
self.future = future
self.io_loop = io_loop or IOLoop.current()
@ -382,7 +537,7 @@ class YieldFuture(YieldPoint):
self.io_loop.add_future(self.future, runner.result_callback(self.key))
else:
self.runner = None
self.result = self.future.result()
self.result_fn = self.future.result
def is_ready(self):
if self.runner is not None:
@ -394,7 +549,7 @@ class YieldFuture(YieldPoint):
if self.runner is not None:
return self.runner.pop_result(self.key).result()
else:
return self.result
return self.result_fn()
class Multi(YieldPoint):
@ -408,8 +563,18 @@ class Multi(YieldPoint):
Instead of a list, the argument may also be a dictionary whose values are
Futures, in which case a parallel dictionary is returned mapping the same
keys to their results.
It is not normally necessary to call this class directly, as it
will be created automatically as needed. However, calling it directly
allows you to use the ``quiet_exceptions`` argument to control
the logging of multiple exceptions.
.. versionchanged:: 4.2
If multiple ``YieldPoints`` fail, any exceptions after the first
(which is raised) will be logged. Added the ``quiet_exceptions``
argument to suppress this logging for selected exception types.
"""
def __init__(self, children):
def __init__(self, children, quiet_exceptions=()):
self.keys = None
if isinstance(children, dict):
self.keys = list(children.keys())
@ -421,6 +586,7 @@ class Multi(YieldPoint):
self.children.append(i)
assert all(isinstance(i, YieldPoint) for i in self.children)
self.unfinished_children = set(self.children)
self.quiet_exceptions = quiet_exceptions
def start(self, runner):
for i in self.children:
@ -433,14 +599,27 @@ class Multi(YieldPoint):
return not self.unfinished_children
def get_result(self):
result = (i.get_result() for i in self.children)
if self.keys is not None:
return dict(zip(self.keys, result))
result_list = []
exc_info = None
for f in self.children:
try:
result_list.append(f.get_result())
except Exception as e:
if exc_info is None:
exc_info = sys.exc_info()
else:
return list(result)
if not isinstance(e, self.quiet_exceptions):
app_log.error("Multiple exceptions in yield list",
exc_info=True)
if exc_info is not None:
raise_exc_info(exc_info)
if self.keys is not None:
return dict(zip(self.keys, result_list))
else:
return list(result_list)
def multi_future(children):
def multi_future(children, quiet_exceptions=()):
"""Wait for multiple asynchronous futures in parallel.
Takes a list of ``Futures`` (but *not* other ``YieldPoints``) and returns
@ -453,12 +632,21 @@ def multi_future(children):
Futures, in which case a parallel dictionary is returned mapping the same
keys to their results.
It is not necessary to call `multi_future` explcitly, since the engine will
do so automatically when the generator yields a list of `Futures`.
This function is faster than the `Multi` `YieldPoint` because it does not
require the creation of a stack context.
It is not normally necessary to call `multi_future` explcitly,
since the engine will do so automatically when the generator
yields a list of ``Futures``. However, calling it directly
allows you to use the ``quiet_exceptions`` argument to control
the logging of multiple exceptions.
This function is faster than the `Multi` `YieldPoint` because it
does not require the creation of a stack context.
.. versionadded:: 4.0
.. versionchanged:: 4.2
If multiple ``Futures`` fail, any exceptions after the first (which is
raised) will be logged. Added the ``quiet_exceptions``
argument to suppress this logging for selected exception types.
"""
if isinstance(children, dict):
keys = list(children.keys())
@ -471,19 +659,31 @@ def multi_future(children):
future = Future()
if not children:
future.set_result({} if keys is not None else [])
def callback(f):
unfinished_children.remove(f)
if not unfinished_children:
result_list = []
for f in children:
try:
result_list = [i.result() for i in children]
except Exception:
future.set_exc_info(sys.exc_info())
result_list.append(f.result())
except Exception as e:
if future.done():
if not isinstance(e, quiet_exceptions):
app_log.error("Multiple exceptions in yield list",
exc_info=True)
else:
future.set_exc_info(sys.exc_info())
if not future.done():
if keys is not None:
future.set_result(dict(zip(keys, result_list)))
else:
future.set_result(result_list)
listening = set()
for f in children:
if f not in listening:
listening.add(f)
f.add_done_callback(callback)
return future
@ -504,7 +704,7 @@ def maybe_future(x):
return fut
def with_timeout(timeout, future, io_loop=None):
def with_timeout(timeout, future, io_loop=None, quiet_exceptions=()):
"""Wraps a `.Future` in a timeout.
Raises `TimeoutError` if the input future does not complete before
@ -512,9 +712,17 @@ def with_timeout(timeout, future, io_loop=None):
`.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or an absolute time
relative to `.IOLoop.time`)
If the wrapped `.Future` fails after it has timed out, the exception
will be logged unless it is of a type contained in ``quiet_exceptions``
(which may be an exception type or a sequence of types).
Currently only supports Futures, not other `YieldPoint` classes.
.. versionadded:: 4.0
.. versionchanged:: 4.1
Added the ``quiet_exceptions`` argument and the logging of unhandled
exceptions.
"""
# TODO: allow yield points in addition to futures?
# Tricky to do with stack_context semantics.
@ -528,9 +736,21 @@ def with_timeout(timeout, future, io_loop=None):
chain_future(future, result)
if io_loop is None:
io_loop = IOLoop.current()
def error_callback(future):
try:
future.result()
except Exception as e:
if not isinstance(e, quiet_exceptions):
app_log.error("Exception in Future %r after timeout",
future, exc_info=True)
def timeout_callback():
result.set_exception(TimeoutError("Timeout"))
# In case the wrapped future goes on to fail, log it.
future.add_done_callback(error_callback)
timeout_handle = io_loop.add_timeout(
timeout,
lambda: result.set_exception(TimeoutError("Timeout")))
timeout, timeout_callback)
if isinstance(future, Future):
# We know this future will resolve on the IOLoop, so we don't
# need the extra thread-safety of IOLoop.add_future (and we also
@ -545,6 +765,25 @@ def with_timeout(timeout, future, io_loop=None):
return result
def sleep(duration):
"""Return a `.Future` that resolves after the given number of seconds.
When used with ``yield`` in a coroutine, this is a non-blocking
analogue to `time.sleep` (which should not be used in coroutines
because it is blocking)::
yield gen.sleep(0.5)
Note that calling this function on its own does nothing; you must
wait on the `.Future` it returns (usually by yielding it).
.. versionadded:: 4.1
"""
f = Future()
IOLoop.current().call_later(duration, lambda: f.set_result(None))
return f
_null_future = Future()
_null_future.set_result(None)
@ -638,13 +877,20 @@ class Runner(object):
self.future = None
try:
orig_stack_contexts = stack_context._state.contexts
exc_info = None
try:
value = future.result()
except Exception:
self.had_exception = True
yielded = self.gen.throw(*sys.exc_info())
exc_info = sys.exc_info()
if exc_info is not None:
yielded = self.gen.throw(*exc_info)
exc_info = None
else:
yielded = self.gen.send(value)
if stack_context._state.contexts is not orig_stack_contexts:
self.gen.throw(
stack_context.StackContextInconsistentError(
@ -678,19 +924,20 @@ class Runner(object):
self.running = False
def handle_yield(self, yielded):
if isinstance(yielded, list):
if all(is_future(f) for f in yielded):
yielded = multi_future(yielded)
else:
# Lists containing YieldPoints require stack contexts;
# other lists are handled via multi_future in convert_yielded.
if (isinstance(yielded, list) and
any(isinstance(f, YieldPoint) for f in yielded)):
yielded = Multi(yielded)
elif isinstance(yielded, dict):
if all(is_future(f) for f in yielded.values()):
yielded = multi_future(yielded)
else:
elif (isinstance(yielded, dict) and
any(isinstance(f, YieldPoint) for f in yielded.values())):
yielded = Multi(yielded)
if isinstance(yielded, YieldPoint):
# YieldPoints are too closely coupled to the Runner to go
# through the generic convert_yielded mechanism.
self.future = TracebackFuture()
def start_yield_point():
try:
yielded.start(self)
@ -702,12 +949,14 @@ class Runner(object):
except Exception:
self.future = TracebackFuture()
self.future.set_exc_info(sys.exc_info())
if self.stack_context_deactivate is None:
# Start a stack context if this is the first
# YieldPoint we've seen.
with stack_context.ExceptionStackContext(
self.handle_exception) as deactivate:
self.stack_context_deactivate = deactivate
def cb():
start_yield_point()
self.run()
@ -715,16 +964,17 @@ class Runner(object):
return False
else:
start_yield_point()
elif is_future(yielded):
self.future = yielded
else:
try:
self.future = convert_yielded(yielded)
except BadYieldError:
self.future = TracebackFuture()
self.future.set_exc_info(sys.exc_info())
if not self.future.done() or self.future is moment:
self.io_loop.add_future(
self.future, lambda f: self.run())
return False
else:
self.future = TracebackFuture()
self.future.set_exception(BadYieldError(
"yielded unknown object %r" % (yielded,)))
return True
def result_callback(self, key):
@ -763,3 +1013,30 @@ def _argument_adapter(callback):
else:
callback(None)
return wrapper
def convert_yielded(yielded):
"""Convert a yielded object into a `.Future`.
The default implementation accepts lists, dictionaries, and Futures.
If the `~functools.singledispatch` library is available, this function
may be extended to support additional types. For example::
@convert_yielded.register(asyncio.Future)
def _(asyncio_future):
return tornado.platform.asyncio.to_tornado_future(asyncio_future)
.. versionadded:: 4.1
"""
# Lists and dicts containing YieldPoints were handled separately
# via Multi().
if isinstance(yielded, (list, dict)):
return multi_future(yielded)
elif is_future(yielded):
return yielded
else:
raise BadYieldError("yielded unknown object %r" % (yielded,))
if singledispatch is not None:
convert_yielded = singledispatch(convert_yielded)

View file

@ -37,6 +37,7 @@ class _QuietException(Exception):
def __init__(self):
pass
class _ExceptionLoggingContext(object):
"""Used with the ``with`` statement when calling delegate methods to
log any exceptions with the given logger. Any exceptions caught are
@ -53,6 +54,7 @@ class _ExceptionLoggingContext(object):
self.logger.error("Uncaught exception", exc_info=(typ, value, tb))
raise _QuietException
class HTTP1ConnectionParameters(object):
"""Parameters for `.HTTP1Connection` and `.HTTP1ServerConnection`.
"""
@ -162,7 +164,8 @@ class HTTP1Connection(httputil.HTTPConnection):
header_data = yield gen.with_timeout(
self.stream.io_loop.time() + self.params.header_timeout,
header_future,
io_loop=self.stream.io_loop)
io_loop=self.stream.io_loop,
quiet_exceptions=iostream.StreamClosedError)
except gen.TimeoutError:
self.close()
raise gen.Return(False)
@ -221,7 +224,8 @@ class HTTP1Connection(httputil.HTTPConnection):
try:
yield gen.with_timeout(
self.stream.io_loop.time() + self._body_timeout,
body_future, self.stream.io_loop)
body_future, self.stream.io_loop,
quiet_exceptions=iostream.StreamClosedError)
except gen.TimeoutError:
gen_log.info("Timeout reading body from %s",
self.context)
@ -326,8 +330,10 @@ class HTTP1Connection(httputil.HTTPConnection):
def write_headers(self, start_line, headers, chunk=None, callback=None):
"""Implements `.HTTPConnection.write_headers`."""
lines = []
if self.is_client:
self._request_start_line = start_line
lines.append(utf8('%s %s HTTP/1.1' % (start_line[0], start_line[1])))
# Client requests with a non-empty body must have either a
# Content-Length or a Transfer-Encoding.
self._chunking_output = (
@ -336,6 +342,7 @@ class HTTP1Connection(httputil.HTTPConnection):
'Transfer-Encoding' not in headers)
else:
self._response_start_line = start_line
lines.append(utf8('HTTP/1.1 %s %s' % (start_line[1], start_line[2])))
self._chunking_output = (
# TODO: should this use
# self._request_start_line.version or
@ -365,7 +372,6 @@ class HTTP1Connection(httputil.HTTPConnection):
self._expected_content_remaining = int(headers['Content-Length'])
else:
self._expected_content_remaining = None
lines = [utf8("%s %s %s" % start_line)]
lines.extend([utf8(n) + b": " + utf8(v) for n, v in headers.get_all()])
for line in lines:
if b'\n' in line:
@ -374,6 +380,7 @@ class HTTP1Connection(httputil.HTTPConnection):
if self.stream.closed():
future = self._write_future = Future()
future.set_exception(iostream.StreamClosedError())
future.exception()
else:
if callback is not None:
self._write_callback = stack_context.wrap(callback)
@ -412,6 +419,7 @@ class HTTP1Connection(httputil.HTTPConnection):
if self.stream.closed():
future = self._write_future = Future()
self._write_future.set_exception(iostream.StreamClosedError())
self._write_future.exception()
else:
if callback is not None:
self._write_callback = stack_context.wrap(callback)
@ -451,6 +459,9 @@ class HTTP1Connection(httputil.HTTPConnection):
self._pending_write.add_done_callback(self._finish_request)
def _on_write_complete(self, future):
exc = future.exception()
if exc is not None and not isinstance(exc, iostream.StreamClosedError):
future.result()
if self._write_callback is not None:
callback = self._write_callback
self._write_callback = None
@ -491,8 +502,9 @@ class HTTP1Connection(httputil.HTTPConnection):
# we SHOULD ignore at least one empty line before the request.
# http://tools.ietf.org/html/rfc7230#section-3.5
data = native_str(data.decode('latin1')).lstrip("\r\n")
eol = data.find("\r\n")
start_line = data[:eol]
# RFC 7230 section allows for both CRLF and bare LF.
eol = data.find("\n")
start_line = data[:eol].rstrip("\r")
try:
headers = httputil.HTTPHeaders.parse(data[eol:])
except ValueError:
@ -686,8 +698,7 @@ class HTTP1ServerConnection(object):
# This exception was already logged.
conn.close()
return
except Exception as e:
if 1 != e.errno:
except Exception:
gen_log.error("Uncaught exception", exc_info=True)
conn.close()
return

View file

@ -72,7 +72,7 @@ class HTTPClient(object):
http_client.close()
"""
def __init__(self, async_client_class=None, **kwargs):
self._io_loop = IOLoop()
self._io_loop = IOLoop(make_current=False)
if async_client_class is None:
async_client_class = AsyncHTTPClient
self._async_client = async_client_class(self._io_loop, **kwargs)
@ -95,11 +95,11 @@ class HTTPClient(object):
If it is a string, we construct an `HTTPRequest` using any additional
kwargs: ``HTTPRequest(request, **kwargs)``
If an error occurs during the fetch, we raise an `HTTPError`.
If an error occurs during the fetch, we raise an `HTTPError` unless
the ``raise_error`` keyword argument is set to False.
"""
response = self._io_loop.run_sync(functools.partial(
self._async_client.fetch, request, **kwargs))
response.rethrow()
return response
@ -136,6 +136,9 @@ class AsyncHTTPClient(Configurable):
# or with force_instance:
client = AsyncHTTPClient(force_instance=True,
defaults=dict(user_agent="MyUserAgent"))
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
@classmethod
def configurable_base(cls):
@ -200,7 +203,7 @@ class AsyncHTTPClient(Configurable):
raise RuntimeError("inconsistent AsyncHTTPClient cache")
del self._instance_cache[self.io_loop]
def fetch(self, request, callback=None, **kwargs):
def fetch(self, request, callback=None, raise_error=True, **kwargs):
"""Executes a request, asynchronously returning an `HTTPResponse`.
The request may be either a string URL or an `HTTPRequest` object.
@ -208,8 +211,10 @@ class AsyncHTTPClient(Configurable):
kwargs: ``HTTPRequest(request, **kwargs)``
This method returns a `.Future` whose result is an
`HTTPResponse`. The ``Future`` will raise an `HTTPError` if
the request returned a non-200 response code.
`HTTPResponse`. By default, the ``Future`` will raise an `HTTPError`
if the request returned a non-200 response code. Instead, if
``raise_error`` is set to False, the response will always be
returned regardless of the response code.
If a ``callback`` is given, it will be invoked with the `HTTPResponse`.
In the callback interface, `HTTPError` is not automatically raised.
@ -243,7 +248,7 @@ class AsyncHTTPClient(Configurable):
future.add_done_callback(handle_future)
def handle_response(response):
if response.error:
if raise_error and response.error:
future.set_exception(response.error)
else:
future.set_result(response)
@ -304,7 +309,8 @@ class HTTPRequest(object):
validate_cert=None, ca_certs=None,
allow_ipv6=None,
client_key=None, client_cert=None, body_producer=None,
expect_100_continue=False, decompress_response=None):
expect_100_continue=False, decompress_response=None,
ssl_options=None):
r"""All parameters except ``url`` are optional.
:arg string url: URL to fetch
@ -374,12 +380,15 @@ class HTTPRequest(object):
:arg string ca_certs: filename of CA certificates in PEM format,
or None to use defaults. See note below when used with
``curl_httpclient``.
:arg bool allow_ipv6: Use IPv6 when available? Default is false in
``simple_httpclient`` and true in ``curl_httpclient``
:arg string client_key: Filename for client SSL key, if any. See
note below when used with ``curl_httpclient``.
:arg string client_cert: Filename for client SSL certificate, if any.
See note below when used with ``curl_httpclient``.
:arg ssl.SSLContext ssl_options: `ssl.SSLContext` object for use in
``simple_httpclient`` (unsupported by ``curl_httpclient``).
Overrides ``validate_cert``, ``ca_certs``, ``client_key``,
and ``client_cert``.
:arg bool allow_ipv6: Use IPv6 when available? Default is true.
:arg bool expect_100_continue: If true, send the
``Expect: 100-continue`` header and wait for a continue response
before sending the request body. Only supported with
@ -402,6 +411,9 @@ class HTTPRequest(object):
.. versionadded:: 4.0
The ``body_producer`` and ``expect_100_continue`` arguments.
.. versionadded:: 4.2
The ``ssl_options`` argument.
"""
# Note that some of these attributes go through property setters
# defined below.
@ -439,6 +451,7 @@ class HTTPRequest(object):
self.allow_ipv6 = allow_ipv6
self.client_key = client_key
self.client_cert = client_cert
self.ssl_options = ssl_options
self.expect_100_continue = expect_100_continue
self.start_time = time.time()

View file

@ -37,35 +37,17 @@ from tornado import httputil
from tornado import iostream
from tornado import netutil
from tornado.tcpserver import TCPServer
from tornado.util import Configurable
class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
class HTTPServer(TCPServer, Configurable,
httputil.HTTPServerConnectionDelegate):
r"""A non-blocking, single-threaded HTTP server.
A server is defined by either a request callback that takes a
`.HTTPServerRequest` as an argument or a `.HTTPServerConnectionDelegate`
instance.
A simple example server that echoes back the URI you requested::
import tornado.httpserver
import tornado.ioloop
from tornado import httputil
def handle_request(request):
message = "You requested %s\n" % request.uri
request.connection.write_headers(
httputil.ResponseStartLine('HTTP/1.1', 200, 'OK'),
{"Content-Length": str(len(message))})
request.connection.write(message)
request.connection.finish()
http_server = tornado.httpserver.HTTPServer(handle_request)
http_server.listen(8888)
tornado.ioloop.IOLoop.instance().start()
Applications should use the methods of `.HTTPConnection` to write
their response.
A server is defined by a subclass of `.HTTPServerConnectionDelegate`,
or, for backwards compatibility, a callback that takes an
`.HTTPServerRequest` as an argument. The delegate is usually a
`tornado.web.Application`.
`HTTPServer` supports keep-alive connections by default
(automatically for HTTP/1.1, or for HTTP/1.0 when the client
@ -80,15 +62,15 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
if Tornado is run behind an SSL-decoding proxy that does not set one of
the supported ``xheaders``.
To make this server serve SSL traffic, send the ``ssl_options`` dictionary
argument with the arguments required for the `ssl.wrap_socket` method,
including ``certfile`` and ``keyfile``. (In Python 3.2+ you can pass
an `ssl.SSLContext` object instead of a dict)::
To make this server serve SSL traffic, send the ``ssl_options`` keyword
argument with an `ssl.SSLContext` object. For compatibility with older
versions of Python ``ssl_options`` may also be a dictionary of keyword
arguments for the `ssl.wrap_socket` method.::
HTTPServer(applicaton, ssl_options={
"certfile": os.path.join(data_dir, "mydomain.crt"),
"keyfile": os.path.join(data_dir, "mydomain.key"),
})
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"),
os.path.join(data_dir, "mydomain.key"))
HTTPServer(applicaton, ssl_options=ssl_ctx)
`HTTPServer` initialization follows one of three patterns (the
initialization methods are defined on `tornado.tcpserver.TCPServer`):
@ -97,7 +79,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
server = HTTPServer(app)
server.listen(8888)
IOLoop.instance().start()
IOLoop.current().start()
In many cases, `tornado.web.Application.listen` can be used to avoid
the need to explicitly create the `HTTPServer`.
@ -108,7 +90,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
server = HTTPServer(app)
server.bind(8888)
server.start(0) # Forks multiple sub-processes
IOLoop.instance().start()
IOLoop.current().start()
When using this interface, an `.IOLoop` must *not* be passed
to the `HTTPServer` constructor. `~.TCPServer.start` will always start
@ -120,7 +102,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
tornado.process.fork_processes(0)
server = HTTPServer(app)
server.add_sockets(sockets)
IOLoop.instance().start()
IOLoop.current().start()
The `~.TCPServer.add_sockets` interface is more complicated,
but it can be used with `tornado.process.fork_processes` to
@ -134,8 +116,24 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
``idle_connection_timeout``, ``body_timeout``, ``max_body_size``
arguments. Added support for `.HTTPServerConnectionDelegate`
instances as ``request_callback``.
.. versionchanged:: 4.1
`.HTTPServerConnectionDelegate.start_request` is now called with
two arguments ``(server_conn, request_conn)`` (in accordance with the
documentation) instead of one ``(request_conn)``.
.. versionchanged:: 4.2
`HTTPServer` is now a subclass of `tornado.util.Configurable`.
"""
def __init__(self, request_callback, no_keep_alive=False, io_loop=None,
def __init__(self, *args, **kwargs):
# Ignore args to __init__; real initialization belongs in
# initialize since we're Configurable. (there's something
# weird in initialization order between this class,
# Configurable, and TCPServer so we can't leave __init__ out
# completely)
pass
def initialize(self, request_callback, no_keep_alive=False, io_loop=None,
xheaders=False, ssl_options=None, protocol=None,
decompress_request=False,
chunk_size=None, max_header_size=None,
@ -157,6 +155,14 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
read_chunk_size=chunk_size)
self._connections = set()
@classmethod
def configurable_base(cls):
return HTTPServer
@classmethod
def configurable_default(cls):
return HTTPServer
@gen.coroutine
def close_all_connections(self):
while self._connections:
@ -173,7 +179,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
conn.start_serving(self)
def start_request(self, server_conn, request_conn):
return _ServerRequestAdapter(self, request_conn)
return _ServerRequestAdapter(self, server_conn, request_conn)
def on_close(self, server_conn):
self._connections.remove(server_conn)
@ -246,13 +252,14 @@ class _ServerRequestAdapter(httputil.HTTPMessageDelegate):
"""Adapts the `HTTPMessageDelegate` interface to the interface expected
by our clients.
"""
def __init__(self, server, connection):
def __init__(self, server, server_conn, request_conn):
self.server = server
self.connection = connection
self.connection = request_conn
self.request = None
if isinstance(server.request_callback,
httputil.HTTPServerConnectionDelegate):
self.delegate = server.request_callback.start_request(connection)
self.delegate = server.request_callback.start_request(
server_conn, request_conn)
self._chunks = None
else:
self.delegate = None

View file

@ -62,6 +62,11 @@ except ImportError:
pass
# RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line
# terminator and ignore any preceding CR.
_CRLF_RE = re.compile(r'\r?\n')
class _NormalizedHeaderCache(dict):
"""Dynamic cached mapping of header names to Http-Header-Case.
@ -193,7 +198,7 @@ class HTTPHeaders(dict):
[('Content-Length', '42'), ('Content-Type', 'text/html')]
"""
h = cls()
for line in headers.splitlines():
for line in _CRLF_RE.split(headers):
if line:
h.parse_line(line)
return h
@ -229,6 +234,14 @@ class HTTPHeaders(dict):
# default implementation returns dict(self), not the subclass
return HTTPHeaders(self)
# Use our overridden copy method for the copy.copy module.
__copy__ = copy
def __deepcopy__(self, memo_dict):
# Our values are immutable strings, so our standard copy is
# effectively a deep copy.
return self.copy()
class HTTPServerRequest(object):
"""A single HTTP request.
@ -331,7 +344,7 @@ class HTTPServerRequest(object):
self.uri = uri
self.version = version
self.headers = headers or HTTPHeaders()
self.body = body or ""
self.body = body or b""
# set remote IP and protocol
context = getattr(connection, 'context', None)
@ -380,6 +393,8 @@ class HTTPServerRequest(object):
to write the response.
"""
assert isinstance(chunk, bytes)
assert self.version.startswith("HTTP/1."), \
"deprecated interface ony supported in HTTP/1.x"
self.connection.write(chunk, callback=callback)
def finish(self):
@ -406,15 +421,14 @@ class HTTPServerRequest(object):
def get_ssl_certificate(self, binary_form=False):
"""Returns the client's SSL certificate, if any.
To use client certificates, the HTTPServer must have been constructed
with cert_reqs set in ssl_options, e.g.::
To use client certificates, the HTTPServer's
`ssl.SSLContext.verify_mode` field must be set, e.g.::
server = HTTPServer(app,
ssl_options=dict(
certfile="foo.crt",
keyfile="foo.key",
cert_reqs=ssl.CERT_REQUIRED,
ca_certs="cacert.crt"))
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_ctx.load_cert_chain("foo.crt", "foo.key")
ssl_ctx.load_verify_locations("cacerts.pem")
ssl_ctx.verify_mode = ssl.CERT_REQUIRED
server = HTTPServer(app, ssl_options=ssl_ctx)
By default, the return value is a dictionary (or None, if no
client certificate is present). If ``binary_form`` is true, a
@ -543,6 +557,8 @@ class HTTPConnection(object):
headers.
:arg callback: a callback to be run when the write is complete.
The ``version`` field of ``start_line`` is ignored.
Returns a `.Future` if no callback is given.
"""
raise NotImplementedError()
@ -689,6 +705,7 @@ def parse_body_arguments(content_type, body, arguments, files, headers=None):
if values:
arguments.setdefault(name, []).extend(values)
elif content_type.startswith("multipart/form-data"):
try:
fields = content_type.split(";")
for field in fields:
k, sep, v = field.strip().partition("=")
@ -696,7 +713,9 @@ def parse_body_arguments(content_type, body, arguments, files, headers=None):
parse_multipart_form_data(utf8(v), body, arguments, files)
break
else:
gen_log.warning("Invalid multipart/form-data")
raise ValueError("multipart boundary not found")
except Exception as e:
gen_log.warning("Invalid multipart/form-data: %s", e)
def parse_multipart_form_data(boundary, data, arguments, files):
@ -782,7 +801,7 @@ def parse_request_start_line(line):
method, path, version = line.split(" ")
except ValueError:
raise HTTPInputError("Malformed HTTP request line")
if not version.startswith("HTTP/"):
if not re.match(r"^HTTP/1\.[0-9]$", version):
raise HTTPInputError(
"Malformed HTTP version in HTTP Request-Line: %r" % version)
return RequestStartLine(method, path, version)
@ -801,7 +820,7 @@ def parse_response_start_line(line):
ResponseStartLine(version='HTTP/1.1', code=200, reason='OK')
"""
line = native_str(line)
match = re.match("(HTTP/1.[01]) ([0-9]+) ([^\r]*)", line)
match = re.match("(HTTP/1.[0-9]) ([0-9]+) ([^\r]*)", line)
if not match:
raise HTTPInputError("Error parsing response start line")
return ResponseStartLine(match.group(1), int(match.group(2)),
@ -873,3 +892,20 @@ def _encode_header(key, pdict):
def doctests():
import doctest
return doctest.DocTestSuite()
def split_host_and_port(netloc):
"""Returns ``(host, port)`` tuple from ``netloc``.
Returned ``port`` will be ``None`` if not present.
.. versionadded:: 4.1
"""
match = re.match(r'^(.+):(\d+)$', netloc)
if match:
host = match.group(1)
port = int(match.group(2))
else:
host = netloc
port = None
return (host, port)

View file

@ -41,6 +41,7 @@ import sys
import threading
import time
import traceback
import math
from tornado.concurrent import TracebackFuture, is_future
from tornado.log import app_log, gen_log
@ -76,35 +77,52 @@ class IOLoop(Configurable):
simultaneous connections, you should use a system that supports
either ``epoll`` or ``kqueue``.
Example usage for a simple TCP server::
Example usage for a simple TCP server:
.. testcode::
import errno
import functools
import ioloop
import tornado.ioloop
import socket
def connection_ready(sock, fd, events):
while True:
try:
connection, address = sock.accept()
except socket.error, e:
except socket.error as e:
if e.args[0] not in (errno.EWOULDBLOCK, errno.EAGAIN):
raise
return
connection.setblocking(0)
handle_connection(connection, address)
if __name__ == '__main__':
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setblocking(0)
sock.bind(("", port))
sock.listen(128)
io_loop = ioloop.IOLoop.instance()
io_loop = tornado.ioloop.IOLoop.current()
callback = functools.partial(connection_ready, sock)
io_loop.add_handler(sock.fileno(), callback, io_loop.READ)
io_loop.start()
.. testoutput::
:hide:
By default, a newly-constructed `IOLoop` becomes the thread's current
`IOLoop`, unless there already is a current `IOLoop`. This behavior
can be controlled with the ``make_current`` argument to the `IOLoop`
constructor: if ``make_current=True``, the new `IOLoop` will always
try to become current and it raises an error if there is already a
current instance. If ``make_current=False``, the new `IOLoop` will
not try to become current.
.. versionchanged:: 4.2
Added the ``make_current`` keyword argument to the `IOLoop`
constructor.
"""
# Constants from the epoll module
_EPOLLIN = 0x001
@ -133,7 +151,8 @@ class IOLoop(Configurable):
Most applications have a single, global `IOLoop` running on the
main thread. Use this method to get this instance from
another thread. To get the current thread's `IOLoop`, use `current()`.
another thread. In most other cases, it is better to use `current()`
to get the current thread's `IOLoop`.
"""
if not hasattr(IOLoop, "_instance"):
with IOLoop._instance_lock:
@ -167,28 +186,26 @@ class IOLoop(Configurable):
del IOLoop._instance
@staticmethod
def current():
def current(instance=True):
"""Returns the current thread's `IOLoop`.
If an `IOLoop` is currently running or has been marked as current
by `make_current`, returns that instance. Otherwise returns
`IOLoop.instance()`, i.e. the main thread's `IOLoop`.
A common pattern for classes that depend on ``IOLoops`` is to use
a default argument to enable programs with multiple ``IOLoops``
but not require the argument for simpler applications::
class MyClass(object):
def __init__(self, io_loop=None):
self.io_loop = io_loop or IOLoop.current()
If an `IOLoop` is currently running or has been marked as
current by `make_current`, returns that instance. If there is
no current `IOLoop`, returns `IOLoop.instance()` (i.e. the
main thread's `IOLoop`, creating one if necessary) if ``instance``
is true.
In general you should use `IOLoop.current` as the default when
constructing an asynchronous object, and use `IOLoop.instance`
when you mean to communicate to the main thread from a different
one.
.. versionchanged:: 4.1
Added ``instance`` argument to control the fallback to
`IOLoop.instance()`.
"""
current = getattr(IOLoop._current, "instance", None)
if current is None:
if current is None and instance:
return IOLoop.instance()
return current
@ -200,6 +217,10 @@ class IOLoop(Configurable):
`make_current` explicitly before starting the `IOLoop`,
so that code run at startup time can find the right
instance.
.. versionchanged:: 4.1
An `IOLoop` created while there is no current `IOLoop`
will automatically become current.
"""
IOLoop._current.instance = self
@ -223,8 +244,14 @@ class IOLoop(Configurable):
from tornado.platform.select import SelectIOLoop
return SelectIOLoop
def initialize(self):
pass
def initialize(self, make_current=None):
if make_current is None:
if IOLoop.current(instance=False) is None:
self.make_current()
elif make_current:
if IOLoop.current(instance=False) is None:
raise RuntimeError("current IOLoop already exists")
self.make_current()
def close(self, all_fds=False):
"""Closes the `IOLoop`, freeing any resources used.
@ -390,7 +417,7 @@ class IOLoop(Configurable):
# do stuff...
if __name__ == '__main__':
IOLoop.instance().run_sync(main)
IOLoop.current().run_sync(main)
"""
future_cell = [None]
@ -633,8 +660,8 @@ class PollIOLoop(IOLoop):
(Linux), `tornado.platform.kqueue.KQueueIOLoop` (BSD and Mac), or
`tornado.platform.select.SelectIOLoop` (all platforms).
"""
def initialize(self, impl, time_func=None):
super(PollIOLoop, self).initialize()
def initialize(self, impl, time_func=None, **kwargs):
super(PollIOLoop, self).initialize(**kwargs)
self._impl = impl
if hasattr(self._impl, 'fileno'):
set_close_exec(self._impl.fileno())
@ -739,8 +766,10 @@ class PollIOLoop(IOLoop):
# IOLoop is just started once at the beginning.
signal.set_wakeup_fd(old_wakeup_fd)
old_wakeup_fd = None
except ValueError: # non-main thread
pass
except ValueError:
# Non-main thread, or the previous value of wakeup_fd
# is no longer valid.
old_wakeup_fd = None
try:
while True:
@ -944,8 +973,16 @@ class PeriodicCallback(object):
"""Schedules the given callback to be called periodically.
The callback is called every ``callback_time`` milliseconds.
Note that the timeout is given in milliseconds, while most other
time-related functions in Tornado use seconds.
If the callback runs for longer than ``callback_time`` milliseconds,
subsequent invocations will be skipped to get back on schedule.
`start` must be called after the `PeriodicCallback` is created.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
def __init__(self, callback, callback_time, io_loop=None):
self.callback = callback
@ -969,6 +1006,13 @@ class PeriodicCallback(object):
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
def is_running(self):
"""Return True if this `.PeriodicCallback` has been started.
.. versionadded:: 4.1
"""
return self._running
def _run(self):
if not self._running:
return
@ -982,6 +1026,9 @@ class PeriodicCallback(object):
def _schedule_next(self):
if self._running:
current_time = self.io_loop.time()
while self._next_timeout <= current_time:
self._next_timeout += self.callback_time / 1000.0
if self._next_timeout <= current_time:
callback_time_sec = self.callback_time / 1000.0
self._next_timeout += (math.floor((current_time - self._next_timeout) / callback_time_sec) + 1) * callback_time_sec
self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run)

View file

@ -37,7 +37,7 @@ import re
from tornado.concurrent import TracebackFuture
from tornado import ioloop
from tornado.log import gen_log, app_log
from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError
from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError, _client_ssl_defaults, _server_ssl_defaults
from tornado import stack_context
from tornado.util import errno_from_exception
@ -68,13 +68,21 @@ _ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE,
if hasattr(errno, "WSAECONNRESET"):
_ERRNO_CONNRESET += (errno.WSAECONNRESET, errno.WSAECONNABORTED, errno.WSAETIMEDOUT)
if sys.platform == 'darwin':
# OSX appears to have a race condition that causes send(2) to return
# EPROTOTYPE if called while a socket is being torn down:
# http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/
# Since the socket is being closed anyway, treat this as an ECONNRESET
# instead of an unexpected error.
_ERRNO_CONNRESET += (errno.EPROTOTYPE,)
# More non-portable errnos:
_ERRNO_INPROGRESS = (errno.EINPROGRESS,)
if hasattr(errno, "WSAEINPROGRESS"):
_ERRNO_INPROGRESS += (errno.WSAEINPROGRESS,)
#######################################################
class StreamClosedError(IOError):
"""Exception raised by `IOStream` methods when the stream is closed.
@ -122,6 +130,7 @@ class BaseIOStream(object):
"""`BaseIOStream` constructor.
:arg io_loop: The `.IOLoop` to use; defaults to `.IOLoop.current`.
Deprecated since Tornado 4.1.
:arg max_buffer_size: Maximum amount of incoming data to buffer;
defaults to 100MB.
:arg read_chunk_size: Amount of data to read at one time from the
@ -160,6 +169,11 @@ class BaseIOStream(object):
self._close_callback = None
self._connect_callback = None
self._connect_future = None
# _ssl_connect_future should be defined in SSLIOStream
# but it's here so we can clean it up in maybe_run_close_callback.
# TODO: refactor that so subclasses can add additional futures
# to be cancelled.
self._ssl_connect_future = None
self._connecting = False
self._state = None
self._pending_callbacks = 0
@ -230,6 +244,12 @@ class BaseIOStream(object):
gen_log.info("Unsatisfiable read, closing connection: %s" % e)
self.close(exc_info=True)
return future
except:
if future is not None:
# Ensure that the future doesn't log an error because its
# failure was never examined.
future.add_done_callback(lambda f: f.exception())
raise
return future
def read_until(self, delimiter, callback=None, max_bytes=None):
@ -257,6 +277,10 @@ class BaseIOStream(object):
gen_log.info("Unsatisfiable read, closing connection: %s" % e)
self.close(exc_info=True)
return future
except:
if future is not None:
future.add_done_callback(lambda f: f.exception())
raise
return future
def read_bytes(self, num_bytes, callback=None, streaming_callback=None,
@ -281,7 +305,12 @@ class BaseIOStream(object):
self._read_bytes = num_bytes
self._read_partial = partial
self._streaming_callback = stack_context.wrap(streaming_callback)
try:
self._try_inline_read()
except:
if future is not None:
future.add_done_callback(lambda f: f.exception())
raise
return future
def read_until_close(self, callback=None, streaming_callback=None):
@ -293,9 +322,16 @@ class BaseIOStream(object):
If a callback is given, it will be run with the data as an argument;
if not, this method returns a `.Future`.
Note that if a ``streaming_callback`` is used, data will be
read from the socket as quickly as it becomes available; there
is no way to apply backpressure or cancel the reads. If flow
control or cancellation are desired, use a loop with
`read_bytes(partial=True) <.read_bytes>` instead.
.. versionchanged:: 4.0
The callback argument is now optional and a `.Future` will
be returned if it is omitted.
"""
future = self._set_read_callback(callback)
self._streaming_callback = stack_context.wrap(streaming_callback)
@ -305,7 +341,11 @@ class BaseIOStream(object):
self._run_read_callback(self._read_buffer_size, False)
return future
self._read_until_close = True
try:
self._try_inline_read()
except:
future.add_done_callback(lambda f: f.exception())
raise
return future
def write(self, data, callback=None):
@ -331,7 +371,7 @@ class BaseIOStream(object):
if data:
if (self.max_write_buffer_size is not None and
self._write_buffer_size + len(data) > self.max_write_buffer_size):
raise StreamBufferFullError("Reached maximum read buffer size")
raise StreamBufferFullError("Reached maximum write buffer size")
# Break up large contiguous strings before inserting them in the
# write buffer, so we don't have to recopy the entire thing
# as we slice off pieces to send to the socket.
@ -344,6 +384,7 @@ class BaseIOStream(object):
future = None
else:
future = self._write_future = TracebackFuture()
future.add_done_callback(lambda f: f.exception())
if not self._connecting:
self._handle_write()
if self._write_buffer:
@ -401,9 +442,11 @@ class BaseIOStream(object):
if self._connect_future is not None:
futures.append(self._connect_future)
self._connect_future = None
if self._ssl_connect_future is not None:
futures.append(self._ssl_connect_future)
self._ssl_connect_future = None
for future in futures:
if (isinstance(self.error, (socket.error, IOError)) and
errno_from_exception(self.error) in _ERRNO_CONNRESET):
if self._is_connreset(self.error):
# Treat connection resets as closed connections so
# clients only have to catch one kind of exception
# to avoid logging.
@ -601,8 +644,7 @@ class BaseIOStream(object):
pos = self._read_to_buffer_loop()
except UnsatisfiableReadError:
raise
except Exception as e:
if 1 != e.errno:
except Exception:
gen_log.warning("error on read", exc_info=True)
self.close(exc_info=True)
return
@ -633,7 +675,7 @@ class BaseIOStream(object):
self._read_future = None
future.set_result(self._consume(size))
if callback is not None:
assert self._read_future is None
assert (self._read_future is None) or streaming
self._run_callback(callback, self._consume(size))
else:
# If we scheduled a callback, we will add the error listener
@ -684,7 +726,7 @@ class BaseIOStream(object):
chunk = self.read_from_fd()
except (socket.error, IOError, OSError) as e:
# ssl.SSLError is a subclass of socket.error
if e.args[0] in _ERRNO_CONNRESET:
if self._is_connreset(e):
# Treat ECONNRESET as a connection close rather than
# an error to minimize log spam (the exception will
# be available on self.error for apps that care).
@ -806,7 +848,7 @@ class BaseIOStream(object):
self._write_buffer_frozen = True
break
else:
if e.args[0] not in _ERRNO_CONNRESET:
if not self._is_connreset(e):
# Broken pipe errors are usually caused by connection
# reset, and its better to not log EPIPE errors to
# minimize log spam
@ -884,6 +926,14 @@ class BaseIOStream(object):
self._state = self._state | state
self.io_loop.update_handler(self.fileno(), self._state)
def _is_connreset(self, exc):
"""Return true if exc is ECONNRESET or equivalent.
May be overridden in subclasses.
"""
return (isinstance(exc, (socket.error, IOError)) and
errno_from_exception(exc) in _ERRNO_CONNRESET)
class IOStream(BaseIOStream):
r"""Socket-based `IOStream` implementation.
@ -898,7 +948,9 @@ class IOStream(BaseIOStream):
connected before passing it to the `IOStream` or connected with
`IOStream.connect`.
A very simple (and broken) HTTP client using this class::
A very simple (and broken) HTTP client using this class:
.. testcode::
import tornado.ioloop
import tornado.iostream
@ -917,14 +969,19 @@ class IOStream(BaseIOStream):
stream.read_bytes(int(headers[b"Content-Length"]), on_body)
def on_body(data):
print data
print(data)
stream.close()
tornado.ioloop.IOLoop.instance().stop()
tornado.ioloop.IOLoop.current().stop()
if __name__ == '__main__':
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
stream = tornado.iostream.IOStream(s)
stream.connect(("friendfeed.com", 80), send_request)
tornado.ioloop.IOLoop.instance().start()
tornado.ioloop.IOLoop.current().start()
.. testoutput::
:hide:
"""
def __init__(self, socket, *args, **kwargs):
self.socket = socket
@ -978,10 +1035,10 @@ class IOStream(BaseIOStream):
returns a `.Future` (whose result after a successful
connection will be the stream itself).
If specified, the ``server_hostname`` parameter will be used
in SSL connections for certificate validation (if requested in
the ``ssl_options``) and SNI (if supported; requires
Python 3.2+).
In SSL mode, the ``server_hostname`` parameter will be used
for certificate validation (unless disabled in the
``ssl_options``) and SNI (if supported; requires Python
2.7.9+).
Note that it is safe to call `IOStream.write
<BaseIOStream.write>` while the connection is pending, in
@ -992,6 +1049,11 @@ class IOStream(BaseIOStream):
.. versionchanged:: 4.0
If no callback is given, returns a `.Future`.
.. versionchanged:: 4.2
SSL certificates are validated by default; pass
``ssl_options=dict(cert_reqs=ssl.CERT_NONE)`` or a
suitably-configured `ssl.SSLContext` to the
`SSLIOStream` constructor to disable.
"""
self._connecting = True
if callback is not None:
@ -1011,6 +1073,7 @@ class IOStream(BaseIOStream):
# reported later in _handle_connect.
if (errno_from_exception(e) not in _ERRNO_INPROGRESS and
errno_from_exception(e) not in _ERRNO_WOULDBLOCK):
if future is None:
gen_log.warning("Connect error on fd %s: %s",
self.socket.fileno(), e)
self.close(exc_info=True)
@ -1033,10 +1096,11 @@ class IOStream(BaseIOStream):
data. It can also be used immediately after connecting,
before any reads or writes.
The ``ssl_options`` argument may be either a dictionary
of options or an `ssl.SSLContext`. If a ``server_hostname``
is given, it will be used for certificate verification
(as configured in the ``ssl_options``).
The ``ssl_options`` argument may be either an `ssl.SSLContext`
object or a dictionary of keyword arguments for the
`ssl.wrap_socket` function. The ``server_hostname`` argument
will be used for certificate validation unless disabled
in the ``ssl_options``.
This method returns a `.Future` whose result is the new
`SSLIOStream`. After this method has been called,
@ -1046,6 +1110,11 @@ class IOStream(BaseIOStream):
transferred to the new stream.
.. versionadded:: 4.0
.. versionchanged:: 4.2
SSL certificates are validated by default; pass
``ssl_options=dict(cert_reqs=ssl.CERT_NONE)`` or a
suitably-configured `ssl.SSLContext` to disable.
"""
if (self._read_callback or self._read_future or
self._write_callback or self._write_future or
@ -1054,12 +1123,17 @@ class IOStream(BaseIOStream):
self._read_buffer or self._write_buffer):
raise ValueError("IOStream is not idle; cannot convert to SSL")
if ssl_options is None:
ssl_options = {}
if server_side:
ssl_options = _server_ssl_defaults
else:
ssl_options = _client_ssl_defaults
socket = self.socket
self.io_loop.remove_handler(socket)
self.socket = None
socket = ssl_wrap_socket(socket, ssl_options, server_side=server_side,
socket = ssl_wrap_socket(socket, ssl_options,
server_hostname=server_hostname,
server_side=server_side,
do_handshake_on_connect=False)
orig_close_callback = self._close_callback
self._close_callback = None
@ -1071,6 +1145,7 @@ class IOStream(BaseIOStream):
# If we had an "unwrap" counterpart to this method we would need
# to restore the original callback after our Future resolves
# so that repeated wrap/unwrap calls don't build up layers.
def close_callback():
if not future.done():
future.set_exception(ssl_stream.error or StreamClosedError())
@ -1115,7 +1190,7 @@ class IOStream(BaseIOStream):
# Sometimes setsockopt will fail if the socket is closed
# at the wrong time. This can happen with HTTPServer
# resetting the value to false between requests.
if e.errno not in (errno.EINVAL, errno.ECONNRESET):
if e.errno != errno.EINVAL and not self._is_connreset(e):
raise
@ -1131,11 +1206,11 @@ class SSLIOStream(IOStream):
wrapped when `IOStream.connect` is finished.
"""
def __init__(self, *args, **kwargs):
"""The ``ssl_options`` keyword argument may either be a dictionary
of keywords arguments for `ssl.wrap_socket`, or an `ssl.SSLContext`
object.
"""The ``ssl_options`` keyword argument may either be an
`ssl.SSLContext` object or a dictionary of keywords arguments
for `ssl.wrap_socket`
"""
self._ssl_options = kwargs.pop('ssl_options', {})
self._ssl_options = kwargs.pop('ssl_options', _client_ssl_defaults)
super(SSLIOStream, self).__init__(*args, **kwargs)
self._ssl_accepting = True
self._handshake_reading = False
@ -1190,8 +1265,7 @@ class SSLIOStream(IOStream):
# to cause do_handshake to raise EBADF, so make that error
# quiet as well.
# https://groups.google.com/forum/?fromgroups#!topic/python-tornado/ApucKJat1_0
if (err.args[0] in _ERRNO_CONNRESET or
err.args[0] == errno.EBADF):
if self._is_connreset(err) or err.args[0] == errno.EBADF:
return self.close(exc_info=True)
raise
except AttributeError:
@ -1204,10 +1278,17 @@ class SSLIOStream(IOStream):
if not self._verify_cert(self.socket.getpeercert()):
self.close()
return
self._run_ssl_connect_callback()
def _run_ssl_connect_callback(self):
if self._ssl_connect_callback is not None:
callback = self._ssl_connect_callback
self._ssl_connect_callback = None
self._run_callback(callback)
if self._ssl_connect_future is not None:
future = self._ssl_connect_future
self._ssl_connect_future = None
future.set_result(self)
def _verify_cert(self, peercert):
"""Returns True if peercert is valid according to the configured
@ -1249,14 +1330,11 @@ class SSLIOStream(IOStream):
super(SSLIOStream, self)._handle_write()
def connect(self, address, callback=None, server_hostname=None):
# Save the user's callback and run it after the ssl handshake
# has completed.
self._ssl_connect_callback = stack_context.wrap(callback)
self._server_hostname = server_hostname
# Note: Since we don't pass our callback argument along to
# super.connect(), this will always return a Future.
# This is harmless, but a bit less efficient than it could be.
return super(SSLIOStream, self).connect(address, callback=None)
# Pass a dummy callback to super.connect(), which is slightly
# more efficient than letting it return a Future we ignore.
super(SSLIOStream, self).connect(address, callback=lambda: None)
return self.wait_for_handshake(callback)
def _handle_connect(self):
# Call the superclass method to check for errors.
@ -1281,6 +1359,51 @@ class SSLIOStream(IOStream):
do_handshake_on_connect=False)
self._add_io_state(old_state)
def wait_for_handshake(self, callback=None):
"""Wait for the initial SSL handshake to complete.
If a ``callback`` is given, it will be called with no
arguments once the handshake is complete; otherwise this
method returns a `.Future` which will resolve to the
stream itself after the handshake is complete.
Once the handshake is complete, information such as
the peer's certificate and NPN/ALPN selections may be
accessed on ``self.socket``.
This method is intended for use on server-side streams
or after using `IOStream.start_tls`; it should not be used
with `IOStream.connect` (which already waits for the
handshake to complete). It may only be called once per stream.
.. versionadded:: 4.2
"""
if (self._ssl_connect_callback is not None or
self._ssl_connect_future is not None):
raise RuntimeError("Already waiting")
if callback is not None:
self._ssl_connect_callback = stack_context.wrap(callback)
future = None
else:
future = self._ssl_connect_future = TracebackFuture()
if not self._ssl_accepting:
self._run_ssl_connect_callback()
return future
def write_to_fd(self, data):
try:
return self.socket.send(data)
except ssl.SSLError as e:
if e.args[0] == ssl.SSL_ERROR_WANT_WRITE:
# In Python 3.5+, SSLSocket.send raises a WANT_WRITE error if
# the socket is not writeable; we need to transform this into
# an EWOULDBLOCK socket.error or a zero return value,
# either of which will be recognized by the caller of this
# method. Prior to Python 3.5, an unwriteable socket would
# simply return 0 bytes written.
return 0
raise
def read_from_fd(self):
if self._ssl_accepting:
# If the handshake hasn't finished yet, there can't be anything
@ -1311,6 +1434,11 @@ class SSLIOStream(IOStream):
return None
return chunk
def _is_connreset(self, e):
if isinstance(e, ssl.SSLError) and e.args[0] == ssl.SSL_ERROR_EOF:
return True
return super(SSLIOStream, self)._is_connreset(e)
class PipeIOStream(BaseIOStream):
"""Pipe-based `IOStream` implementation.

View file

@ -55,6 +55,7 @@ _default_locale = "en_US"
_translations = {}
_supported_locales = frozenset([_default_locale])
_use_gettext = False
CONTEXT_SEPARATOR = "\x04"
def get(*locale_codes):
@ -273,6 +274,9 @@ class Locale(object):
"""
raise NotImplementedError()
def pgettext(self, context, message, plural_message=None, count=None):
raise NotImplementedError()
def format_date(self, date, gmt_offset=0, relative=True, shorter=False,
full_format=False):
"""Formats the given date (which should be GMT).
@ -422,6 +426,11 @@ class CSVLocale(Locale):
message_dict = self.translations.get("unknown", {})
return message_dict.get(message, message)
def pgettext(self, context, message, plural_message=None, count=None):
if self.translations:
gen_log.warning('pgettext is not supported by CSVLocale')
return self.translate(message, plural_message, count)
class GettextLocale(Locale):
"""Locale implementation using the `gettext` module."""
@ -445,6 +454,44 @@ class GettextLocale(Locale):
else:
return self.gettext(message)
def pgettext(self, context, message, plural_message=None, count=None):
"""Allows to set context for translation, accepts plural forms.
Usage example::
pgettext("law", "right")
pgettext("good", "right")
Plural message example::
pgettext("organization", "club", "clubs", len(clubs))
pgettext("stick", "club", "clubs", len(clubs))
To generate POT file with context, add following options to step 1
of `load_gettext_translations` sequence::
xgettext [basic options] --keyword=pgettext:1c,2 --keyword=pgettext:1c,2,3
.. versionadded:: 4.2
"""
if plural_message is not None:
assert count is not None
msgs_with_ctxt = ("%s%s%s" % (context, CONTEXT_SEPARATOR, message),
"%s%s%s" % (context, CONTEXT_SEPARATOR, plural_message),
count)
result = self.ngettext(*msgs_with_ctxt)
if CONTEXT_SEPARATOR in result:
# Translation not found
result = self.ngettext(message, plural_message, count)
return result
else:
msg_with_ctxt = "%s%s%s" % (context, CONTEXT_SEPARATOR, message)
result = self.gettext(msg_with_ctxt)
if CONTEXT_SEPARATOR in result:
# Translation not found
result = message
return result
LOCALE_NAMES = {
"af_ZA": {"name_en": u("Afrikaans"), "name": u("Afrikaans")},
"am_ET": {"name_en": u("Amharic"), "name": u('\u12a0\u121b\u122d\u129b')},

460
tornado/locks.py Normal file
View file

@ -0,0 +1,460 @@
# Copyright 2015 The Tornado Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
.. testsetup:: *
from tornado import ioloop, gen, locks
io_loop = ioloop.IOLoop.current()
"""
from __future__ import absolute_import, division, print_function, with_statement
__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock']
import collections
from tornado import gen, ioloop
from tornado.concurrent import Future
class _TimeoutGarbageCollector(object):
"""Base class for objects that periodically clean up timed-out waiters.
Avoids memory leak in a common pattern like:
while True:
yield condition.wait(short_timeout)
print('looping....')
"""
def __init__(self):
self._waiters = collections.deque() # Futures.
self._timeouts = 0
def _garbage_collect(self):
# Occasionally clear timed-out waiters.
self._timeouts += 1
if self._timeouts > 100:
self._timeouts = 0
self._waiters = collections.deque(
w for w in self._waiters if not w.done())
class Condition(_TimeoutGarbageCollector):
"""A condition allows one or more coroutines to wait until notified.
Like a standard `threading.Condition`, but does not need an underlying lock
that is acquired and released.
With a `Condition`, coroutines can wait to be notified by other coroutines:
.. testcode::
condition = locks.Condition()
@gen.coroutine
def waiter():
print("I'll wait right here")
yield condition.wait() # Yield a Future.
print("I'm done waiting")
@gen.coroutine
def notifier():
print("About to notify")
condition.notify()
print("Done notifying")
@gen.coroutine
def runner():
# Yield two Futures; wait for waiter() and notifier() to finish.
yield [waiter(), notifier()]
io_loop.run_sync(runner)
.. testoutput::
I'll wait right here
About to notify
Done notifying
I'm done waiting
`wait` takes an optional ``timeout`` argument, which is either an absolute
timestamp::
io_loop = ioloop.IOLoop.current()
# Wait up to 1 second for a notification.
yield condition.wait(timeout=io_loop.time() + 1)
...or a `datetime.timedelta` for a timeout relative to the current time::
# Wait up to 1 second.
yield condition.wait(timeout=datetime.timedelta(seconds=1))
The method raises `tornado.gen.TimeoutError` if there's no notification
before the deadline.
"""
def __init__(self):
super(Condition, self).__init__()
self.io_loop = ioloop.IOLoop.current()
def __repr__(self):
result = '<%s' % (self.__class__.__name__, )
if self._waiters:
result += ' waiters[%s]' % len(self._waiters)
return result + '>'
def wait(self, timeout=None):
"""Wait for `.notify`.
Returns a `.Future` that resolves ``True`` if the condition is notified,
or ``False`` after a timeout.
"""
waiter = Future()
self._waiters.append(waiter)
if timeout:
def on_timeout():
waiter.set_result(False)
self._garbage_collect()
io_loop = ioloop.IOLoop.current()
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
waiter.add_done_callback(
lambda _: io_loop.remove_timeout(timeout_handle))
return waiter
def notify(self, n=1):
"""Wake ``n`` waiters."""
waiters = [] # Waiters we plan to run right now.
while n and self._waiters:
waiter = self._waiters.popleft()
if not waiter.done(): # Might have timed out.
n -= 1
waiters.append(waiter)
for waiter in waiters:
waiter.set_result(True)
def notify_all(self):
"""Wake all waiters."""
self.notify(len(self._waiters))
class Event(object):
"""An event blocks coroutines until its internal flag is set to True.
Similar to `threading.Event`.
A coroutine can wait for an event to be set. Once it is set, calls to
``yield event.wait()`` will not block unless the event has been cleared:
.. testcode::
event = locks.Event()
@gen.coroutine
def waiter():
print("Waiting for event")
yield event.wait()
print("Not waiting this time")
yield event.wait()
print("Done")
@gen.coroutine
def setter():
print("About to set the event")
event.set()
@gen.coroutine
def runner():
yield [waiter(), setter()]
io_loop.run_sync(runner)
.. testoutput::
Waiting for event
About to set the event
Not waiting this time
Done
"""
def __init__(self):
self._future = Future()
def __repr__(self):
return '<%s %s>' % (
self.__class__.__name__, 'set' if self.is_set() else 'clear')
def is_set(self):
"""Return ``True`` if the internal flag is true."""
return self._future.done()
def set(self):
"""Set the internal flag to ``True``. All waiters are awakened.
Calling `.wait` once the flag is set will not block.
"""
if not self._future.done():
self._future.set_result(None)
def clear(self):
"""Reset the internal flag to ``False``.
Calls to `.wait` will block until `.set` is called.
"""
if self._future.done():
self._future = Future()
def wait(self, timeout=None):
"""Block until the internal flag is true.
Returns a Future, which raises `tornado.gen.TimeoutError` after a
timeout.
"""
if timeout is None:
return self._future
else:
return gen.with_timeout(timeout, self._future)
class _ReleasingContextManager(object):
"""Releases a Lock or Semaphore at the end of a "with" statement.
with (yield semaphore.acquire()):
pass
# Now semaphore.release() has been called.
"""
def __init__(self, obj):
self._obj = obj
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
self._obj.release()
class Semaphore(_TimeoutGarbageCollector):
"""A lock that can be acquired a fixed number of times before blocking.
A Semaphore manages a counter representing the number of `.release` calls
minus the number of `.acquire` calls, plus an initial value. The `.acquire`
method blocks if necessary until it can return without making the counter
negative.
Semaphores limit access to a shared resource. To allow access for two
workers at a time:
.. testsetup:: semaphore
from collections import deque
from tornado import gen, ioloop
from tornado.concurrent import Future
# Ensure reliable doctest output: resolve Futures one at a time.
futures_q = deque([Future() for _ in range(3)])
@gen.coroutine
def simulator(futures):
for f in futures:
yield gen.moment
f.set_result(None)
ioloop.IOLoop.current().add_callback(simulator, list(futures_q))
def use_some_resource():
return futures_q.popleft()
.. testcode:: semaphore
sem = locks.Semaphore(2)
@gen.coroutine
def worker(worker_id):
yield sem.acquire()
try:
print("Worker %d is working" % worker_id)
yield use_some_resource()
finally:
print("Worker %d is done" % worker_id)
sem.release()
@gen.coroutine
def runner():
# Join all workers.
yield [worker(i) for i in range(3)]
io_loop.run_sync(runner)
.. testoutput:: semaphore
Worker 0 is working
Worker 1 is working
Worker 0 is done
Worker 2 is working
Worker 1 is done
Worker 2 is done
Workers 0 and 1 are allowed to run concurrently, but worker 2 waits until
the semaphore has been released once, by worker 0.
`.acquire` is a context manager, so ``worker`` could be written as::
@gen.coroutine
def worker(worker_id):
with (yield sem.acquire()):
print("Worker %d is working" % worker_id)
yield use_some_resource()
# Now the semaphore has been released.
print("Worker %d is done" % worker_id)
"""
def __init__(self, value=1):
super(Semaphore, self).__init__()
if value < 0:
raise ValueError('semaphore initial value must be >= 0')
self._value = value
def __repr__(self):
res = super(Semaphore, self).__repr__()
extra = 'locked' if self._value == 0 else 'unlocked,value:{0}'.format(
self._value)
if self._waiters:
extra = '{0},waiters:{1}'.format(extra, len(self._waiters))
return '<{0} [{1}]>'.format(res[1:-1], extra)
def release(self):
"""Increment the counter and wake one waiter."""
self._value += 1
while self._waiters:
waiter = self._waiters.popleft()
if not waiter.done():
self._value -= 1
# If the waiter is a coroutine paused at
#
# with (yield semaphore.acquire()):
#
# then the context manager's __exit__ calls release() at the end
# of the "with" block.
waiter.set_result(_ReleasingContextManager(self))
break
def acquire(self, timeout=None):
"""Decrement the counter. Returns a Future.
Block if the counter is zero and wait for a `.release`. The Future
raises `.TimeoutError` after the deadline.
"""
waiter = Future()
if self._value > 0:
self._value -= 1
waiter.set_result(_ReleasingContextManager(self))
else:
self._waiters.append(waiter)
if timeout:
def on_timeout():
waiter.set_exception(gen.TimeoutError())
self._garbage_collect()
io_loop = ioloop.IOLoop.current()
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
waiter.add_done_callback(
lambda _: io_loop.remove_timeout(timeout_handle))
return waiter
def __enter__(self):
raise RuntimeError(
"Use Semaphore like 'with (yield semaphore.acquire())', not like"
" 'with semaphore'")
__exit__ = __enter__
class BoundedSemaphore(Semaphore):
"""A semaphore that prevents release() being called too many times.
If `.release` would increment the semaphore's value past the initial
value, it raises `ValueError`. Semaphores are mostly used to guard
resources with limited capacity, so a semaphore released too many times
is a sign of a bug.
"""
def __init__(self, value=1):
super(BoundedSemaphore, self).__init__(value=value)
self._initial_value = value
def release(self):
"""Increment the counter and wake one waiter."""
if self._value >= self._initial_value:
raise ValueError("Semaphore released too many times")
super(BoundedSemaphore, self).release()
class Lock(object):
"""A lock for coroutines.
A Lock begins unlocked, and `acquire` locks it immediately. While it is
locked, a coroutine that yields `acquire` waits until another coroutine
calls `release`.
Releasing an unlocked lock raises `RuntimeError`.
`acquire` supports the context manager protocol:
>>> from tornado import gen, locks
>>> lock = locks.Lock()
>>>
>>> @gen.coroutine
... def f():
... with (yield lock.acquire()):
... # Do something holding the lock.
... pass
...
... # Now the lock is released.
"""
def __init__(self):
self._block = BoundedSemaphore(value=1)
def __repr__(self):
return "<%s _block=%s>" % (
self.__class__.__name__,
self._block)
def acquire(self, timeout=None):
"""Attempt to lock. Returns a Future.
Returns a Future, which raises `tornado.gen.TimeoutError` after a
timeout.
"""
return self._block.acquire(timeout)
def release(self):
"""Unlock.
The first coroutine in line waiting for `acquire` gets the lock.
If not locked, raise a `RuntimeError`.
"""
try:
self._block.release()
except ValueError:
raise RuntimeError('release unlocked lock')
def __enter__(self):
raise RuntimeError(
"Use Lock like 'with (yield lock)', not like 'with lock'")
__exit__ = __enter__

View file

@ -206,6 +206,14 @@ def enable_pretty_logging(options=None, logger=None):
def define_logging_options(options=None):
"""Add logging-related flags to ``options``.
These options are present automatically on the default options instance;
this method is only necessary if you have created your own `.OptionParser`.
.. versionadded:: 4.2
This function existed in prior versions but was broken and undocumented until 4.2.
"""
if options is None:
# late import to prevent cycle
from tornado.options import options
@ -227,4 +235,4 @@ def define_logging_options(options=None):
options.define("log_file_num_backups", type=int, default=10,
help="number of log files to keep")
options.add_parse_callback(enable_pretty_logging)
options.add_parse_callback(lambda: enable_pretty_logging(options))

View file

@ -20,7 +20,7 @@ from __future__ import absolute_import, division, print_function, with_statement
import errno
import os
import platform
import sys
import socket
import stat
@ -35,6 +35,15 @@ except ImportError:
# ssl is not available on Google App Engine
ssl = None
try:
import certifi
except ImportError:
# certifi is optional as long as we have ssl.create_default_context.
if ssl is None or hasattr(ssl, 'create_default_context'):
certifi = None
else:
raise
try:
xrange # py2
except NameError:
@ -50,6 +59,38 @@ else:
ssl_match_hostname = backports.ssl_match_hostname.match_hostname
SSLCertificateError = backports.ssl_match_hostname.CertificateError
if hasattr(ssl, 'SSLContext'):
if hasattr(ssl, 'create_default_context'):
# Python 2.7.9+, 3.4+
# Note that the naming of ssl.Purpose is confusing; the purpose
# of a context is to authentiate the opposite side of the connection.
_client_ssl_defaults = ssl.create_default_context(
ssl.Purpose.SERVER_AUTH)
_server_ssl_defaults = ssl.create_default_context(
ssl.Purpose.CLIENT_AUTH)
else:
# Python 3.2-3.3
_client_ssl_defaults = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
_client_ssl_defaults.verify_mode = ssl.CERT_REQUIRED
_client_ssl_defaults.load_verify_locations(certifi.where())
_server_ssl_defaults = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
if hasattr(ssl, 'OP_NO_COMPRESSION'):
# Disable TLS compression to avoid CRIME and related attacks.
# This constant wasn't added until python 3.3.
_client_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
_server_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
elif ssl:
# Python 2.6-2.7.8
_client_ssl_defaults = dict(cert_reqs=ssl.CERT_REQUIRED,
ca_certs=certifi.where())
_server_ssl_defaults = {}
else:
# Google App Engine
_client_ssl_defaults = dict(cert_reqs=None,
ca_certs=None)
_server_ssl_defaults = {}
# ThreadedResolver runs getaddrinfo on a thread. If the hostname is unicode,
# getaddrinfo attempts to import encodings.idna. If this is done at
# module-import time, the import lock is already held by the main thread,
@ -68,6 +109,7 @@ if hasattr(errno, "WSAEWOULDBLOCK"):
# Default backlog used when calling sock.listen()
_DEFAULT_BACKLOG = 128
def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
backlog=_DEFAULT_BACKLOG, flags=None):
"""Creates listening sockets bound to the given port and address.
@ -105,7 +147,7 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
for res in set(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM,
0, flags)):
af, socktype, proto, canonname, sockaddr = res
if (platform.system() == 'Darwin' and address == 'localhost' and
if (sys.platform == 'darwin' and address == 'localhost' and
af == socket.AF_INET6 and sockaddr[3] != 0):
# Mac OS X includes a link-local address fe80::1%lo0 in the
# getaddrinfo results for 'localhost'. However, the firewall
@ -187,6 +229,9 @@ def add_accept_handler(sock, callback, io_loop=None):
address of the other end of the connection). Note that this signature
is different from the ``callback(fd, events)`` signature used for
`.IOLoop` handlers.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
if io_loop is None:
io_loop = IOLoop.current()
@ -301,6 +346,9 @@ class ExecutorResolver(Resolver):
The executor will be shut down when the resolver is closed unless
``close_resolver=False``; use this if you want to reuse the same
executor elsewhere.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
def initialize(self, io_loop=None, executor=None, close_executor=True):
self.io_loop = io_loop or IOLoop.current()
@ -413,7 +461,7 @@ def ssl_options_to_context(ssl_options):
`~ssl.SSLContext` object.
The ``ssl_options`` dictionary contains keywords to be passed to
`ssl.wrap_socket`. In Python 3.2+, `ssl.SSLContext` objects can
`ssl.wrap_socket`. In Python 2.7.9+, `ssl.SSLContext` objects can
be used instead. This function converts the dict form to its
`~ssl.SSLContext` equivalent, and may be used when a component which
accepts both forms needs to upgrade to the `~ssl.SSLContext` version
@ -444,11 +492,11 @@ def ssl_options_to_context(ssl_options):
def ssl_wrap_socket(socket, ssl_options, server_hostname=None, **kwargs):
"""Returns an ``ssl.SSLSocket`` wrapping the given socket.
``ssl_options`` may be either a dictionary (as accepted by
`ssl_options_to_context`) or an `ssl.SSLContext` object.
Additional keyword arguments are passed to ``wrap_socket``
(either the `~ssl.SSLContext` method or the `ssl` module function
as appropriate).
``ssl_options`` may be either an `ssl.SSLContext` object or a
dictionary (as accepted by `ssl_options_to_context`). Additional
keyword arguments are passed to ``wrap_socket`` (either the
`~ssl.SSLContext` method or the `ssl` module function as
appropriate).
"""
context = ssl_options_to_context(ssl_options)
if hasattr(ssl, 'SSLContext') and isinstance(context, ssl.SSLContext):

View file

@ -204,6 +204,13 @@ class OptionParser(object):
(name, self._options[name].file_name))
frame = sys._getframe(0)
options_file = frame.f_code.co_filename
# Can be called directly, or through top level define() fn, in which
# case, step up above that frame to look for real caller.
if (frame.f_back.f_code.co_filename == options_file and
frame.f_back.f_code.co_name == 'define'):
frame = frame.f_back
file_name = frame.f_back.f_code.co_filename
if file_name == options_file:
file_name = ""
@ -249,7 +256,7 @@ class OptionParser(object):
arg = args[i].lstrip("-")
name, equals, value = arg.partition("=")
name = name.replace('-', '_')
if not name in self._options:
if name not in self._options:
self.print_help()
raise Error('Unrecognized command line option: %r' % name)
option = self._options[name]

View file

@ -12,6 +12,8 @@ unfinished callbacks on the event loop that fail when it resumes)
from __future__ import absolute_import, division, print_function, with_statement
import functools
import tornado.concurrent
from tornado.gen import convert_yielded
from tornado.ioloop import IOLoop
from tornado import stack_context
@ -27,8 +29,10 @@ except ImportError as e:
# Re-raise the original asyncio error, not the trollius one.
raise e
class BaseAsyncIOLoop(IOLoop):
def initialize(self, asyncio_loop, close_loop=False):
def initialize(self, asyncio_loop, close_loop=False, **kwargs):
super(BaseAsyncIOLoop, self).initialize(**kwargs)
self.asyncio_loop = asyncio_loop
self.close_loop = close_loop
self.asyncio_loop.call_soon(self.make_current)
@ -129,12 +133,29 @@ class BaseAsyncIOLoop(IOLoop):
class AsyncIOMainLoop(BaseAsyncIOLoop):
def initialize(self):
def initialize(self, **kwargs):
super(AsyncIOMainLoop, self).initialize(asyncio.get_event_loop(),
close_loop=False)
close_loop=False, **kwargs)
class AsyncIOLoop(BaseAsyncIOLoop):
def initialize(self):
def initialize(self, **kwargs):
super(AsyncIOLoop, self).initialize(asyncio.new_event_loop(),
close_loop=True)
close_loop=True, **kwargs)
def to_tornado_future(asyncio_future):
"""Convert an ``asyncio.Future`` to a `tornado.concurrent.Future`."""
tf = tornado.concurrent.Future()
tornado.concurrent.chain_future(asyncio_future, tf)
return tf
def to_asyncio_future(tornado_future):
"""Convert a `tornado.concurrent.Future` to an ``asyncio.Future``."""
af = asyncio.Future()
tornado.concurrent.chain_future(tornado_future, af)
return af
if hasattr(convert_yielded, 'register'):
convert_yielded.register(asyncio.Future, to_tornado_future)

View file

@ -27,13 +27,14 @@ from __future__ import absolute_import, division, print_function, with_statement
import os
if os.name == 'nt':
from tornado.platform.common import Waker
from tornado.platform.windows import set_close_exec
elif 'APPENGINE_RUNTIME' in os.environ:
if 'APPENGINE_RUNTIME' in os.environ:
from tornado.platform.common import Waker
def set_close_exec(fd):
pass
elif os.name == 'nt':
from tornado.platform.common import Waker
from tornado.platform.windows import set_close_exec
else:
from tornado.platform.posix import set_close_exec, Waker
@ -41,9 +42,13 @@ try:
# monotime monkey-patches the time module to have a monotonic function
# in versions of python before 3.3.
import monotime
# Silence pyflakes warning about this unused import
monotime
except ImportError:
pass
try:
from time import monotonic as monotonic_time
except ImportError:
monotonic_time = None
__all__ = ['Waker', 'set_close_exec', 'monotonic_time']

View file

@ -18,6 +18,9 @@ class CaresResolver(Resolver):
so it is only recommended for use in ``AF_INET`` (i.e. IPv4). This is
the default for ``tornado.simple_httpclient``, but other libraries
may default to ``AF_UNSPEC``.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
def initialize(self, io_loop=None):
self.io_loop = io_loop or IOLoop.current()

View file

@ -54,8 +54,7 @@ class _KQueue(object):
if events & IOLoop.WRITE:
kevents.append(select.kevent(
fd, filter=select.KQ_FILTER_WRITE, flags=flags))
if events & IOLoop.READ or not kevents:
# Always read when there is not a write
if events & IOLoop.READ:
kevents.append(select.kevent(
fd, filter=select.KQ_FILTER_READ, flags=flags))
# Even though control() takes a list, it seems to return EINVAL

View file

@ -47,7 +47,7 @@ class _Select(object):
# Closed connections are reported as errors by epoll and kqueue,
# but as zero-byte reads by select, so when errors are requested
# we need to listen for both read and error.
self.read_fds.add(fd)
# self.read_fds.add(fd)
def modify(self, fd, events):
self.unregister(fd)

View file

@ -35,7 +35,7 @@ of the application::
tornado.platform.twisted.install()
from twisted.internet import reactor
When the app is ready to start, call `IOLoop.instance().start()`
When the app is ready to start, call `IOLoop.current().start()`
instead of `reactor.run()`.
It is also possible to create a non-global reactor by calling
@ -70,8 +70,10 @@ import datetime
import functools
import numbers
import socket
import sys
import twisted.internet.abstract
from twisted.internet.defer import Deferred
from twisted.internet.posixbase import PosixReactorBase
from twisted.internet.interfaces import \
IReactorFDSet, IDelayedCall, IReactorTime, IReadDescriptor, IWriteDescriptor
@ -84,6 +86,7 @@ import twisted.names.resolve
from zope.interface import implementer
from tornado.concurrent import Future
from tornado.escape import utf8
from tornado import gen
import tornado.ioloop
@ -147,6 +150,9 @@ class TornadoReactor(PosixReactorBase):
We override `mainLoop` instead of `doIteration` and must implement
timed call functionality on top of `IOLoop.add_timeout` rather than
using the implementation in `PosixReactorBase`.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
def __init__(self, io_loop=None):
if not io_loop:
@ -356,7 +362,11 @@ class _TestReactor(TornadoReactor):
def install(io_loop=None):
"""Install this package as the default Twisted reactor."""
"""Install this package as the default Twisted reactor.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
if not io_loop:
io_loop = tornado.ioloop.IOLoop.current()
reactor = TornadoReactor(io_loop)
@ -406,7 +416,8 @@ class TwistedIOLoop(tornado.ioloop.IOLoop):
because the ``SIGCHLD`` handlers used by Tornado and Twisted conflict
with each other.
"""
def initialize(self, reactor=None):
def initialize(self, reactor=None, **kwargs):
super(TwistedIOLoop, self).initialize(**kwargs)
if reactor is None:
import twisted.internet.reactor
reactor = twisted.internet.reactor
@ -512,6 +523,9 @@ class TwistedResolver(Resolver):
``socket.AF_UNSPEC``.
Requires Twisted 12.1 or newer.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
def initialize(self, io_loop=None):
self.io_loop = io_loop or IOLoop.current()
@ -554,3 +568,18 @@ class TwistedResolver(Resolver):
(resolved_family, (resolved, port)),
]
raise gen.Return(result)
if hasattr(gen.convert_yielded, 'register'):
@gen.convert_yielded.register(Deferred)
def _(d):
f = Future()
def errback(failure):
try:
failure.raiseException()
# Should never happen, but just in case
raise Exception("errback called without error")
except:
f.set_exc_info(sys.exc_info())
d.addCallbacks(f.set_result, errback)
return f

View file

@ -29,6 +29,7 @@ import time
from binascii import hexlify
from tornado.concurrent import Future
from tornado import ioloop
from tornado.iostream import PipeIOStream
from tornado.log import gen_log
@ -48,6 +49,10 @@ except NameError:
long = int # py3
# Re-export this exception for convenience.
CalledProcessError = subprocess.CalledProcessError
def cpu_count():
"""Returns the number of processors on this machine."""
if multiprocessing is None:
@ -191,6 +196,9 @@ class Subprocess(object):
``tornado.process.Subprocess.STREAM``, which will make the corresponding
attribute of the resulting Subprocess a `.PipeIOStream`.
* A new keyword argument ``io_loop`` may be used to pass in an IOLoop.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
STREAM = object()
@ -255,6 +263,33 @@ class Subprocess(object):
Subprocess._waiting[self.pid] = self
Subprocess._try_cleanup_process(self.pid)
def wait_for_exit(self, raise_error=True):
"""Returns a `.Future` which resolves when the process exits.
Usage::
ret = yield proc.wait_for_exit()
This is a coroutine-friendly alternative to `set_exit_callback`
(and a replacement for the blocking `subprocess.Popen.wait`).
By default, raises `subprocess.CalledProcessError` if the process
has a non-zero exit status. Use ``wait_for_exit(raise_error=False)``
to suppress this behavior and return the exit status without raising.
.. versionadded:: 4.2
"""
future = Future()
def callback(ret):
if ret != 0 and raise_error:
# Unfortunately we don't have the original args any more.
future.set_exception(CalledProcessError(ret, None))
else:
future.set_result(ret)
self.set_exit_callback(callback)
return future
@classmethod
def initialize(cls, io_loop=None):
"""Initializes the ``SIGCHLD`` handler.
@ -263,6 +298,9 @@ class Subprocess(object):
Note that the `.IOLoop` used for signal handling need not be the
same one used by individual Subprocess objects (as long as the
``IOLoops`` are each running in separate threads).
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
if cls._initialized:
return

321
tornado/queues.py Normal file
View file

@ -0,0 +1,321 @@
# Copyright 2015 The Tornado Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty']
import collections
import heapq
from tornado import gen, ioloop
from tornado.concurrent import Future
from tornado.locks import Event
class QueueEmpty(Exception):
"""Raised by `.Queue.get_nowait` when the queue has no items."""
pass
class QueueFull(Exception):
"""Raised by `.Queue.put_nowait` when a queue is at its maximum size."""
pass
def _set_timeout(future, timeout):
if timeout:
def on_timeout():
future.set_exception(gen.TimeoutError())
io_loop = ioloop.IOLoop.current()
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
future.add_done_callback(
lambda _: io_loop.remove_timeout(timeout_handle))
class Queue(object):
"""Coordinate producer and consumer coroutines.
If maxsize is 0 (the default) the queue size is unbounded.
.. testcode::
q = queues.Queue(maxsize=2)
@gen.coroutine
def consumer():
while True:
item = yield q.get()
try:
print('Doing work on %s' % item)
yield gen.sleep(0.01)
finally:
q.task_done()
@gen.coroutine
def producer():
for item in range(5):
yield q.put(item)
print('Put %s' % item)
@gen.coroutine
def main():
consumer() # Start consumer.
yield producer() # Wait for producer to put all tasks.
yield q.join() # Wait for consumer to finish all tasks.
print('Done')
io_loop.run_sync(main)
.. testoutput::
Put 0
Put 1
Put 2
Doing work on 0
Doing work on 1
Put 3
Doing work on 2
Put 4
Doing work on 3
Doing work on 4
Done
"""
def __init__(self, maxsize=0):
if maxsize is None:
raise TypeError("maxsize can't be None")
if maxsize < 0:
raise ValueError("maxsize can't be negative")
self._maxsize = maxsize
self._init()
self._getters = collections.deque([]) # Futures.
self._putters = collections.deque([]) # Pairs of (item, Future).
self._unfinished_tasks = 0
self._finished = Event()
self._finished.set()
@property
def maxsize(self):
"""Number of items allowed in the queue."""
return self._maxsize
def qsize(self):
"""Number of items in the queue."""
return len(self._queue)
def empty(self):
return not self._queue
def full(self):
if self.maxsize == 0:
return False
else:
return self.qsize() >= self.maxsize
def put(self, item, timeout=None):
"""Put an item into the queue, perhaps waiting until there is room.
Returns a Future, which raises `tornado.gen.TimeoutError` after a
timeout.
"""
try:
self.put_nowait(item)
except QueueFull:
future = Future()
self._putters.append((item, future))
_set_timeout(future, timeout)
return future
else:
return gen._null_future
def put_nowait(self, item):
"""Put an item into the queue without blocking.
If no free slot is immediately available, raise `QueueFull`.
"""
self._consume_expired()
if self._getters:
assert self.empty(), "queue non-empty, why are getters waiting?"
getter = self._getters.popleft()
self.__put_internal(item)
getter.set_result(self._get())
elif self.full():
raise QueueFull
else:
self.__put_internal(item)
def get(self, timeout=None):
"""Remove and return an item from the queue.
Returns a Future which resolves once an item is available, or raises
`tornado.gen.TimeoutError` after a timeout.
"""
future = Future()
try:
future.set_result(self.get_nowait())
except QueueEmpty:
self._getters.append(future)
_set_timeout(future, timeout)
return future
def get_nowait(self):
"""Remove and return an item from the queue without blocking.
Return an item if one is immediately available, else raise
`QueueEmpty`.
"""
self._consume_expired()
if self._putters:
assert self.full(), "queue not full, why are putters waiting?"
item, putter = self._putters.popleft()
self.__put_internal(item)
putter.set_result(None)
return self._get()
elif self.qsize():
return self._get()
else:
raise QueueEmpty
def task_done(self):
"""Indicate that a formerly enqueued task is complete.
Used by queue consumers. For each `.get` used to fetch a task, a
subsequent call to `.task_done` tells the queue that the processing
on the task is complete.
If a `.join` is blocking, it resumes when all items have been
processed; that is, when every `.put` is matched by a `.task_done`.
Raises `ValueError` if called more times than `.put`.
"""
if self._unfinished_tasks <= 0:
raise ValueError('task_done() called too many times')
self._unfinished_tasks -= 1
if self._unfinished_tasks == 0:
self._finished.set()
def join(self, timeout=None):
"""Block until all items in the queue are processed.
Returns a Future, which raises `tornado.gen.TimeoutError` after a
timeout.
"""
return self._finished.wait(timeout)
# These three are overridable in subclasses.
def _init(self):
self._queue = collections.deque()
def _get(self):
return self._queue.popleft()
def _put(self, item):
self._queue.append(item)
# End of the overridable methods.
def __put_internal(self, item):
self._unfinished_tasks += 1
self._finished.clear()
self._put(item)
def _consume_expired(self):
# Remove timed-out waiters.
while self._putters and self._putters[0][1].done():
self._putters.popleft()
while self._getters and self._getters[0].done():
self._getters.popleft()
def __repr__(self):
return '<%s at %s %s>' % (
type(self).__name__, hex(id(self)), self._format())
def __str__(self):
return '<%s %s>' % (type(self).__name__, self._format())
def _format(self):
result = 'maxsize=%r' % (self.maxsize, )
if getattr(self, '_queue', None):
result += ' queue=%r' % self._queue
if self._getters:
result += ' getters[%s]' % len(self._getters)
if self._putters:
result += ' putters[%s]' % len(self._putters)
if self._unfinished_tasks:
result += ' tasks=%s' % self._unfinished_tasks
return result
class PriorityQueue(Queue):
"""A `.Queue` that retrieves entries in priority order, lowest first.
Entries are typically tuples like ``(priority number, data)``.
.. testcode::
q = queues.PriorityQueue()
q.put((1, 'medium-priority item'))
q.put((0, 'high-priority item'))
q.put((10, 'low-priority item'))
print(q.get_nowait())
print(q.get_nowait())
print(q.get_nowait())
.. testoutput::
(0, 'high-priority item')
(1, 'medium-priority item')
(10, 'low-priority item')
"""
def _init(self):
self._queue = []
def _put(self, item):
heapq.heappush(self._queue, item)
def _get(self):
return heapq.heappop(self._queue)
class LifoQueue(Queue):
"""A `.Queue` that retrieves the most recently put items first.
.. testcode::
q = queues.LifoQueue()
q.put(3)
q.put(2)
q.put(1)
print(q.get_nowait())
print(q.get_nowait())
print(q.get_nowait())
.. testoutput::
1
2
3
"""
def _init(self):
self._queue = []
def _put(self, item):
self._queue.append(item)
def _get(self):
return self._queue.pop()

View file

@ -7,7 +7,7 @@ from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _
from tornado import httputil
from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters
from tornado.iostream import StreamClosedError
from tornado.netutil import Resolver, OverrideResolver
from tornado.netutil import Resolver, OverrideResolver, _client_ssl_defaults
from tornado.log import gen_log
from tornado import stack_context
from tornado.tcpclient import TCPClient
@ -34,7 +34,7 @@ except ImportError:
ssl = None
try:
import lib.certifi
import certifi
except ImportError:
certifi = None
@ -50,9 +50,6 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
"""Non-blocking HTTP client with no external dependencies.
This class implements an HTTP 1.1 client on top of Tornado's IOStreams.
It does not currently implement all applicable parts of the HTTP
specification, but it does enough to work with major web service APIs.
Some features found in the curl-based AsyncHTTPClient are not yet
supported. In particular, proxies are not supported, connections
are not reused, and callers cannot select the network interface to be
@ -60,25 +57,39 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
"""
def initialize(self, io_loop, max_clients=10,
hostname_mapping=None, max_buffer_size=104857600,
resolver=None, defaults=None, max_header_size=None):
resolver=None, defaults=None, max_header_size=None,
max_body_size=None):
"""Creates a AsyncHTTPClient.
Only a single AsyncHTTPClient instance exists per IOLoop
in order to provide limitations on the number of pending connections.
force_instance=True may be used to suppress this behavior.
``force_instance=True`` may be used to suppress this behavior.
max_clients is the number of concurrent requests that can be
in progress. Note that this arguments are only used when the
client is first created, and will be ignored when an existing
client is reused.
Note that because of this implicit reuse, unless ``force_instance``
is used, only the first call to the constructor actually uses
its arguments. It is recommended to use the ``configure`` method
instead of the constructor to ensure that arguments take effect.
hostname_mapping is a dictionary mapping hostnames to IP addresses.
``max_clients`` is the number of concurrent requests that can be
in progress; when this limit is reached additional requests will be
queued. Note that time spent waiting in this queue still counts
against the ``request_timeout``.
``hostname_mapping`` is a dictionary mapping hostnames to IP addresses.
It can be used to make local DNS changes when modifying system-wide
settings like /etc/hosts is not possible or desirable (e.g. in
settings like ``/etc/hosts`` is not possible or desirable (e.g. in
unittests).
max_buffer_size is the number of bytes that can be read by IOStream. It
defaults to 100mb.
``max_buffer_size`` (default 100MB) is the number of bytes
that can be read into memory at once. ``max_body_size``
(defaults to ``max_buffer_size``) is the largest response body
that the client will accept. Without a
``streaming_callback``, the smaller of these two limits
applies; with a ``streaming_callback`` only ``max_body_size``
does.
.. versionchanged:: 4.2
Added the ``max_body_size`` argument.
"""
super(SimpleAsyncHTTPClient, self).initialize(io_loop,
defaults=defaults)
@ -88,6 +99,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
self.waiting = {}
self.max_buffer_size = max_buffer_size
self.max_header_size = max_header_size
self.max_body_size = max_body_size
# TCPClient could create a Resolver for us, but we have to do it
# ourselves to support hostname_mapping.
if resolver:
@ -135,10 +147,14 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
release_callback = functools.partial(self._release_fetch, key)
self._handle_request(request, release_callback, callback)
def _connection_class(self):
return _HTTPConnection
def _handle_request(self, request, release_callback, final_callback):
_HTTPConnection(self.io_loop, self, request, release_callback,
self._connection_class()(
self.io_loop, self, request, release_callback,
final_callback, self.max_buffer_size, self.tcp_client,
self.max_header_size)
self.max_header_size, self.max_body_size)
def _release_fetch(self, key):
del self.active[key]
@ -166,7 +182,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
def __init__(self, io_loop, client, request, release_callback,
final_callback, max_buffer_size, tcp_client,
max_header_size):
max_header_size, max_body_size):
self.start_time = io_loop.time()
self.io_loop = io_loop
self.client = client
@ -176,6 +192,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
self.max_buffer_size = max_buffer_size
self.tcp_client = tcp_client
self.max_header_size = max_header_size
self.max_body_size = max_body_size
self.code = None
self.headers = None
self.chunks = []
@ -193,12 +210,8 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
netloc = self.parsed.netloc
if "@" in netloc:
userpass, _, netloc = netloc.rpartition("@")
match = re.match(r'^(.+):(\d+)$', netloc)
if match:
host = match.group(1)
port = int(match.group(2))
else:
host = netloc
host, port = httputil.split_host_and_port(netloc)
if port is None:
port = 443 if self.parsed.scheme == "https" else 80
if re.match(r'^\[.*\]$', host):
# raw ipv6 addresses in urls are enclosed in brackets
@ -224,12 +237,24 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
def _get_ssl_options(self, scheme):
if scheme == "https":
if self.request.ssl_options is not None:
return self.request.ssl_options
# If we are using the defaults, don't construct a
# new SSLContext.
if (self.request.validate_cert and
self.request.ca_certs is None and
self.request.client_cert is None and
self.request.client_key is None):
return _client_ssl_defaults
ssl_options = {}
if self.request.validate_cert:
ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
if self.request.ca_certs is not None:
ssl_options["ca_certs"] = self.request.ca_certs
else:
elif not hasattr(ssl, 'create_default_context'):
# When create_default_context is present,
# we can omit the "ca_certs" parameter entirely,
# which avoids the dependency on "certifi" for py34.
ssl_options["ca_certs"] = _default_ca_certs()
if self.request.client_key is not None:
ssl_options["keyfile"] = self.request.client_key
@ -323,7 +348,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
if ((body_expected and not body_present) or
(body_present and not body_expected)):
raise ValueError(
'Body must %sbe None for method %s (unelss '
'Body must %sbe None for method %s (unless '
'allow_nonstandard_methods is true)' %
('not ' if body_expected else '', self.request.method))
if self.request.expect_100_continue:
@ -340,26 +365,30 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
self.request.headers["Accept-Encoding"] = "gzip"
req_path = ((self.parsed.path or '/') +
(('?' + self.parsed.query) if self.parsed.query else ''))
self.stream.set_nodelay(True)
self.connection = HTTP1Connection(
self.stream, True,
HTTP1ConnectionParameters(
no_keep_alive=True,
max_header_size=self.max_header_size,
decompress=self.request.decompress_response),
self._sockaddr)
self.connection = self._create_connection(stream)
start_line = httputil.RequestStartLine(self.request.method,
req_path, 'HTTP/1.1')
req_path, '')
self.connection.write_headers(start_line, self.request.headers)
if self.request.expect_100_continue:
self._read_response()
else:
self._write_body(True)
def _create_connection(self, stream):
stream.set_nodelay(True)
connection = HTTP1Connection(
stream, True,
HTTP1ConnectionParameters(
no_keep_alive=True,
max_header_size=self.max_header_size,
max_body_size=self.max_body_size,
decompress=self.request.decompress_response),
self._sockaddr)
return connection
def _write_body(self, start_read):
if self.request.body is not None:
self.connection.write(self.request.body)
self.connection.finish()
elif self.request.body_producer is not None:
fut = self.request.body_producer(self.connection.write)
if is_future(fut):

View file

@ -111,6 +111,7 @@ class _Connector(object):
if self.timeout is not None:
# If the first attempt failed, don't wait for the
# timeout to try an address from the secondary queue.
self.io_loop.remove_timeout(self.timeout)
self.on_timeout()
return
self.clear_timeout()
@ -135,6 +136,9 @@ class _Connector(object):
class TCPClient(object):
"""A non-blocking TCP connection factory.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
def __init__(self, resolver=None, io_loop=None):
self.io_loop = io_loop or IOLoop.current()

View file

@ -41,14 +41,15 @@ class TCPServer(object):
To use `TCPServer`, define a subclass which overrides the `handle_stream`
method.
To make this server serve SSL traffic, send the ssl_options dictionary
argument with the arguments required for the `ssl.wrap_socket` method,
including "certfile" and "keyfile"::
To make this server serve SSL traffic, send the ``ssl_options`` keyword
argument with an `ssl.SSLContext` object. For compatibility with older
versions of Python ``ssl_options`` may also be a dictionary of keyword
arguments for the `ssl.wrap_socket` method.::
TCPServer(ssl_options={
"certfile": os.path.join(data_dir, "mydomain.crt"),
"keyfile": os.path.join(data_dir, "mydomain.key"),
})
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"),
os.path.join(data_dir, "mydomain.key"))
TCPServer(ssl_options=ssl_ctx)
`TCPServer` initialization follows one of three patterns:
@ -56,14 +57,14 @@ class TCPServer(object):
server = TCPServer()
server.listen(8888)
IOLoop.instance().start()
IOLoop.current().start()
2. `bind`/`start`: simple multi-process::
server = TCPServer()
server.bind(8888)
server.start(0) # Forks multiple sub-processes
IOLoop.instance().start()
IOLoop.current().start()
When using this interface, an `.IOLoop` must *not* be passed
to the `TCPServer` constructor. `start` will always start
@ -75,7 +76,7 @@ class TCPServer(object):
tornado.process.fork_processes(0)
server = TCPServer()
server.add_sockets(sockets)
IOLoop.instance().start()
IOLoop.current().start()
The `add_sockets` interface is more complicated, but it can be
used with `tornado.process.fork_processes` to give you more
@ -95,7 +96,7 @@ class TCPServer(object):
self._pending_sockets = []
self._started = False
self.max_buffer_size = max_buffer_size
self.read_chunk_size = None
self.read_chunk_size = read_chunk_size
# Verify the SSL options. Otherwise we don't get errors until clients
# connect. This doesn't verify that the keys are legitimate, but
@ -212,7 +213,20 @@ class TCPServer(object):
sock.close()
def handle_stream(self, stream, address):
"""Override to handle a new `.IOStream` from an incoming connection."""
"""Override to handle a new `.IOStream` from an incoming connection.
This method may be a coroutine; if so any exceptions it raises
asynchronously will be logged. Accepting of incoming connections
will not be blocked by this coroutine.
If this `TCPServer` is configured for SSL, ``handle_stream``
may be called before the SSL handshake has completed. Use
`.SSLIOStream.wait_for_handshake` if you need to verify the client's
certificate or use NPN/ALPN.
.. versionchanged:: 4.2
Added the option for this method to be a coroutine.
"""
raise NotImplementedError()
def _handle_connection(self, connection, address):
@ -252,6 +266,8 @@ class TCPServer(object):
stream = IOStream(connection, io_loop=self.io_loop,
max_buffer_size=self.max_buffer_size,
read_chunk_size=self.read_chunk_size)
self.handle_stream(stream, address)
future = self.handle_stream(stream, address)
if future is not None:
self.io_loop.add_future(future, lambda f: f.result())
except Exception:
app_log.error("Error in connection callback", exc_info=True)

View file

@ -1,4 +0,0 @@
Test coverage is almost non-existent, but it's a start. Be sure to
set PYTHONPATH appropriately (generally to the root directory of your
tornado checkout) when running tests to make sure you're getting the
version of the tornado package that you expect.

View file

@ -0,0 +1,69 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
import sys
import textwrap
from tornado import gen
from tornado.testing import AsyncTestCase, gen_test
from tornado.test.util import unittest
try:
from tornado.platform.asyncio import asyncio, AsyncIOLoop
except ImportError:
asyncio = None
skipIfNoSingleDispatch = unittest.skipIf(
gen.singledispatch is None, "singledispatch module not present")
@unittest.skipIf(asyncio is None, "asyncio module not present")
class AsyncIOLoopTest(AsyncTestCase):
def get_new_ioloop(self):
io_loop = AsyncIOLoop()
asyncio.set_event_loop(io_loop.asyncio_loop)
return io_loop
def test_asyncio_callback(self):
# Basic test that the asyncio loop is set up correctly.
asyncio.get_event_loop().call_soon(self.stop)
self.wait()
@skipIfNoSingleDispatch
@gen_test
def test_asyncio_future(self):
# Test that we can yield an asyncio future from a tornado coroutine.
# Without 'yield from', we must wrap coroutines in asyncio.async.
x = yield asyncio.async(
asyncio.get_event_loop().run_in_executor(None, lambda: 42))
self.assertEqual(x, 42)
@unittest.skipIf(sys.version_info < (3, 3),
'PEP 380 not available')
@skipIfNoSingleDispatch
@gen_test
def test_asyncio_yield_from(self):
# Test that we can use asyncio coroutines with 'yield from'
# instead of asyncio.async(). This requires python 3.3 syntax.
global_namespace = dict(globals(), **locals())
local_namespace = {}
exec(textwrap.dedent("""
@gen.coroutine
def f():
event_loop = asyncio.get_event_loop()
x = yield from event_loop.run_in_executor(None, lambda: 42)
return x
"""), global_namespace, local_namespace)
result = yield local_namespace['f']()
self.assertEqual(result, 42)

View file

@ -5,7 +5,7 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin, TwitterMixin, GoogleMixin, AuthError
from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin, TwitterMixin, AuthError
from tornado.concurrent import Future
from tornado.escape import json_decode
from tornado import gen
@ -238,28 +238,6 @@ class TwitterServerVerifyCredentialsHandler(RequestHandler):
self.write(dict(screen_name='foo', name='Foo'))
class GoogleOpenIdClientLoginHandler(RequestHandler, GoogleMixin):
def initialize(self, test):
self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
@asynchronous
def get(self):
if self.get_argument("openid.mode", None):
self.get_authenticated_user(self.on_user)
return
res = self.authenticate_redirect()
assert isinstance(res, Future)
assert res.done()
def on_user(self, user):
if user is None:
raise Exception("user is None")
self.finish(user)
def get_auth_http_client(self):
return self.settings['http_client']
class AuthTest(AsyncHTTPTestCase):
def get_app(self):
return Application(
@ -286,7 +264,6 @@ class AuthTest(AsyncHTTPTestCase):
('/twitter/client/login_gen_coroutine', TwitterClientLoginGenCoroutineHandler, dict(test=self)),
('/twitter/client/show_user', TwitterClientShowUserHandler, dict(test=self)),
('/twitter/client/show_user_future', TwitterClientShowUserFutureHandler, dict(test=self)),
('/google/client/openid_login', GoogleOpenIdClientLoginHandler, dict(test=self)),
# simulated servers
('/openid/server/authenticate', OpenIdServerAuthenticateHandler),
@ -436,16 +413,3 @@ class AuthTest(AsyncHTTPTestCase):
response = self.fetch('/twitter/client/show_user_future?name=error')
self.assertEqual(response.code, 500)
self.assertIn(b'Error response HTTP 500', response.body)
def test_google_redirect(self):
# same as test_openid_redirect
response = self.fetch('/google/client/openid_login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(
'/openid/server/authenticate?' in response.headers['Location'])
def test_google_get_user(self):
response = self.fetch('/google/client/openid_login?openid.mode=blah&openid.ns.ax=http://openid.net/srv/ax/1.0&openid.ax.type.email=http://axschema.org/contact/email&openid.ax.value.email=foo@example.com', follow_redirects=False)
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["email"], "foo@example.com")

View file

@ -21,13 +21,14 @@ import socket
import sys
import traceback
from tornado.concurrent import Future, return_future, ReturnValueIgnoredError
from tornado.concurrent import Future, return_future, ReturnValueIgnoredError, run_on_executor
from tornado.escape import utf8, to_unicode
from tornado import gen
from tornado.iostream import IOStream
from tornado import stack_context
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, LogTrapTestCase, bind_unused_port, gen_test
from tornado.test.util import unittest
try:
@ -334,3 +335,81 @@ class DecoratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
class GeneratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
client_class = GeneratorCapClient
@unittest.skipIf(futures is None, "concurrent.futures module not present")
class RunOnExecutorTest(AsyncTestCase):
@gen_test
def test_no_calling(self):
class Object(object):
def __init__(self, io_loop):
self.io_loop = io_loop
self.executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor
def f(self):
return 42
o = Object(io_loop=self.io_loop)
answer = yield o.f()
self.assertEqual(answer, 42)
@gen_test
def test_call_with_no_args(self):
class Object(object):
def __init__(self, io_loop):
self.io_loop = io_loop
self.executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor()
def f(self):
return 42
o = Object(io_loop=self.io_loop)
answer = yield o.f()
self.assertEqual(answer, 42)
@gen_test
def test_call_with_io_loop(self):
class Object(object):
def __init__(self, io_loop):
self._io_loop = io_loop
self.executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor(io_loop='_io_loop')
def f(self):
return 42
o = Object(io_loop=self.io_loop)
answer = yield o.f()
self.assertEqual(answer, 42)
@gen_test
def test_call_with_executor(self):
class Object(object):
def __init__(self, io_loop):
self.io_loop = io_loop
self.__executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor(executor='_Object__executor')
def f(self):
return 42
o = Object(io_loop=self.io_loop)
answer = yield o.f()
self.assertEqual(answer, 42)
@gen_test
def test_call_with_both(self):
class Object(object):
def __init__(self, io_loop):
self._io_loop = io_loop
self.__executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor(io_loop='_io_loop', executor='_Object__executor')
def f(self):
return 42
o = Object(io_loop=self.io_loop)
answer = yield o.f()
self.assertEqual(answer, 42)

View file

@ -8,7 +8,7 @@ from tornado.stack_context import ExceptionStackContext
from tornado.testing import AsyncHTTPTestCase
from tornado.test import httpclient_test
from tornado.test.util import unittest
from tornado.web import Application, RequestHandler, URLSpec
from tornado.web import Application, RequestHandler
try:

View file

@ -217,8 +217,7 @@ class EscapeTestCase(unittest.TestCase):
self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9")
def test_squeeze(self):
self.assertEqual(squeeze(u('sequences of whitespace chars'))
, u('sequences of whitespace chars'))
self.assertEqual(squeeze(u('sequences of whitespace chars')), u('sequences of whitespace chars'))
def test_recursive_unicode(self):
tests = {

View file

@ -62,6 +62,11 @@ class GenEngineTest(AsyncTestCase):
def async_future(self, result, callback):
self.io_loop.add_callback(callback, result)
@gen.coroutine
def async_exception(self, e):
yield gen.moment
raise e
def test_no_yield(self):
@gen.engine
def f():
@ -385,11 +390,56 @@ class GenEngineTest(AsyncTestCase):
results = yield [self.async_future(1), self.async_future(2)]
self.assertEqual(results, [1, 2])
@gen_test
def test_multi_future_duplicate(self):
f = self.async_future(2)
results = yield [self.async_future(1), f, self.async_future(3), f]
self.assertEqual(results, [1, 2, 3, 2])
@gen_test
def test_multi_dict_future(self):
results = yield dict(foo=self.async_future(1), bar=self.async_future(2))
self.assertEqual(results, dict(foo=1, bar=2))
@gen_test
def test_multi_exceptions(self):
with ExpectLog(app_log, "Multiple exceptions in yield list"):
with self.assertRaises(RuntimeError) as cm:
yield gen.Multi([self.async_exception(RuntimeError("error 1")),
self.async_exception(RuntimeError("error 2"))])
self.assertEqual(str(cm.exception), "error 1")
# With only one exception, no error is logged.
with self.assertRaises(RuntimeError):
yield gen.Multi([self.async_exception(RuntimeError("error 1")),
self.async_future(2)])
# Exception logging may be explicitly quieted.
with self.assertRaises(RuntimeError):
yield gen.Multi([self.async_exception(RuntimeError("error 1")),
self.async_exception(RuntimeError("error 2"))],
quiet_exceptions=RuntimeError)
@gen_test
def test_multi_future_exceptions(self):
with ExpectLog(app_log, "Multiple exceptions in yield list"):
with self.assertRaises(RuntimeError) as cm:
yield [self.async_exception(RuntimeError("error 1")),
self.async_exception(RuntimeError("error 2"))]
self.assertEqual(str(cm.exception), "error 1")
# With only one exception, no error is logged.
with self.assertRaises(RuntimeError):
yield [self.async_exception(RuntimeError("error 1")),
self.async_future(2)]
# Exception logging may be explicitly quieted.
with self.assertRaises(RuntimeError):
yield gen.multi_future(
[self.async_exception(RuntimeError("error 1")),
self.async_exception(RuntimeError("error 2"))],
quiet_exceptions=RuntimeError)
def test_arguments(self):
@gen.engine
def f():
@ -816,6 +866,7 @@ class GenCoroutineTest(AsyncTestCase):
@gen_test
def test_moment(self):
calls = []
@gen.coroutine
def f(name, yieldable):
for i in range(5):
@ -838,6 +889,34 @@ class GenCoroutineTest(AsyncTestCase):
yield [f('a', gen.moment), f('b', immediate)]
self.assertEqual(''.join(calls), 'abbbbbaaaa')
@gen_test
def test_sleep(self):
yield gen.sleep(0.01)
self.finished = True
@skipBefore33
@gen_test
def test_py3_leak_exception_context(self):
class LeakedException(Exception):
pass
@gen.coroutine
def inner(iteration):
raise LeakedException(iteration)
try:
yield inner(1)
except LeakedException as e:
self.assertEqual(str(e), "1")
self.assertIsNone(e.__context__)
try:
yield inner(2)
except LeakedException as e:
self.assertEqual(str(e), "2")
self.assertIsNone(e.__context__)
self.finished = True
class GenSequenceHandler(RequestHandler):
@asynchronous
@ -1031,7 +1110,7 @@ class WithTimeoutTest(AsyncTestCase):
self.io_loop.add_timeout(datetime.timedelta(seconds=0.1),
lambda: future.set_result('asdf'))
result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
future)
future, io_loop=self.io_loop)
self.assertEqual(result, 'asdf')
@gen_test
@ -1039,16 +1118,17 @@ class WithTimeoutTest(AsyncTestCase):
future = Future()
self.io_loop.add_timeout(
datetime.timedelta(seconds=0.1),
lambda: future.set_exception(ZeroDivisionError))
lambda: future.set_exception(ZeroDivisionError()))
with self.assertRaises(ZeroDivisionError):
yield gen.with_timeout(datetime.timedelta(seconds=3600), future)
yield gen.with_timeout(datetime.timedelta(seconds=3600),
future, io_loop=self.io_loop)
@gen_test
def test_already_resolved(self):
future = Future()
future.set_result('asdf')
result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
future)
future, io_loop=self.io_loop)
self.assertEqual(result, 'asdf')
@unittest.skipIf(futures is None, 'futures module not present')
@ -1067,5 +1147,117 @@ class WithTimeoutTest(AsyncTestCase):
executor.submit(lambda: None))
class WaitIteratorTest(AsyncTestCase):
@gen_test
def test_empty_iterator(self):
g = gen.WaitIterator()
self.assertTrue(g.done(), 'empty generator iterated')
with self.assertRaises(ValueError):
g = gen.WaitIterator(False, bar=False)
self.assertEqual(g.current_index, None, "bad nil current index")
self.assertEqual(g.current_future, None, "bad nil current future")
@gen_test
def test_already_done(self):
f1 = Future()
f2 = Future()
f3 = Future()
f1.set_result(24)
f2.set_result(42)
f3.set_result(84)
g = gen.WaitIterator(f1, f2, f3)
i = 0
while not g.done():
r = yield g.next()
# Order is not guaranteed, but the current implementation
# preserves ordering of already-done Futures.
if i == 0:
self.assertEqual(g.current_index, 0)
self.assertIs(g.current_future, f1)
self.assertEqual(r, 24)
elif i == 1:
self.assertEqual(g.current_index, 1)
self.assertIs(g.current_future, f2)
self.assertEqual(r, 42)
elif i == 2:
self.assertEqual(g.current_index, 2)
self.assertIs(g.current_future, f3)
self.assertEqual(r, 84)
i += 1
self.assertEqual(g.current_index, None, "bad nil current index")
self.assertEqual(g.current_future, None, "bad nil current future")
dg = gen.WaitIterator(f1=f1, f2=f2)
while not dg.done():
dr = yield dg.next()
if dg.current_index == "f1":
self.assertTrue(dg.current_future == f1 and dr == 24,
"WaitIterator dict status incorrect")
elif dg.current_index == "f2":
self.assertTrue(dg.current_future == f2 and dr == 42,
"WaitIterator dict status incorrect")
else:
self.fail("got bad WaitIterator index {}".format(
dg.current_index))
i += 1
self.assertEqual(dg.current_index, None, "bad nil current index")
self.assertEqual(dg.current_future, None, "bad nil current future")
def finish_coroutines(self, iteration, futures):
if iteration == 3:
futures[2].set_result(24)
elif iteration == 5:
futures[0].set_exception(ZeroDivisionError())
elif iteration == 8:
futures[1].set_result(42)
futures[3].set_result(84)
if iteration < 8:
self.io_loop.add_callback(self.finish_coroutines, iteration + 1, futures)
@gen_test
def test_iterator(self):
futures = [Future(), Future(), Future(), Future()]
self.finish_coroutines(0, futures)
g = gen.WaitIterator(*futures)
i = 0
while not g.done():
try:
r = yield g.next()
except ZeroDivisionError:
self.assertIs(g.current_future, futures[0],
'exception future invalid')
else:
if i == 0:
self.assertEqual(r, 24, 'iterator value incorrect')
self.assertEqual(g.current_index, 2, 'wrong index')
elif i == 2:
self.assertEqual(r, 42, 'iterator value incorrect')
self.assertEqual(g.current_index, 1, 'wrong index')
elif i == 3:
self.assertEqual(r, 84, 'iterator value incorrect')
self.assertEqual(g.current_index, 3, 'wrong index')
i += 1
@gen_test
def test_no_ref(self):
# In this usage, there is no direct hard reference to the
# WaitIterator itself, only the Future it returns. Since
# WaitIterator uses weak references internally to improve GC
# performance, this used to cause problems.
yield gen.with_timeout(datetime.timedelta(seconds=0.1),
gen.WaitIterator(gen.sleep(0)).next())
if __name__ == '__main__':
unittest.main()

View file

@ -1,11 +1,16 @@
# flake8: noqa
# Dummy source file to allow creation of the initial .po file in the
# same way as a real project. I'm not entirely sure about the real
# workflow here, but this seems to work.
#
# 1) xgettext --language=Python --keyword=_:1,2 -d tornado_test extract_me.py -o tornado_test.po
# 2) Edit tornado_test.po, setting CHARSET and setting msgstr
# 1) xgettext --language=Python --keyword=_:1,2 --keyword=pgettext:1c,2 --keyword=pgettext:1c,2,3 extract_me.py -o tornado_test.po
# 2) Edit tornado_test.po, setting CHARSET, Plural-Forms and setting msgstr
# 3) msgfmt tornado_test.po -o tornado_test.mo
# 4) Put the file in the proper location: $LANG/LC_MESSAGES
from __future__ import absolute_import, division, print_function, with_statement
_("school")
pgettext("law", "right")
pgettext("good", "right")
pgettext("organization", "club", "clubs", 1)
pgettext("stick", "club", "clubs", 1)

View file

@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2012-06-14 01:10-0700\n"
"POT-Creation-Date: 2015-01-27 11:05+0300\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
@ -16,7 +16,32 @@ msgstr ""
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Plural-Forms: nplurals=2; plural=(n > 1);\n"
#: extract_me.py:1
#: extract_me.py:11
msgid "school"
msgstr "école"
#: extract_me.py:12
msgctxt "law"
msgid "right"
msgstr "le droit"
#: extract_me.py:13
msgctxt "good"
msgid "right"
msgstr "le bien"
#: extract_me.py:14
msgctxt "organization"
msgid "club"
msgid_plural "clubs"
msgstr[0] "le club"
msgstr[1] "les clubs"
#: extract_me.py:15
msgctxt "stick"
msgid "club"
msgid_plural "clubs"
msgstr[0] "le bâton"
msgstr[1] "les bâtons"

View file

@ -12,6 +12,7 @@ import datetime
from io import BytesIO
from tornado.escape import utf8
from tornado import gen
from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
@ -52,9 +53,12 @@ class RedirectHandler(RequestHandler):
class ChunkHandler(RequestHandler):
@gen.coroutine
def get(self):
self.write("asdf")
self.flush()
# Wait a bit to ensure the chunks are sent and received separately.
yield gen.sleep(0.01)
self.write("qwer")
@ -178,6 +182,8 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase):
sock, port = bind_unused_port()
with closing(sock):
def write_response(stream, request_data):
if b"HTTP/1." not in request_data:
self.skipTest("requires HTTP/1.x")
stream.write(b"""\
HTTP/1.1 200 OK
Transfer-Encoding: chunked
@ -300,23 +306,26 @@ Transfer-Encoding: chunked
chunks = []
def header_callback(header_line):
if header_line.startswith('HTTP/'):
if header_line.startswith('HTTP/1.1 101'):
# Upgrading to HTTP/2
pass
elif header_line.startswith('HTTP/'):
first_line.append(header_line)
elif header_line != '\r\n':
k, v = header_line.split(':', 1)
headers[k] = v.strip()
headers[k.lower()] = v.strip()
def streaming_callback(chunk):
# All header callbacks are run before any streaming callbacks,
# so the header data is available to process the data as it
# comes in.
self.assertEqual(headers['Content-Type'], 'text/html; charset=UTF-8')
self.assertEqual(headers['content-type'], 'text/html; charset=UTF-8')
chunks.append(chunk)
self.fetch('/chunk', header_callback=header_callback,
streaming_callback=streaming_callback)
self.assertEqual(len(first_line), 1)
self.assertRegexpMatches(first_line[0], 'HTTP/1.[01] 200 OK\r\n')
self.assertEqual(len(first_line), 1, first_line)
self.assertRegexpMatches(first_line[0], 'HTTP/[0-9]\\.[0-9] 200.*\r\n')
self.assertEqual(chunks, [b'asdf', b'qwer'])
def test_header_callback_stack_context(self):
@ -327,7 +336,7 @@ Transfer-Encoding: chunked
return True
def header_callback(header_line):
if header_line.startswith('Content-Type:'):
if header_line.lower().startswith('content-type:'):
1 / 0
with ExceptionStackContext(error_handler):
@ -404,6 +413,11 @@ Transfer-Encoding: chunked
self.assertEqual(context.exception.code, 404)
self.assertEqual(context.exception.response.code, 404)
@gen_test
def test_future_http_error_no_raise(self):
response = yield self.http_client.fetch(self.get_url('/notfound'), raise_error=False)
self.assertEqual(response.code, 404)
@gen_test
def test_reuse_request_from_response(self):
# The response.request attribute should be an HTTPRequest, not
@ -454,7 +468,7 @@ Transfer-Encoding: chunked
# Twisted's reactor does not. The removeReader call fails and so
# do all future removeAll calls (which our tests do at cleanup).
#
#def test_post_307(self):
# def test_post_307(self):
# response = self.fetch("/redirect?status=307&url=/post",
# method="POST", body=b"arg1=foo&arg2=bar")
# self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
@ -536,14 +550,19 @@ class SyncHTTPClientTest(unittest.TestCase):
def tearDown(self):
def stop_server():
self.server.stop()
self.server_ioloop.stop()
# Delay the shutdown of the IOLoop by one iteration because
# the server may still have some cleanup work left when
# the client finishes with the response (this is noticable
# with http/2, which leaves a Future with an unexamined
# StreamClosedError on the loop).
self.server_ioloop.add_callback(self.server_ioloop.stop)
self.server_ioloop.add_callback(stop_server)
self.server_thread.join()
self.http_client.close()
self.server_ioloop.close(all_fds=True)
def get_url(self, path):
return 'http://localhost:%d%s' % (self.port, path)
return 'http://127.0.0.1:%d%s' % (self.port, path)
def test_sync_client(self):
response = self.http_client.fetch(self.get_url('/'))

View file

@ -32,6 +32,7 @@ def read_stream_body(stream, callback):
"""Reads an HTTP response from `stream` and runs callback with its
headers and body."""
chunks = []
class Delegate(HTTPMessageDelegate):
def headers_received(self, start_line, headers):
self.headers = headers
@ -161,11 +162,14 @@ class BadSSLOptionsTest(unittest.TestCase):
application = Application()
module_dir = os.path.dirname(__file__)
existing_certificate = os.path.join(module_dir, 'test.crt')
existing_key = os.path.join(module_dir, 'test.key')
self.assertRaises(ValueError, HTTPServer, application, ssl_options={
self.assertRaises((ValueError, IOError),
HTTPServer, application, ssl_options={
"certfile": "/__mising__.crt",
})
self.assertRaises(ValueError, HTTPServer, application, ssl_options={
self.assertRaises((ValueError, IOError),
HTTPServer, application, ssl_options={
"certfile": existing_certificate,
"keyfile": "/__missing__.key"
})
@ -173,7 +177,7 @@ class BadSSLOptionsTest(unittest.TestCase):
# This actually works because both files exist
HTTPServer(application, ssl_options={
"certfile": existing_certificate,
"keyfile": existing_certificate
"keyfile": existing_key,
})
@ -195,14 +199,14 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
def get_app(self):
return Application(self.get_handlers())
def raw_fetch(self, headers, body):
def raw_fetch(self, headers, body, newline=b"\r\n"):
with closing(IOStream(socket.socket())) as stream:
stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
self.wait()
stream.write(
b"\r\n".join(headers +
[utf8("Content-Length: %d\r\n" % len(body))]) +
b"\r\n" + body)
newline.join(headers +
[utf8("Content-Length: %d" % len(body))]) +
newline + newline + body)
read_stream_body(stream, self.stop)
headers, body = self.wait()
return body
@ -232,12 +236,19 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
self.assertEqual(u("\u00f3"), data["filename"])
self.assertEqual(u("\u00fa"), data["filebody"])
def test_newlines(self):
# We support both CRLF and bare LF as line separators.
for newline in (b"\r\n", b"\n"):
response = self.raw_fetch([b"GET /hello HTTP/1.0"], b"",
newline=newline)
self.assertEqual(response, b'Hello world')
def test_100_continue(self):
# Run through a 100-continue interaction by hand:
# When given Expect: 100-continue, we get a 100 response after the
# headers, and then the real response after the body.
stream = IOStream(socket.socket(), io_loop=self.io_loop)
stream.connect(("localhost", self.get_http_port()), callback=self.stop)
stream.connect(("127.0.0.1", self.get_http_port()), callback=self.stop)
self.wait()
stream.write(b"\r\n".join([b"POST /hello HTTP/1.1",
b"Content-Length: 1024",
@ -374,7 +385,7 @@ class HTTPServerRawTest(AsyncHTTPTestCase):
def setUp(self):
super(HTTPServerRawTest, self).setUp()
self.stream = IOStream(socket.socket())
self.stream.connect(('localhost', self.get_http_port()), self.stop)
self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
self.wait()
def tearDown(self):
@ -555,7 +566,7 @@ class UnixSocketTest(AsyncTestCase):
self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
self.stream.read_until(b"\r\n", self.stop)
response = self.wait()
self.assertEqual(response, b"HTTP/1.0 200 OK\r\n")
self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
self.stream.read_until(b"\r\n\r\n", self.stop)
headers = HTTPHeaders.parse(self.wait().decode('latin1'))
self.stream.read_bytes(int(headers["Content-Length"]), self.stop)
@ -582,6 +593,7 @@ class KeepAliveTest(AsyncHTTPTestCase):
class HelloHandler(RequestHandler):
def get(self):
self.finish('Hello world')
def post(self):
self.finish('Hello world')
@ -623,13 +635,13 @@ class KeepAliveTest(AsyncHTTPTestCase):
# The next few methods are a crude manual http client
def connect(self):
self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
self.stream.connect(('localhost', self.get_http_port()), self.stop)
self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
self.wait()
def read_headers(self):
self.stream.read_until(b'\r\n', self.stop)
first_line = self.wait()
self.assertTrue(first_line.startswith(self.http_version + b' 200'), first_line)
self.assertTrue(first_line.startswith(b'HTTP/1.1 200'), first_line)
self.stream.read_until(b'\r\n\r\n', self.stop)
header_bytes = self.wait()
headers = HTTPHeaders.parse(header_bytes.decode('latin1'))
@ -808,8 +820,8 @@ class StreamingChunkSizeTest(AsyncHTTPTestCase):
def get_app(self):
class App(HTTPServerConnectionDelegate):
def start_request(self, connection):
return StreamingChunkSizeTest.MessageDelegate(connection)
def start_request(self, server_conn, request_conn):
return StreamingChunkSizeTest.MessageDelegate(request_conn)
return App()
def fetch_chunk_sizes(self, **kwargs):
@ -856,6 +868,7 @@ class StreamingChunkSizeTest(AsyncHTTPTestCase):
def test_chunked_compressed(self):
compressed = self.compress(self.BODY)
self.assertGreater(len(compressed), 20)
def body_producer(write):
write(compressed[:20])
write(compressed[20:])
@ -900,7 +913,7 @@ class IdleTimeoutTest(AsyncHTTPTestCase):
def connect(self):
stream = IOStream(socket.socket())
stream.connect(('localhost', self.get_http_port()), self.stop)
stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
self.wait()
self.streams.append(stream)
return stream
@ -1045,6 +1058,15 @@ class LegacyInterfaceTest(AsyncHTTPTestCase):
# delegate interface, and writes its response via request.write
# instead of request.connection.write_headers.
def handle_request(request):
self.http1 = request.version.startswith("HTTP/1.")
if not self.http1:
# This test will be skipped if we're using HTTP/2,
# so just close it out cleanly using the modern interface.
request.connection.write_headers(
ResponseStartLine('', 200, 'OK'),
HTTPHeaders())
request.connection.finish()
return
message = b"Hello world"
request.write(utf8("HTTP/1.1 200 OK\r\n"
"Content-Length: %d\r\n\r\n" % len(message)))
@ -1054,4 +1076,6 @@ class LegacyInterfaceTest(AsyncHTTPTestCase):
def test_legacy_interface(self):
response = self.fetch('/')
if not self.http1:
self.skipTest("requires HTTP/1.x")
self.assertEqual(response.body, b"Hello world")

View file

@ -3,11 +3,13 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado.httputil import url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp, HTTPServerRequest, parse_request_start_line
from tornado.escape import utf8
from tornado.escape import utf8, native_str
from tornado.log import gen_log
from tornado.testing import ExpectLog
from tornado.test.util import unittest
from tornado.util import u
import copy
import datetime
import logging
import time
@ -228,6 +230,75 @@ Foo: even
("Foo", "bar baz"),
("Foo", "even more lines")])
def test_unicode_newlines(self):
# Ensure that only \r\n is recognized as a header separator, and not
# the other newline-like unicode characters.
# Characters that are likely to be problematic can be found in
# http://unicode.org/standard/reports/tr13/tr13-5.html
# and cpython's unicodeobject.c (which defines the implementation
# of unicode_type.splitlines(), and uses a different list than TR13).
newlines = [
u('\u001b'), # VERTICAL TAB
u('\u001c'), # FILE SEPARATOR
u('\u001d'), # GROUP SEPARATOR
u('\u001e'), # RECORD SEPARATOR
u('\u0085'), # NEXT LINE
u('\u2028'), # LINE SEPARATOR
u('\u2029'), # PARAGRAPH SEPARATOR
]
for newline in newlines:
# Try the utf8 and latin1 representations of each newline
for encoding in ['utf8', 'latin1']:
try:
try:
encoded = newline.encode(encoding)
except UnicodeEncodeError:
# Some chars cannot be represented in latin1
continue
data = b'Cookie: foo=' + encoded + b'bar'
# parse() wants a native_str, so decode through latin1
# in the same way the real parser does.
headers = HTTPHeaders.parse(
native_str(data.decode('latin1')))
expected = [('Cookie', 'foo=' +
native_str(encoded.decode('latin1')) + 'bar')]
self.assertEqual(
expected, list(headers.get_all()))
except Exception:
gen_log.warning("failed while trying %r in %s",
newline, encoding)
raise
def test_optional_cr(self):
# Both CRLF and LF should be accepted as separators. CR should not be
# part of the data when followed by LF, but it is a normal char
# otherwise (or should bare CR be an error?)
headers = HTTPHeaders.parse(
'CRLF: crlf\r\nLF: lf\nCR: cr\rMore: more\r\n')
self.assertEqual(sorted(headers.get_all()),
[('Cr', 'cr\rMore: more'),
('Crlf', 'crlf'),
('Lf', 'lf'),
])
def test_copy(self):
all_pairs = [('A', '1'), ('A', '2'), ('B', 'c')]
h1 = HTTPHeaders()
for k, v in all_pairs:
h1.add(k, v)
h2 = h1.copy()
h3 = copy.copy(h1)
h4 = copy.deepcopy(h1)
for headers in [h1, h2, h3, h4]:
# All the copies are identical, no matter how they were
# constructed.
self.assertEqual(list(sorted(headers.get_all())), all_pairs)
for headers in [h2, h3, h4]:
# Neither the dict or its member lists are reused.
self.assertIsNot(headers, h1)
self.assertIsNot(headers.get_list('A'), h1.get_list('A'))
class FormatTimestampTest(unittest.TestCase):
# Make sure that all the input types are supported.
@ -264,6 +335,10 @@ class HTTPServerRequestTest(unittest.TestCase):
# more required parameters slip in.
HTTPServerRequest(uri='/')
def test_body_is_a_byte_string(self):
requets = HTTPServerRequest(uri='/')
self.assertIsInstance(requets.body, bytes)
class ParseRequestStartLineTest(unittest.TestCase):
METHOD = "GET"

View file

@ -1,3 +1,4 @@
# flake8: noqa
from __future__ import absolute_import, division, print_function, with_statement
from tornado.test.util import unittest

View file

@ -11,8 +11,9 @@ import threading
import time
from tornado import gen
from tornado.ioloop import IOLoop, TimeoutError
from tornado.ioloop import IOLoop, TimeoutError, PollIOLoop, PeriodicCallback
from tornado.log import app_log
from tornado.platform.select import _Select
from tornado.stack_context import ExceptionStackContext, StackContext, wrap, NullContext
from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog
from tornado.test.util import unittest, skipIfNonUnix, skipOnTravis
@ -23,6 +24,42 @@ except ImportError:
futures = None
class FakeTimeSelect(_Select):
def __init__(self):
self._time = 1000
super(FakeTimeSelect, self).__init__()
def time(self):
return self._time
def sleep(self, t):
self._time += t
def poll(self, timeout):
events = super(FakeTimeSelect, self).poll(0)
if events:
return events
self._time += timeout
return []
class FakeTimeIOLoop(PollIOLoop):
"""IOLoop implementation with a fake and deterministic clock.
The clock advances as needed to trigger timeouts immediately.
For use when testing code that involves the passage of time
and no external dependencies.
"""
def initialize(self):
self.fts = FakeTimeSelect()
super(FakeTimeIOLoop, self).initialize(impl=self.fts,
time_func=self.fts.time)
def sleep(self, t):
"""Simulate a blocking sleep by advancing the clock."""
self.fts.sleep(t)
class TestIOLoop(AsyncTestCase):
@skipOnTravis
def test_add_callback_wakeup(self):
@ -180,10 +217,12 @@ class TestIOLoop(AsyncTestCase):
# t2 should be cancelled by t1, even though it is already scheduled to
# be run before the ioloop even looks at it.
now = self.io_loop.time()
def t1():
calls[0] = True
self.io_loop.remove_timeout(t2_handle)
self.io_loop.add_timeout(now + 0.01, t1)
def t2():
calls[1] = True
t2_handle = self.io_loop.add_timeout(now + 0.02, t2)
@ -252,6 +291,7 @@ class TestIOLoop(AsyncTestCase):
"""The handler callback receives the same fd object it passed in."""
server_sock, port = bind_unused_port()
fds = []
def handle_connection(fd, events):
fds.append(fd)
conn, addr = server_sock.accept()
@ -274,6 +314,7 @@ class TestIOLoop(AsyncTestCase):
def test_mixed_fd_fileobj(self):
server_sock, port = bind_unused_port()
def f(fd, events):
pass
self.io_loop.add_handler(server_sock, f, IOLoop.READ)
@ -288,6 +329,7 @@ class TestIOLoop(AsyncTestCase):
"""Calling start() twice should raise an error, not deadlock."""
returned_from_start = [False]
got_exception = [False]
def callback():
try:
self.io_loop.start()
@ -305,7 +347,7 @@ class TestIOLoop(AsyncTestCase):
# Use a NullContext to keep the exception from being caught by
# AsyncTestCase.
with NullContext():
self.io_loop.add_callback(lambda: 1/0)
self.io_loop.add_callback(lambda: 1 / 0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
@ -316,7 +358,7 @@ class TestIOLoop(AsyncTestCase):
@gen.coroutine
def callback():
self.io_loop.add_callback(self.stop)
1/0
1 / 0
self.io_loop.add_callback(callback)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
@ -324,12 +366,12 @@ class TestIOLoop(AsyncTestCase):
def test_spawn_callback(self):
# An added callback runs in the test's stack_context, so will be
# re-arised in wait().
self.io_loop.add_callback(lambda: 1/0)
self.io_loop.add_callback(lambda: 1 / 0)
with self.assertRaises(ZeroDivisionError):
self.wait()
# A spawned callback is run directly on the IOLoop, so it will be
# logged without stopping the test.
self.io_loop.spawn_callback(lambda: 1/0)
self.io_loop.spawn_callback(lambda: 1 / 0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
@ -344,6 +386,7 @@ class TestIOLoop(AsyncTestCase):
# After reading from one fd, remove the other from the IOLoop.
chunks = []
def handle_read(fd, events):
chunks.append(fd.recv(1024))
if fd is client:
@ -352,7 +395,7 @@ class TestIOLoop(AsyncTestCase):
self.io_loop.remove_handler(client)
self.io_loop.add_handler(client, handle_read, self.io_loop.READ)
self.io_loop.add_handler(server, handle_read, self.io_loop.READ)
self.io_loop.call_later(0.01, self.stop)
self.io_loop.call_later(0.03, self.stop)
self.wait()
# Only one fd was read; the other was cleanly removed.
@ -520,5 +563,47 @@ class TestIOLoopRunSync(unittest.TestCase):
self.assertRaises(TimeoutError, self.io_loop.run_sync, f, timeout=0.01)
class TestPeriodicCallback(unittest.TestCase):
def setUp(self):
self.io_loop = FakeTimeIOLoop()
self.io_loop.make_current()
def tearDown(self):
self.io_loop.close()
def test_basic(self):
calls = []
def cb():
calls.append(self.io_loop.time())
pc = PeriodicCallback(cb, 10000)
pc.start()
self.io_loop.call_later(50, self.io_loop.stop)
self.io_loop.start()
self.assertEqual(calls, [1010, 1020, 1030, 1040, 1050])
def test_overrun(self):
sleep_durations = [9, 9, 10, 11, 20, 20, 35, 35, 0, 0]
expected = [
1010, 1020, 1030, # first 3 calls on schedule
1050, 1070, # next 2 delayed one cycle
1100, 1130, # next 2 delayed 2 cycles
1170, 1210, # next 2 delayed 3 cycles
1220, 1230, # then back on schedule.
]
calls = []
def cb():
calls.append(self.io_loop.time())
if not sleep_durations:
self.io_loop.stop()
return
self.io_loop.sleep(sleep_durations.pop(0))
pc = PeriodicCallback(cb, 10000)
pc.start()
self.io_loop.start()
self.assertEqual(calls, expected)
if __name__ == "__main__":
unittest.main()

View file

@ -7,10 +7,10 @@ from tornado.httputil import HTTPHeaders
from tornado.log import gen_log, app_log
from tornado.netutil import ssl_wrap_socket
from tornado.stack_context import NullContext
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test
from tornado.test.util import unittest, skipIfNonUnix
from tornado.test.util import unittest, skipIfNonUnix, refusing_port
from tornado.web import RequestHandler, Application
import lib.certifi
import errno
import logging
import os
@ -51,18 +51,18 @@ class TestIOStreamWebMixin(object):
def test_read_until_close(self):
stream = self._make_client_iostream()
stream.connect(('localhost', self.get_http_port()), callback=self.stop)
stream.connect(('127.0.0.1', self.get_http_port()), callback=self.stop)
self.wait()
stream.write(b"GET / HTTP/1.0\r\n\r\n")
stream.read_until_close(self.stop)
data = self.wait()
self.assertTrue(data.startswith(b"HTTP/1.0 200"))
self.assertTrue(data.startswith(b"HTTP/1.1 200"))
self.assertTrue(data.endswith(b"Hello"))
def test_read_zero_bytes(self):
self.stream = self._make_client_iostream()
self.stream.connect(("localhost", self.get_http_port()),
self.stream.connect(("127.0.0.1", self.get_http_port()),
callback=self.stop)
self.wait()
self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
@ -70,7 +70,7 @@ class TestIOStreamWebMixin(object):
# normal read
self.stream.read_bytes(9, self.stop)
data = self.wait()
self.assertEqual(data, b"HTTP/1.0 ")
self.assertEqual(data, b"HTTP/1.1 ")
# zero bytes
self.stream.read_bytes(0, self.stop)
@ -91,7 +91,7 @@ class TestIOStreamWebMixin(object):
def connected_callback():
connected[0] = True
self.stop()
stream.connect(("localhost", self.get_http_port()),
stream.connect(("127.0.0.1", self.get_http_port()),
callback=connected_callback)
# unlike the previous tests, try to write before the connection
# is complete.
@ -121,11 +121,11 @@ class TestIOStreamWebMixin(object):
"""Basic test of IOStream's ability to return Futures."""
stream = self._make_client_iostream()
connect_result = yield stream.connect(
("localhost", self.get_http_port()))
("127.0.0.1", self.get_http_port()))
self.assertIs(connect_result, stream)
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
first_line = yield stream.read_until(b"\r\n")
self.assertEqual(first_line, b"HTTP/1.0 200 OK\r\n")
self.assertEqual(first_line, b"HTTP/1.1 200 OK\r\n")
# callback=None is equivalent to no callback.
header_data = yield stream.read_until(b"\r\n\r\n", callback=None)
headers = HTTPHeaders.parse(header_data.decode('latin1'))
@ -137,7 +137,7 @@ class TestIOStreamWebMixin(object):
@gen_test
def test_future_close_while_reading(self):
stream = self._make_client_iostream()
yield stream.connect(("localhost", self.get_http_port()))
yield stream.connect(("127.0.0.1", self.get_http_port()))
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
with self.assertRaises(StreamClosedError):
yield stream.read_bytes(1024 * 1024)
@ -147,7 +147,7 @@ class TestIOStreamWebMixin(object):
def test_future_read_until_close(self):
# Ensure that the data comes through before the StreamClosedError.
stream = self._make_client_iostream()
yield stream.connect(("localhost", self.get_http_port()))
yield stream.connect(("127.0.0.1", self.get_http_port()))
yield stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n")
yield stream.read_until(b"\r\n\r\n")
body = yield stream.read_until_close()
@ -217,17 +217,18 @@ class TestIOStreamMixin(object):
# When a connection is refused, the connect callback should not
# be run. (The kqueue IOLoop used to behave differently from the
# epoll IOLoop in this respect)
server_socket, port = bind_unused_port()
server_socket.close()
cleanup_func, port = refusing_port()
self.addCleanup(cleanup_func)
stream = IOStream(socket.socket(), self.io_loop)
self.connect_called = False
def connect_callback():
self.connect_called = True
self.stop()
stream.set_close_callback(self.stop)
# log messages vary by platform and ioloop implementation
with ExpectLog(gen_log, ".*", required=False):
stream.connect(("localhost", port), connect_callback)
stream.connect(("127.0.0.1", port), connect_callback)
self.wait()
self.assertFalse(self.connect_called)
self.assertTrue(isinstance(stream.error, socket.error), stream.error)
@ -248,7 +249,8 @@ class TestIOStreamMixin(object):
# opendns and some ISPs return bogus addresses for nonexistent
# domains instead of the proper error codes).
with ExpectLog(gen_log, "Connect error"):
stream.connect(('an invalid domain', 54321))
stream.connect(('an invalid domain', 54321), callback=self.stop)
self.wait()
self.assertTrue(isinstance(stream.error, socket.gaierror), stream.error)
def test_read_callback_error(self):
@ -308,6 +310,7 @@ class TestIOStreamMixin(object):
def streaming_callback(data):
chunks.append(data)
self.stop()
def close_callback(data):
assert not data, data
closed[0] = True
@ -325,6 +328,31 @@ class TestIOStreamMixin(object):
server.close()
client.close()
def test_streaming_until_close_future(self):
server, client = self.make_iostream_pair()
try:
chunks = []
@gen.coroutine
def client_task():
yield client.read_until_close(streaming_callback=chunks.append)
@gen.coroutine
def server_task():
yield server.write(b"1234")
yield gen.sleep(0.01)
yield server.write(b"5678")
server.close()
@gen.coroutine
def f():
yield [client_task(), server_task()]
self.io_loop.run_sync(f)
self.assertEqual(chunks, [b"1234", b"5678"])
finally:
server.close()
client.close()
def test_delayed_close_callback(self):
# The scenario: Server closes the connection while there is a pending
# read that can be served out of buffered data. The client does not
@ -353,6 +381,7 @@ class TestIOStreamMixin(object):
def test_future_delayed_close_callback(self):
# Same as test_delayed_close_callback, but with the future interface.
server, client = self.make_iostream_pair()
# We can't call make_iostream_pair inside a gen_test function
# because the ioloop is not reentrant.
@gen_test
@ -532,6 +561,7 @@ class TestIOStreamMixin(object):
# and IOStream._maybe_add_error_listener.
server, client = self.make_iostream_pair()
closed = [False]
def close_callback():
closed[0] = True
self.stop()
@ -724,6 +754,26 @@ class TestIOStreamMixin(object):
server.close()
client.close()
def test_flow_control(self):
MB = 1024 * 1024
server, client = self.make_iostream_pair(max_buffer_size=5 * MB)
try:
# Client writes more than the server will accept.
client.write(b"a" * 10 * MB)
# The server pauses while reading.
server.read_bytes(MB, self.stop)
self.wait()
self.io_loop.call_later(0.1, self.stop)
self.wait()
# The client's writes have been blocked; the server can
# continue to read gradually.
for i in range(9):
server.read_bytes(MB, self.stop)
self.wait()
finally:
server.close()
client.close()
class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase):
def _make_client_iostream(self):
@ -732,7 +782,8 @@ class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase):
class TestIOStreamWebHTTPS(TestIOStreamWebMixin, AsyncHTTPSTestCase):
def _make_client_iostream(self):
return SSLIOStream(socket.socket(), io_loop=self.io_loop)
return SSLIOStream(socket.socket(), io_loop=self.io_loop,
ssl_options=dict(cert_reqs=ssl.CERT_NONE))
class TestIOStream(TestIOStreamMixin, AsyncTestCase):
@ -752,7 +803,9 @@ class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase):
return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
def _make_client_iostream(self, connection, **kwargs):
return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
return SSLIOStream(connection, io_loop=self.io_loop,
ssl_options=dict(cert_reqs=ssl.CERT_NONE),
**kwargs)
# This will run some tests that are basically redundant but it's the
@ -820,10 +873,10 @@ class TestIOStreamStartTLS(AsyncTestCase):
recv_line = yield self.client_stream.read_until(b"\r\n")
self.assertEqual(line, recv_line)
def client_start_tls(self, ssl_options=None):
def client_start_tls(self, ssl_options=None, server_hostname=None):
client_stream = self.client_stream
self.client_stream = None
return client_stream.start_tls(False, ssl_options)
return client_stream.start_tls(False, ssl_options, server_hostname)
def server_start_tls(self, ssl_options=None):
server_stream = self.server_stream
@ -842,7 +895,7 @@ class TestIOStreamStartTLS(AsyncTestCase):
yield self.server_send_line(b"250 STARTTLS\r\n")
yield self.client_send_line(b"STARTTLS\r\n")
yield self.server_send_line(b"220 Go ahead\r\n")
client_future = self.client_start_tls()
client_future = self.client_start_tls(dict(cert_reqs=ssl.CERT_NONE))
server_future = self.server_start_tls(_server_ssl_options())
self.client_stream = yield client_future
self.server_stream = yield server_future
@ -853,12 +906,123 @@ class TestIOStreamStartTLS(AsyncTestCase):
@gen_test
def test_handshake_fail(self):
self.server_start_tls(_server_ssl_options())
client_future = self.client_start_tls(
dict(cert_reqs=ssl.CERT_REQUIRED, ca_certs=certifi.where()))
server_future = self.server_start_tls(_server_ssl_options())
# Certificates are verified with the default configuration.
client_future = self.client_start_tls(server_hostname="localhost")
with ExpectLog(gen_log, "SSL Error"):
with self.assertRaises(ssl.SSLError):
yield client_future
with self.assertRaises((ssl.SSLError, socket.error)):
yield server_future
@unittest.skipIf(not hasattr(ssl, 'create_default_context'),
'ssl.create_default_context not present')
@gen_test
def test_check_hostname(self):
# Test that server_hostname parameter to start_tls is being used.
# The check_hostname functionality is only available in python 2.7 and
# up and in python 3.4 and up.
server_future = self.server_start_tls(_server_ssl_options())
client_future = self.client_start_tls(
ssl.create_default_context(),
server_hostname=b'127.0.0.1')
with ExpectLog(gen_log, "SSL Error"):
with self.assertRaises(ssl.SSLError):
yield client_future
with self.assertRaises((ssl.SSLError, socket.error)):
yield server_future
class WaitForHandshakeTest(AsyncTestCase):
@gen.coroutine
def connect_to_server(self, server_cls):
server = client = None
try:
sock, port = bind_unused_port()
server = server_cls(ssl_options=_server_ssl_options())
server.add_socket(sock)
client = SSLIOStream(socket.socket(),
ssl_options=dict(cert_reqs=ssl.CERT_NONE))
yield client.connect(('127.0.0.1', port))
self.assertIsNotNone(client.socket.cipher())
finally:
if server is not None:
server.stop()
if client is not None:
client.close()
@gen_test
def test_wait_for_handshake_callback(self):
test = self
handshake_future = Future()
class TestServer(TCPServer):
def handle_stream(self, stream, address):
# The handshake has not yet completed.
test.assertIsNone(stream.socket.cipher())
self.stream = stream
stream.wait_for_handshake(self.handshake_done)
def handshake_done(self):
# Now the handshake is done and ssl information is available.
test.assertIsNotNone(self.stream.socket.cipher())
handshake_future.set_result(None)
yield self.connect_to_server(TestServer)
yield handshake_future
@gen_test
def test_wait_for_handshake_future(self):
test = self
handshake_future = Future()
class TestServer(TCPServer):
def handle_stream(self, stream, address):
test.assertIsNone(stream.socket.cipher())
test.io_loop.spawn_callback(self.handle_connection, stream)
@gen.coroutine
def handle_connection(self, stream):
yield stream.wait_for_handshake()
handshake_future.set_result(None)
yield self.connect_to_server(TestServer)
yield handshake_future
@gen_test
def test_wait_for_handshake_already_waiting_error(self):
test = self
handshake_future = Future()
class TestServer(TCPServer):
def handle_stream(self, stream, address):
stream.wait_for_handshake(self.handshake_done)
test.assertRaises(RuntimeError, stream.wait_for_handshake)
def handshake_done(self):
handshake_future.set_result(None)
yield self.connect_to_server(TestServer)
yield handshake_future
@gen_test
def test_wait_for_handshake_already_connected(self):
handshake_future = Future()
class TestServer(TCPServer):
def handle_stream(self, stream, address):
self.stream = stream
stream.wait_for_handshake(self.handshake_done)
def handshake_done(self):
self.stream.wait_for_handshake(self.handshake2_done)
def handshake2_done(self):
handshake_future.set_result(None)
yield self.connect_to_server(TestServer)
yield handshake_future
@skipIfNonUnix

View file

@ -41,6 +41,12 @@ class TranslationLoaderTest(unittest.TestCase):
locale = tornado.locale.get("fr_FR")
self.assertTrue(isinstance(locale, tornado.locale.GettextLocale))
self.assertEqual(locale.translate("school"), u("\u00e9cole"))
self.assertEqual(locale.pgettext("law", "right"), u("le droit"))
self.assertEqual(locale.pgettext("good", "right"), u("le bien"))
self.assertEqual(locale.pgettext("organization", "club", "clubs", 1), u("le club"))
self.assertEqual(locale.pgettext("organization", "club", "clubs", 2), u("les clubs"))
self.assertEqual(locale.pgettext("stick", "club", "clubs", 1), u("le b\xe2ton"))
self.assertEqual(locale.pgettext("stick", "club", "clubs", 2), u("les b\xe2tons"))
class LocaleDataTest(unittest.TestCase):

480
tornado/test/locks_test.py Normal file
View file

@ -0,0 +1,480 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from datetime import timedelta
from tornado import gen, locks
from tornado.gen import TimeoutError
from tornado.testing import gen_test, AsyncTestCase
from tornado.test.util import unittest
class ConditionTest(AsyncTestCase):
def setUp(self):
super(ConditionTest, self).setUp()
self.history = []
def record_done(self, future, key):
"""Record the resolution of a Future returned by Condition.wait."""
def callback(_):
if not future.result():
# wait() resolved to False, meaning it timed out.
self.history.append('timeout')
else:
self.history.append(key)
future.add_done_callback(callback)
def test_repr(self):
c = locks.Condition()
self.assertIn('Condition', repr(c))
self.assertNotIn('waiters', repr(c))
c.wait()
self.assertIn('waiters', repr(c))
@gen_test
def test_notify(self):
c = locks.Condition()
self.io_loop.call_later(0.01, c.notify)
yield c.wait()
def test_notify_1(self):
c = locks.Condition()
self.record_done(c.wait(), 'wait1')
self.record_done(c.wait(), 'wait2')
c.notify(1)
self.history.append('notify1')
c.notify(1)
self.history.append('notify2')
self.assertEqual(['wait1', 'notify1', 'wait2', 'notify2'],
self.history)
def test_notify_n(self):
c = locks.Condition()
for i in range(6):
self.record_done(c.wait(), i)
c.notify(3)
# Callbacks execute in the order they were registered.
self.assertEqual(list(range(3)), self.history)
c.notify(1)
self.assertEqual(list(range(4)), self.history)
c.notify(2)
self.assertEqual(list(range(6)), self.history)
def test_notify_all(self):
c = locks.Condition()
for i in range(4):
self.record_done(c.wait(), i)
c.notify_all()
self.history.append('notify_all')
# Callbacks execute in the order they were registered.
self.assertEqual(
list(range(4)) + ['notify_all'],
self.history)
@gen_test
def test_wait_timeout(self):
c = locks.Condition()
wait = c.wait(timedelta(seconds=0.01))
self.io_loop.call_later(0.02, c.notify) # Too late.
yield gen.sleep(0.03)
self.assertFalse((yield wait))
@gen_test
def test_wait_timeout_preempted(self):
c = locks.Condition()
# This fires before the wait times out.
self.io_loop.call_later(0.01, c.notify)
wait = c.wait(timedelta(seconds=0.02))
yield gen.sleep(0.03)
yield wait # No TimeoutError.
@gen_test
def test_notify_n_with_timeout(self):
# Register callbacks 0, 1, 2, and 3. Callback 1 has a timeout.
# Wait for that timeout to expire, then do notify(2) and make
# sure everyone runs. Verifies that a timed-out callback does
# not count against the 'n' argument to notify().
c = locks.Condition()
self.record_done(c.wait(), 0)
self.record_done(c.wait(timedelta(seconds=0.01)), 1)
self.record_done(c.wait(), 2)
self.record_done(c.wait(), 3)
# Wait for callback 1 to time out.
yield gen.sleep(0.02)
self.assertEqual(['timeout'], self.history)
c.notify(2)
yield gen.sleep(0.01)
self.assertEqual(['timeout', 0, 2], self.history)
self.assertEqual(['timeout', 0, 2], self.history)
c.notify()
self.assertEqual(['timeout', 0, 2, 3], self.history)
@gen_test
def test_notify_all_with_timeout(self):
c = locks.Condition()
self.record_done(c.wait(), 0)
self.record_done(c.wait(timedelta(seconds=0.01)), 1)
self.record_done(c.wait(), 2)
# Wait for callback 1 to time out.
yield gen.sleep(0.02)
self.assertEqual(['timeout'], self.history)
c.notify_all()
self.assertEqual(['timeout', 0, 2], self.history)
@gen_test
def test_nested_notify(self):
# Ensure no notifications lost, even if notify() is reentered by a
# waiter calling notify().
c = locks.Condition()
# Three waiters.
futures = [c.wait() for _ in range(3)]
# First and second futures resolved. Second future reenters notify(),
# resolving third future.
futures[1].add_done_callback(lambda _: c.notify())
c.notify(2)
self.assertTrue(all(f.done() for f in futures))
@gen_test
def test_garbage_collection(self):
# Test that timed-out waiters are occasionally cleaned from the queue.
c = locks.Condition()
for _ in range(101):
c.wait(timedelta(seconds=0.01))
future = c.wait()
self.assertEqual(102, len(c._waiters))
# Let first 101 waiters time out, triggering a collection.
yield gen.sleep(0.02)
self.assertEqual(1, len(c._waiters))
# Final waiter is still active.
self.assertFalse(future.done())
c.notify()
self.assertTrue(future.done())
class EventTest(AsyncTestCase):
def test_repr(self):
event = locks.Event()
self.assertTrue('clear' in str(event))
self.assertFalse('set' in str(event))
event.set()
self.assertFalse('clear' in str(event))
self.assertTrue('set' in str(event))
def test_event(self):
e = locks.Event()
future_0 = e.wait()
e.set()
future_1 = e.wait()
e.clear()
future_2 = e.wait()
self.assertTrue(future_0.done())
self.assertTrue(future_1.done())
self.assertFalse(future_2.done())
@gen_test
def test_event_timeout(self):
e = locks.Event()
with self.assertRaises(TimeoutError):
yield e.wait(timedelta(seconds=0.01))
# After a timed-out waiter, normal operation works.
self.io_loop.add_timeout(timedelta(seconds=0.01), e.set)
yield e.wait(timedelta(seconds=1))
def test_event_set_multiple(self):
e = locks.Event()
e.set()
e.set()
self.assertTrue(e.is_set())
def test_event_wait_clear(self):
e = locks.Event()
f0 = e.wait()
e.clear()
f1 = e.wait()
e.set()
self.assertTrue(f0.done())
self.assertTrue(f1.done())
class SemaphoreTest(AsyncTestCase):
def test_negative_value(self):
self.assertRaises(ValueError, locks.Semaphore, value=-1)
def test_repr(self):
sem = locks.Semaphore()
self.assertIn('Semaphore', repr(sem))
self.assertIn('unlocked,value:1', repr(sem))
sem.acquire()
self.assertIn('locked', repr(sem))
self.assertNotIn('waiters', repr(sem))
sem.acquire()
self.assertIn('waiters', repr(sem))
def test_acquire(self):
sem = locks.Semaphore()
f0 = sem.acquire()
self.assertTrue(f0.done())
# Wait for release().
f1 = sem.acquire()
self.assertFalse(f1.done())
f2 = sem.acquire()
sem.release()
self.assertTrue(f1.done())
self.assertFalse(f2.done())
sem.release()
self.assertTrue(f2.done())
sem.release()
# Now acquire() is instant.
self.assertTrue(sem.acquire().done())
self.assertEqual(0, len(sem._waiters))
@gen_test
def test_acquire_timeout(self):
sem = locks.Semaphore(2)
yield sem.acquire()
yield sem.acquire()
acquire = sem.acquire(timedelta(seconds=0.01))
self.io_loop.call_later(0.02, sem.release) # Too late.
yield gen.sleep(0.3)
with self.assertRaises(gen.TimeoutError):
yield acquire
sem.acquire()
f = sem.acquire()
self.assertFalse(f.done())
sem.release()
self.assertTrue(f.done())
@gen_test
def test_acquire_timeout_preempted(self):
sem = locks.Semaphore(1)
yield sem.acquire()
# This fires before the wait times out.
self.io_loop.call_later(0.01, sem.release)
acquire = sem.acquire(timedelta(seconds=0.02))
yield gen.sleep(0.03)
yield acquire # No TimeoutError.
def test_release_unacquired(self):
# Unbounded releases are allowed, and increment the semaphore's value.
sem = locks.Semaphore()
sem.release()
sem.release()
# Now the counter is 3. We can acquire three times before blocking.
self.assertTrue(sem.acquire().done())
self.assertTrue(sem.acquire().done())
self.assertTrue(sem.acquire().done())
self.assertFalse(sem.acquire().done())
@gen_test
def test_garbage_collection(self):
# Test that timed-out waiters are occasionally cleaned from the queue.
sem = locks.Semaphore(value=0)
futures = [sem.acquire(timedelta(seconds=0.01)) for _ in range(101)]
future = sem.acquire()
self.assertEqual(102, len(sem._waiters))
# Let first 101 waiters time out, triggering a collection.
yield gen.sleep(0.02)
self.assertEqual(1, len(sem._waiters))
# Final waiter is still active.
self.assertFalse(future.done())
sem.release()
self.assertTrue(future.done())
# Prevent "Future exception was never retrieved" messages.
for future in futures:
self.assertRaises(TimeoutError, future.result)
class SemaphoreContextManagerTest(AsyncTestCase):
@gen_test
def test_context_manager(self):
sem = locks.Semaphore()
with (yield sem.acquire()) as yielded:
self.assertTrue(yielded is None)
# Semaphore was released and can be acquired again.
self.assertTrue(sem.acquire().done())
@gen_test
def test_context_manager_exception(self):
sem = locks.Semaphore()
with self.assertRaises(ZeroDivisionError):
with (yield sem.acquire()):
1 / 0
# Semaphore was released and can be acquired again.
self.assertTrue(sem.acquire().done())
@gen_test
def test_context_manager_timeout(self):
sem = locks.Semaphore()
with (yield sem.acquire(timedelta(seconds=0.01))):
pass
# Semaphore was released and can be acquired again.
self.assertTrue(sem.acquire().done())
@gen_test
def test_context_manager_timeout_error(self):
sem = locks.Semaphore(value=0)
with self.assertRaises(gen.TimeoutError):
with (yield sem.acquire(timedelta(seconds=0.01))):
pass
# Counter is still 0.
self.assertFalse(sem.acquire().done())
@gen_test
def test_context_manager_contended(self):
sem = locks.Semaphore()
history = []
@gen.coroutine
def f(index):
with (yield sem.acquire()):
history.append('acquired %d' % index)
yield gen.sleep(0.01)
history.append('release %d' % index)
yield [f(i) for i in range(2)]
expected_history = []
for i in range(2):
expected_history.extend(['acquired %d' % i, 'release %d' % i])
self.assertEqual(expected_history, history)
@gen_test
def test_yield_sem(self):
# Ensure we catch a "with (yield sem)", which should be
# "with (yield sem.acquire())".
with self.assertRaises(gen.BadYieldError):
with (yield locks.Semaphore()):
pass
def test_context_manager_misuse(self):
# Ensure we catch a "with sem", which should be
# "with (yield sem.acquire())".
with self.assertRaises(RuntimeError):
with locks.Semaphore():
pass
class BoundedSemaphoreTest(AsyncTestCase):
def test_release_unacquired(self):
sem = locks.BoundedSemaphore()
self.assertRaises(ValueError, sem.release)
# Value is 0.
sem.acquire()
# Block on acquire().
future = sem.acquire()
self.assertFalse(future.done())
sem.release()
self.assertTrue(future.done())
# Value is 1.
sem.release()
self.assertRaises(ValueError, sem.release)
class LockTests(AsyncTestCase):
def test_repr(self):
lock = locks.Lock()
# No errors.
repr(lock)
lock.acquire()
repr(lock)
def test_acquire_release(self):
lock = locks.Lock()
self.assertTrue(lock.acquire().done())
future = lock.acquire()
self.assertFalse(future.done())
lock.release()
self.assertTrue(future.done())
@gen_test
def test_acquire_fifo(self):
lock = locks.Lock()
self.assertTrue(lock.acquire().done())
N = 5
history = []
@gen.coroutine
def f(idx):
with (yield lock.acquire()):
history.append(idx)
futures = [f(i) for i in range(N)]
self.assertFalse(any(future.done() for future in futures))
lock.release()
yield futures
self.assertEqual(list(range(N)), history)
@gen_test
def test_acquire_timeout(self):
lock = locks.Lock()
lock.acquire()
with self.assertRaises(gen.TimeoutError):
yield lock.acquire(timeout=timedelta(seconds=0.01))
# Still locked.
self.assertFalse(lock.acquire().done())
def test_multi_release(self):
lock = locks.Lock()
self.assertRaises(RuntimeError, lock.release)
lock.acquire()
lock.release()
self.assertRaises(RuntimeError, lock.release)
@gen_test
def test_yield_lock(self):
# Ensure we catch a "with (yield lock)", which should be
# "with (yield lock.acquire())".
with self.assertRaises(gen.BadYieldError):
with (yield locks.Lock()):
pass
def test_context_manager_misuse(self):
# Ensure we catch a "with lock", which should be
# "with (yield lock.acquire())".
with self.assertRaises(RuntimeError):
with locks.Lock():
pass
if __name__ == '__main__':
unittest.main()

View file

@ -67,10 +67,12 @@ class _ResolverErrorTestMixin(object):
yield self.resolver.resolve('an invalid domain', 80,
socket.AF_UNSPEC)
def _failing_getaddrinfo(*args):
"""Dummy implementation of getaddrinfo for use in mocks"""
raise socket.gaierror("mock: lookup failed")
@skipIfNoNetwork
class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):

View file

@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop
from tornado.log import gen_log
from tornado.process import fork_processes, task_id, Subprocess
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase
from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase, gen_test
from tornado.test.util import unittest, skipIfNonUnix
from tornado.web import RequestHandler, Application
@ -85,7 +85,7 @@ class ProcessTest(unittest.TestCase):
self.assertEqual(id, task_id())
server = HTTPServer(self.get_app())
server.add_sockets([sock])
IOLoop.instance().start()
IOLoop.current().start()
elif id == 2:
self.assertEqual(id, task_id())
sock.close()
@ -200,6 +200,16 @@ class SubprocessTest(AsyncTestCase):
self.assertEqual(ret, 0)
self.assertEqual(subproc.returncode, ret)
@gen_test
def test_sigchild_future(self):
skip_if_twisted()
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, '-c', 'pass'])
ret = yield subproc.wait_for_exit()
self.assertEqual(ret, 0)
self.assertEqual(subproc.returncode, ret)
def test_sigchild_signal(self):
skip_if_twisted()
Subprocess.initialize(io_loop=self.io_loop)
@ -212,3 +222,22 @@ class SubprocessTest(AsyncTestCase):
ret = self.wait()
self.assertEqual(subproc.returncode, ret)
self.assertEqual(ret, -signal.SIGTERM)
@gen_test
def test_wait_for_exit_raise(self):
skip_if_twisted()
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, '-c', 'import sys; sys.exit(1)'])
with self.assertRaises(subprocess.CalledProcessError) as cm:
yield subproc.wait_for_exit()
self.assertEqual(cm.exception.returncode, 1)
@gen_test
def test_wait_for_exit_raise_disabled(self):
skip_if_twisted()
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, '-c', 'import sys; sys.exit(1)'])
ret = yield subproc.wait_for_exit(raise_error=False)
self.assertEqual(ret, 1)

403
tornado/test/queues_test.py Normal file
View file

@ -0,0 +1,403 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from datetime import timedelta
from random import random
from tornado import gen, queues
from tornado.gen import TimeoutError
from tornado.testing import gen_test, AsyncTestCase
from tornado.test.util import unittest
class QueueBasicTest(AsyncTestCase):
def test_repr_and_str(self):
q = queues.Queue(maxsize=1)
self.assertIn(hex(id(q)), repr(q))
self.assertNotIn(hex(id(q)), str(q))
q.get()
for q_str in repr(q), str(q):
self.assertTrue(q_str.startswith('<Queue'))
self.assertIn('maxsize=1', q_str)
self.assertIn('getters[1]', q_str)
self.assertNotIn('putters', q_str)
self.assertNotIn('tasks', q_str)
q.put(None)
q.put(None)
# Now the queue is full, this putter blocks.
q.put(None)
for q_str in repr(q), str(q):
self.assertNotIn('getters', q_str)
self.assertIn('putters[1]', q_str)
self.assertIn('tasks=2', q_str)
def test_order(self):
q = queues.Queue()
for i in [1, 3, 2]:
q.put_nowait(i)
items = [q.get_nowait() for _ in range(3)]
self.assertEqual([1, 3, 2], items)
@gen_test
def test_maxsize(self):
self.assertRaises(TypeError, queues.Queue, maxsize=None)
self.assertRaises(ValueError, queues.Queue, maxsize=-1)
q = queues.Queue(maxsize=2)
self.assertTrue(q.empty())
self.assertFalse(q.full())
self.assertEqual(2, q.maxsize)
self.assertTrue(q.put(0).done())
self.assertTrue(q.put(1).done())
self.assertFalse(q.empty())
self.assertTrue(q.full())
put2 = q.put(2)
self.assertFalse(put2.done())
self.assertEqual(0, (yield q.get())) # Make room.
self.assertTrue(put2.done())
self.assertFalse(q.empty())
self.assertTrue(q.full())
class QueueGetTest(AsyncTestCase):
@gen_test
def test_blocking_get(self):
q = queues.Queue()
q.put_nowait(0)
self.assertEqual(0, (yield q.get()))
def test_nonblocking_get(self):
q = queues.Queue()
q.put_nowait(0)
self.assertEqual(0, q.get_nowait())
def test_nonblocking_get_exception(self):
q = queues.Queue()
self.assertRaises(queues.QueueEmpty, q.get_nowait)
@gen_test
def test_get_with_putters(self):
q = queues.Queue(1)
q.put_nowait(0)
put = q.put(1)
self.assertEqual(0, (yield q.get()))
self.assertIsNone((yield put))
@gen_test
def test_blocking_get_wait(self):
q = queues.Queue()
q.put(0)
self.io_loop.call_later(0.01, q.put, 1)
self.io_loop.call_later(0.02, q.put, 2)
self.assertEqual(0, (yield q.get(timeout=timedelta(seconds=1))))
self.assertEqual(1, (yield q.get(timeout=timedelta(seconds=1))))
@gen_test
def test_get_timeout(self):
q = queues.Queue()
get_timeout = q.get(timeout=timedelta(seconds=0.01))
get = q.get()
with self.assertRaises(TimeoutError):
yield get_timeout
q.put_nowait(0)
self.assertEqual(0, (yield get))
@gen_test
def test_get_timeout_preempted(self):
q = queues.Queue()
get = q.get(timeout=timedelta(seconds=0.01))
q.put(0)
yield gen.sleep(0.02)
self.assertEqual(0, (yield get))
@gen_test
def test_get_clears_timed_out_putters(self):
q = queues.Queue(1)
# First putter succeeds, remainder block.
putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)]
put = q.put(10)
self.assertEqual(10, len(q._putters))
yield gen.sleep(0.02)
self.assertEqual(10, len(q._putters))
self.assertFalse(put.done()) # Final waiter is still active.
q.put(11)
self.assertEqual(0, (yield q.get())) # get() clears the waiters.
self.assertEqual(1, len(q._putters))
for putter in putters[1:]:
self.assertRaises(TimeoutError, putter.result)
@gen_test
def test_get_clears_timed_out_getters(self):
q = queues.Queue()
getters = [q.get(timedelta(seconds=0.01)) for _ in range(10)]
get = q.get()
self.assertEqual(11, len(q._getters))
yield gen.sleep(0.02)
self.assertEqual(11, len(q._getters))
self.assertFalse(get.done()) # Final waiter is still active.
q.get() # get() clears the waiters.
self.assertEqual(2, len(q._getters))
for getter in getters:
self.assertRaises(TimeoutError, getter.result)
class QueuePutTest(AsyncTestCase):
@gen_test
def test_blocking_put(self):
q = queues.Queue()
q.put(0)
self.assertEqual(0, q.get_nowait())
def test_nonblocking_put_exception(self):
q = queues.Queue(1)
q.put(0)
self.assertRaises(queues.QueueFull, q.put_nowait, 1)
@gen_test
def test_put_with_getters(self):
q = queues.Queue()
get0 = q.get()
get1 = q.get()
yield q.put(0)
self.assertEqual(0, (yield get0))
yield q.put(1)
self.assertEqual(1, (yield get1))
@gen_test
def test_nonblocking_put_with_getters(self):
q = queues.Queue()
get0 = q.get()
get1 = q.get()
q.put_nowait(0)
# put_nowait does *not* immediately unblock getters.
yield gen.moment
self.assertEqual(0, (yield get0))
q.put_nowait(1)
yield gen.moment
self.assertEqual(1, (yield get1))
@gen_test
def test_blocking_put_wait(self):
q = queues.Queue(1)
q.put_nowait(0)
self.io_loop.call_later(0.01, q.get)
self.io_loop.call_later(0.02, q.get)
futures = [q.put(0), q.put(1)]
self.assertFalse(any(f.done() for f in futures))
yield futures
@gen_test
def test_put_timeout(self):
q = queues.Queue(1)
q.put_nowait(0) # Now it's full.
put_timeout = q.put(1, timeout=timedelta(seconds=0.01))
put = q.put(2)
with self.assertRaises(TimeoutError):
yield put_timeout
self.assertEqual(0, q.get_nowait())
# 1 was never put in the queue.
self.assertEqual(2, (yield q.get()))
# Final get() unblocked this putter.
yield put
@gen_test
def test_put_timeout_preempted(self):
q = queues.Queue(1)
q.put_nowait(0)
put = q.put(1, timeout=timedelta(seconds=0.01))
q.get()
yield gen.sleep(0.02)
yield put # No TimeoutError.
@gen_test
def test_put_clears_timed_out_putters(self):
q = queues.Queue(1)
# First putter succeeds, remainder block.
putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)]
put = q.put(10)
self.assertEqual(10, len(q._putters))
yield gen.sleep(0.02)
self.assertEqual(10, len(q._putters))
self.assertFalse(put.done()) # Final waiter is still active.
q.put(11) # put() clears the waiters.
self.assertEqual(2, len(q._putters))
for putter in putters[1:]:
self.assertRaises(TimeoutError, putter.result)
@gen_test
def test_put_clears_timed_out_getters(self):
q = queues.Queue()
getters = [q.get(timedelta(seconds=0.01)) for _ in range(10)]
get = q.get()
q.get()
self.assertEqual(12, len(q._getters))
yield gen.sleep(0.02)
self.assertEqual(12, len(q._getters))
self.assertFalse(get.done()) # Final waiters still active.
q.put(0) # put() clears the waiters.
self.assertEqual(1, len(q._getters))
self.assertEqual(0, (yield get))
for getter in getters:
self.assertRaises(TimeoutError, getter.result)
@gen_test
def test_float_maxsize(self):
# Non-int maxsize must round down: http://bugs.python.org/issue21723
q = queues.Queue(maxsize=1.3)
self.assertTrue(q.empty())
self.assertFalse(q.full())
q.put_nowait(0)
q.put_nowait(1)
self.assertFalse(q.empty())
self.assertTrue(q.full())
self.assertRaises(queues.QueueFull, q.put_nowait, 2)
self.assertEqual(0, q.get_nowait())
self.assertFalse(q.empty())
self.assertFalse(q.full())
yield q.put(2)
put = q.put(3)
self.assertFalse(put.done())
self.assertEqual(1, (yield q.get()))
yield put
self.assertTrue(q.full())
class QueueJoinTest(AsyncTestCase):
queue_class = queues.Queue
def test_task_done_underflow(self):
q = self.queue_class()
self.assertRaises(ValueError, q.task_done)
@gen_test
def test_task_done(self):
q = self.queue_class()
for i in range(100):
q.put_nowait(i)
self.accumulator = 0
@gen.coroutine
def worker():
while True:
item = yield q.get()
self.accumulator += item
q.task_done()
yield gen.sleep(random() * 0.01)
# Two coroutines share work.
worker()
worker()
yield q.join()
self.assertEqual(sum(range(100)), self.accumulator)
@gen_test
def test_task_done_delay(self):
# Verify it is task_done(), not get(), that unblocks join().
q = self.queue_class()
q.put_nowait(0)
join = q.join()
self.assertFalse(join.done())
yield q.get()
self.assertFalse(join.done())
yield gen.moment
self.assertFalse(join.done())
q.task_done()
self.assertTrue(join.done())
@gen_test
def test_join_empty_queue(self):
q = self.queue_class()
yield q.join()
yield q.join()
@gen_test
def test_join_timeout(self):
q = self.queue_class()
q.put(0)
with self.assertRaises(TimeoutError):
yield q.join(timeout=timedelta(seconds=0.01))
class PriorityQueueJoinTest(QueueJoinTest):
queue_class = queues.PriorityQueue
@gen_test
def test_order(self):
q = self.queue_class(maxsize=2)
q.put_nowait((1, 'a'))
q.put_nowait((0, 'b'))
self.assertTrue(q.full())
q.put((3, 'c'))
q.put((2, 'd'))
self.assertEqual((0, 'b'), q.get_nowait())
self.assertEqual((1, 'a'), (yield q.get()))
self.assertEqual((2, 'd'), q.get_nowait())
self.assertEqual((3, 'c'), (yield q.get()))
self.assertTrue(q.empty())
class LifoQueueJoinTest(QueueJoinTest):
queue_class = queues.LifoQueue
@gen_test
def test_order(self):
q = self.queue_class(maxsize=2)
q.put_nowait(1)
q.put_nowait(0)
self.assertTrue(q.full())
q.put(3)
q.put(2)
self.assertEqual(3, q.get_nowait())
self.assertEqual(2, (yield q.get()))
self.assertEqual(0, q.get_nowait())
self.assertEqual(1, (yield q.get()))
self.assertTrue(q.empty())
class ProducerConsumerTest(AsyncTestCase):
@gen_test
def test_producer_consumer(self):
q = queues.Queue(maxsize=3)
history = []
# We don't yield between get() and task_done(), so get() must wait for
# the next tick. Otherwise we'd immediately call task_done and unblock
# join() before q.put() resumes, and we'd only process the first four
# items.
@gen.coroutine
def consumer():
while True:
history.append((yield q.get()))
q.task_done()
@gen.coroutine
def producer():
for item in range(10):
yield q.put(item)
consumer()
yield producer()
yield q.join()
self.assertEqual(list(range(10)), history)
if __name__ == '__main__':
unittest.main()

View file

@ -8,6 +8,7 @@ import operator
import textwrap
import sys
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.netutil import Resolver
from tornado.options import define, options, add_parse_callback
@ -22,6 +23,7 @@ TEST_MODULES = [
'tornado.httputil.doctests',
'tornado.iostream.doctests',
'tornado.util.doctests',
'tornado.test.asyncio_test',
'tornado.test.auth_test',
'tornado.test.concurrent_test',
'tornado.test.curl_httpclient_test',
@ -34,13 +36,16 @@ TEST_MODULES = [
'tornado.test.ioloop_test',
'tornado.test.iostream_test',
'tornado.test.locale_test',
'tornado.test.locks_test',
'tornado.test.netutil_test',
'tornado.test.log_test',
'tornado.test.options_test',
'tornado.test.process_test',
'tornado.test.queues_test',
'tornado.test.simple_httpclient_test',
'tornado.test.stack_context_test',
'tornado.test.tcpclient_test',
'tornado.test.tcpserver_test',
'tornado.test.template_test',
'tornado.test.testing_test',
'tornado.test.twisted_test',
@ -67,6 +72,21 @@ class TornadoTextTestRunner(unittest.TextTestRunner):
return result
class LogCounter(logging.Filter):
"""Counts the number of WARNING or higher log records."""
def __init__(self, *args, **kwargs):
# Can't use super() because logging.Filter is an old-style class in py26
logging.Filter.__init__(self, *args, **kwargs)
self.warning_count = self.error_count = 0
def filter(self, record):
if record.levelno >= logging.ERROR:
self.error_count += 1
elif record.levelno >= logging.WARNING:
self.warning_count += 1
return True
def main():
# The -W command-line option does not work in a virtualenv with
# python 3 (as of virtualenv 1.7), so configure warnings
@ -92,12 +112,21 @@ def main():
# 2.7 and 3.2
warnings.filterwarnings("ignore", category=DeprecationWarning,
message="Please use assert.* instead")
# unittest2 0.6 on py26 reports these as PendingDeprecationWarnings
# instead of DeprecationWarnings.
warnings.filterwarnings("ignore", category=PendingDeprecationWarning,
message="Please use assert.* instead")
# Twisted 15.0.0 triggers some warnings on py3 with -bb.
warnings.filterwarnings("ignore", category=BytesWarning,
module=r"twisted\..*")
logging.getLogger("tornado.access").setLevel(logging.CRITICAL)
define('httpclient', type=str, default=None,
callback=lambda s: AsyncHTTPClient.configure(
s, defaults=dict(allow_ipv6=False)))
define('httpserver', type=str, default=None,
callback=HTTPServer.configure)
define('ioloop', type=str, default=None)
define('ioloop_time_monotonic', default=False)
define('resolver', type=str, default=None,
@ -121,6 +150,10 @@ def main():
IOLoop.configure(options.ioloop, **kwargs)
add_parse_callback(configure_ioloop)
log_counter = LogCounter()
add_parse_callback(
lambda: logging.getLogger().handlers[0].addFilter(log_counter))
import tornado.testing
kwargs = {}
if sys.version_info >= (3, 2):
@ -131,7 +164,16 @@ def main():
# detail. http://bugs.python.org/issue15626
kwargs['warnings'] = False
kwargs['testRunner'] = TornadoTextTestRunner
try:
tornado.testing.main(**kwargs)
finally:
# The tests should run clean; consider it a failure if they logged
# any warnings or errors. We'd like to ban info logs too, but
# we can't count them cleanly due to interactions with LogTrapTestCase.
if log_counter.warning_count > 0 or log_counter.error_count > 0:
logging.error("logged %d warnings and %d errors",
log_counter.warning_count, log_counter.error_count)
sys.exit(1)
if __name__ == '__main__':
main()

View file

@ -8,19 +8,20 @@ import logging
import os
import re
import socket
import ssl
import sys
from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httputil import HTTPHeaders
from tornado.httputil import HTTPHeaders, ResponseStartLine
from tornado.ioloop import IOLoop
from tornado.log import gen_log
from tornado.netutil import Resolver, bind_sockets
from tornado.simple_httpclient import SimpleAsyncHTTPClient, _default_ca_certs
from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler
from tornado.test import httpclient_test
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
from tornado.test.util import skipOnTravis, skipIfNoIPv6
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog
from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port, unittest
from tornado.web import RequestHandler, Application, asynchronous, url, stream_request_body
@ -97,15 +98,18 @@ class HostEchoHandler(RequestHandler):
class NoContentLengthHandler(RequestHandler):
@gen.coroutine
@asynchronous
def get(self):
if self.request.version.startswith('HTTP/1'):
# Emulate the old HTTP/1.0 behavior of returning a body with no
# content-length. Tornado handles content-length at the framework
# level so we have to go around it.
stream = self.request.connection.stream
yield stream.write(b"HTTP/1.0 200 OK\r\n\r\n"
stream = self.request.connection.detach()
stream.write(b"HTTP/1.0 200 OK\r\n\r\n"
b"hello")
stream.close()
else:
self.finish('HTTP/1 required')
class EchoPostHandler(RequestHandler):
@ -191,9 +195,6 @@ class SimpleHTTPClientTestMixin(object):
response = self.wait()
response.rethrow()
def test_default_certificates_exist(self):
open(_default_ca_certs()).close()
def test_gzip(self):
# All the tests in this file should be using gzip, but this test
# ensures that it is in fact getting compressed.
@ -235,9 +236,16 @@ class SimpleHTTPClientTestMixin(object):
@skipOnTravis
def test_request_timeout(self):
response = self.fetch('/trigger?wake=false', request_timeout=0.1)
timeout = 0.1
timeout_min, timeout_max = 0.099, 0.15
if os.name == 'nt':
timeout = 0.5
timeout_min, timeout_max = 0.4, 0.6
response = self.fetch('/trigger?wake=false', request_timeout=timeout)
self.assertEqual(response.code, 599)
self.assertTrue(0.099 < response.request_time < 0.15, response.request_time)
self.assertTrue(timeout_min < response.request_time < timeout_max,
response.request_time)
self.assertEqual(str(response.error), "HTTP 599: Timeout")
# trigger the hanging request to let it clean up after itself
self.triggers.popleft()()
@ -315,10 +323,10 @@ class SimpleHTTPClientTestMixin(object):
self.assertTrue(host_re.match(response.body), response.body)
def test_connection_refused(self):
server_socket, port = bind_unused_port()
server_socket.close()
cleanup_func, port = refusing_port()
self.addCleanup(cleanup_func)
with ExpectLog(gen_log, ".*", required=False):
self.http_client.fetch("http://localhost:%d/" % port, self.stop)
self.http_client.fetch("http://127.0.0.1:%d/" % port, self.stop)
response = self.wait()
self.assertEqual(599, response.code)
@ -352,6 +360,9 @@ class SimpleHTTPClientTestMixin(object):
def test_no_content_length(self):
response = self.fetch("/no_content_length")
if response.body == b"HTTP/1 required":
self.skipTest("requires HTTP/1.x")
else:
self.assertEquals(b"hello", response.body)
def sync_body_producer(self, write):
@ -425,6 +436,33 @@ class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
defaults=dict(validate_cert=False),
**kwargs)
def test_ssl_options(self):
resp = self.fetch("/hello", ssl_options={})
self.assertEqual(resp.body, b"Hello world!")
@unittest.skipIf(not hasattr(ssl, 'SSLContext'),
'ssl.SSLContext not present')
def test_ssl_context(self):
resp = self.fetch("/hello",
ssl_options=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
self.assertEqual(resp.body, b"Hello world!")
def test_ssl_options_handshake_fail(self):
with ExpectLog(gen_log, "SSL Error|Uncaught exception",
required=False):
resp = self.fetch(
"/hello", ssl_options=dict(cert_reqs=ssl.CERT_REQUIRED))
self.assertRaises(ssl.SSLError, resp.rethrow)
@unittest.skipIf(not hasattr(ssl, 'SSLContext'),
'ssl.SSLContext not present')
def test_ssl_context_handshake_fail(self):
with ExpectLog(gen_log, "SSL Error|Uncaught exception"):
ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
ctx.verify_mode = ssl.CERT_REQUIRED
resp = self.fetch("/hello", ssl_options=ctx)
self.assertRaises(ssl.SSLError, resp.rethrow)
class CreateAsyncHTTPClientTestCase(AsyncTestCase):
def setUp(self):
@ -460,6 +498,12 @@ class CreateAsyncHTTPClientTestCase(AsyncTestCase):
class HTTP100ContinueTestCase(AsyncHTTPTestCase):
def respond_100(self, request):
self.http1 = request.version.startswith('HTTP/1.')
if not self.http1:
request.connection.write_headers(ResponseStartLine('', 200, 'OK'),
HTTPHeaders())
request.connection.finish()
return
self.request = request
self.request.connection.stream.write(
b"HTTP/1.1 100 CONTINUE\r\n\r\n",
@ -476,11 +520,20 @@ class HTTP100ContinueTestCase(AsyncHTTPTestCase):
def test_100_continue(self):
res = self.fetch('/')
if not self.http1:
self.skipTest("requires HTTP/1.x")
self.assertEqual(res.body, b'A')
class HTTP204NoContentTestCase(AsyncHTTPTestCase):
def respond_204(self, request):
self.http1 = request.version.startswith('HTTP/1.')
if not self.http1:
# Close the request cleanly in HTTP/2; it will be skipped anyway.
request.connection.write_headers(ResponseStartLine('', 200, 'OK'),
HTTPHeaders())
request.connection.finish()
return
# A 204 response never has a body, even if doesn't have a content-length
# (which would otherwise mean read-until-close). Tornado always
# sends a content-length, so we simulate here a server that sends
@ -488,14 +541,18 @@ class HTTP204NoContentTestCase(AsyncHTTPTestCase):
#
# Tests of a 204 response with a Content-Length header are included
# in SimpleHTTPClientTestMixin.
request.connection.stream.write(
stream = request.connection.detach()
stream.write(
b"HTTP/1.1 204 No content\r\n\r\n")
stream.close()
def get_app(self):
return self.respond_204
def test_204_no_content(self):
resp = self.fetch('/')
if not self.http1:
self.skipTest("requires HTTP/1.x")
self.assertEqual(resp.code, 204)
self.assertEqual(resp.body, b'')
@ -574,3 +631,49 @@ class MaxHeaderSizeTest(AsyncHTTPTestCase):
with ExpectLog(gen_log, "Unsatisfiable read"):
response = self.fetch('/large')
self.assertEqual(response.code, 599)
class MaxBodySizeTest(AsyncHTTPTestCase):
def get_app(self):
class SmallBody(RequestHandler):
def get(self):
self.write("a"*1024*64)
class LargeBody(RequestHandler):
def get(self):
self.write("a"*1024*100)
return Application([('/small', SmallBody),
('/large', LargeBody)])
def get_http_client(self):
return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_body_size=1024*64)
def test_small_body(self):
response = self.fetch('/small')
response.rethrow()
self.assertEqual(response.body, b'a'*1024*64)
def test_large_body(self):
with ExpectLog(gen_log, "Malformed HTTP message from None: Content-Length too long"):
response = self.fetch('/large')
self.assertEqual(response.code, 599)
class MaxBufferSizeTest(AsyncHTTPTestCase):
def get_app(self):
class LargeBody(RequestHandler):
def get(self):
self.write("a"*1024*100)
return Application([('/large', LargeBody)])
def get_http_client(self):
# 100KB body with 64KB buffer
return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_body_size=1024*100, max_buffer_size=1024*64)
def test_large_body(self):
response = self.fetch('/large')
response.rethrow()
self.assertEqual(response.body, b'a'*1024*100)

View file

@ -24,8 +24,8 @@ from tornado.concurrent import Future
from tornado.netutil import bind_sockets, Resolver
from tornado.tcpclient import TCPClient, _Connector
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, bind_unused_port, gen_test
from tornado.test.util import skipIfNoIPv6, unittest
from tornado.testing import AsyncTestCase, gen_test
from tornado.test.util import skipIfNoIPv6, unittest, refusing_port
# Fake address families for testing. Used in place of AF_INET
# and AF_INET6 because some installations do not have AF_INET6.
@ -120,8 +120,8 @@ class TCPClientTest(AsyncTestCase):
@gen_test
def test_refused_ipv4(self):
sock, port = bind_unused_port()
sock.close()
cleanup_func, port = refusing_port()
self.addCleanup(cleanup_func)
with self.assertRaises(IOError):
yield self.client.connect('127.0.0.1', port)

View file

@ -0,0 +1,38 @@
import socket
from tornado import gen
from tornado.iostream import IOStream
from tornado.log import app_log
from tornado.stack_context import NullContext
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test
class TCPServerTest(AsyncTestCase):
@gen_test
def test_handle_stream_coroutine_logging(self):
# handle_stream may be a coroutine and any exception in its
# Future will be logged.
class TestServer(TCPServer):
@gen.coroutine
def handle_stream(self, stream, address):
yield gen.moment
stream.close()
1/0
server = client = None
try:
sock, port = bind_unused_port()
with NullContext():
server = TestServer()
server.add_socket(sock)
client = IOStream(socket.socket())
with ExpectLog(app_log, "Exception in callback"):
yield client.connect(('localhost', port))
yield client.read_until_close()
yield gen.moment
finally:
if server is not None:
server.stop()
if client is not None:
client.close()

View file

@ -19,15 +19,18 @@ Unittest for the twisted-style reactor.
from __future__ import absolute_import, division, print_function, with_statement
import logging
import os
import shutil
import signal
import sys
import tempfile
import threading
import warnings
try:
import fcntl
from twisted.internet.defer import Deferred
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
from twisted.internet.interfaces import IReadDescriptor, IWriteDescriptor
from twisted.internet.protocol import Protocol
from twisted.python import log
@ -40,10 +43,12 @@ except ImportError:
# The core of Twisted 12.3.0 is available on python 3, but twisted.web is not
# so test for it separately.
try:
from twisted.web.client import Agent
from twisted.web.client import Agent, readBody
from twisted.web.resource import Resource
from twisted.web.server import Site
have_twisted_web = True
# As of Twisted 15.0.0, twisted.web is present but fails our
# tests due to internal str/bytes errors.
have_twisted_web = sys.version_info < (3,)
except ImportError:
have_twisted_web = False
@ -52,6 +57,8 @@ try:
except ImportError:
import _thread as thread # py3
from tornado.escape import utf8
from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
@ -65,6 +72,9 @@ from tornado.web import RequestHandler, Application
skipIfNoTwisted = unittest.skipUnless(have_twisted,
"twisted module not present")
skipIfNoSingleDispatch = unittest.skipIf(
gen.singledispatch is None, "singledispatch module not present")
def save_signal_handlers():
saved = {}
@ -407,7 +417,7 @@ class CompatibilityTests(unittest.TestCase):
# http://twistedmatrix.com/documents/current/web/howto/client.html
chunks = []
client = Agent(self.reactor)
d = client.request('GET', url)
d = client.request(b'GET', utf8(url))
class Accumulator(Protocol):
def __init__(self, finished):
@ -425,37 +435,98 @@ class CompatibilityTests(unittest.TestCase):
return finished
d.addCallback(callback)
def shutdown(ignored):
def shutdown(failure):
if hasattr(self, 'stop_loop'):
self.stop_loop()
elif failure is not None:
# loop hasn't been initialized yet; try our best to
# get an error message out. (the runner() interaction
# should probably be refactored).
try:
failure.raiseException()
except:
logging.error('exception before starting loop', exc_info=True)
d.addBoth(shutdown)
runner()
self.assertTrue(chunks)
return ''.join(chunks)
def twisted_coroutine_fetch(self, url, runner):
body = [None]
@gen.coroutine
def f():
# This is simpler than the non-coroutine version, but it cheats
# by reading the body in one blob instead of streaming it with
# a Protocol.
client = Agent(self.reactor)
response = yield client.request(b'GET', utf8(url))
with warnings.catch_warnings():
# readBody has a buggy DeprecationWarning in Twisted 15.0:
# https://twistedmatrix.com/trac/changeset/43379
warnings.simplefilter('ignore', category=DeprecationWarning)
body[0] = yield readBody(response)
self.stop_loop()
self.io_loop.add_callback(f)
runner()
return body[0]
def testTwistedServerTornadoClientIOLoop(self):
self.start_twisted_server()
response = self.tornado_fetch(
'http://localhost:%d' % self.twisted_port, self.run_ioloop)
'http://127.0.0.1:%d' % self.twisted_port, self.run_ioloop)
self.assertEqual(response.body, 'Hello from twisted!')
def testTwistedServerTornadoClientReactor(self):
self.start_twisted_server()
response = self.tornado_fetch(
'http://localhost:%d' % self.twisted_port, self.run_reactor)
'http://127.0.0.1:%d' % self.twisted_port, self.run_reactor)
self.assertEqual(response.body, 'Hello from twisted!')
def testTornadoServerTwistedClientIOLoop(self):
self.start_tornado_server()
response = self.twisted_fetch(
'http://localhost:%d' % self.tornado_port, self.run_ioloop)
'http://127.0.0.1:%d' % self.tornado_port, self.run_ioloop)
self.assertEqual(response, 'Hello from tornado!')
def testTornadoServerTwistedClientReactor(self):
self.start_tornado_server()
response = self.twisted_fetch(
'http://localhost:%d' % self.tornado_port, self.run_reactor)
'http://127.0.0.1:%d' % self.tornado_port, self.run_reactor)
self.assertEqual(response, 'Hello from tornado!')
@skipIfNoSingleDispatch
def testTornadoServerTwistedCoroutineClientIOLoop(self):
self.start_tornado_server()
response = self.twisted_coroutine_fetch(
'http://127.0.0.1:%d' % self.tornado_port, self.run_ioloop)
self.assertEqual(response, 'Hello from tornado!')
@skipIfNoTwisted
@skipIfNoSingleDispatch
class ConvertDeferredTest(unittest.TestCase):
def test_success(self):
@inlineCallbacks
def fn():
if False:
# inlineCallbacks doesn't work with regular functions;
# must have a yield even if it's unreachable.
yield
returnValue(42)
f = gen.convert_yielded(fn())
self.assertEqual(f.result(), 42)
def test_failure(self):
@inlineCallbacks
def fn():
if False:
yield
1 / 0
f = gen.convert_yielded(fn())
with self.assertRaises(ZeroDivisionError):
f.result()
if have_twisted:
# Import and run as much of twisted's test suite as possible.
@ -483,7 +554,7 @@ if have_twisted:
'test_changeUID',
],
# Process tests appear to work on OSX 10.7, but not 10.6
#'twisted.internet.test.test_process.PTYProcessTestsBuilder': [
# 'twisted.internet.test.test_process.PTYProcessTestsBuilder': [
# 'test_systemCallUninterruptedByChildExit',
# ],
'twisted.internet.test.test_tcp.TCPClientTestsBuilder': [
@ -502,7 +573,7 @@ if have_twisted:
'twisted.internet.test.test_threads.ThreadTestsBuilder': [],
'twisted.internet.test.test_time.TimeTestsBuilder': [],
# Extra third-party dependencies (pyOpenSSL)
#'twisted.internet.test.test_tls.SSLClientTestsMixin': [],
# 'twisted.internet.test.test_tls.SSLClientTestsMixin': [],
'twisted.internet.test.test_udp.UDPServerTestsBuilder': [],
'twisted.internet.test.test_unix.UNIXTestsBuilder': [
# Platform-specific. These tests would be skipped automatically
@ -588,13 +659,13 @@ if have_twisted:
correctly. In some tests another TornadoReactor is layered on top
of the whole stack.
"""
def initialize(self):
def initialize(self, **kwargs):
# When configured to use LayeredTwistedIOLoop we can't easily
# get the next-best IOLoop implementation, so use the lowest common
# denominator.
self.real_io_loop = SelectIOLoop()
reactor = TornadoReactor(io_loop=self.real_io_loop)
super(LayeredTwistedIOLoop, self).initialize(reactor=reactor)
super(LayeredTwistedIOLoop, self).initialize(reactor=reactor, **kwargs)
self.add_callback(self.make_current)
def close(self, all_fds=False):

View file

@ -4,6 +4,8 @@ import os
import socket
import sys
from tornado.testing import bind_unused_port
# Encapsulate the choice of unittest or unittest2 here.
# To be used as 'from tornado.test.util import unittest'.
if sys.version_info < (2, 7):
@ -28,3 +30,23 @@ skipIfNoNetwork = unittest.skipIf('NO_NETWORK' in os.environ,
'network access disabled')
skipIfNoIPv6 = unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present')
def refusing_port():
"""Returns a local port number that will refuse all connections.
Return value is (cleanup_func, port); the cleanup function
must be called to free the port to be reused.
"""
# On travis-ci, port numbers are reassigned frequently. To avoid
# collisions with other tests, we use an open client-side socket's
# ephemeral port number to ensure that nothing can listen on that
# port.
server_socket, port = bind_unused_port()
server_socket.setblocking(1)
client_socket = socket.socket()
client_socket.connect(("127.0.0.1", port))
conn, client_addr = server_socket.accept()
conn.close()
server_socket.close()
return (client_socket.close, client_addr[1])

View file

@ -3,8 +3,9 @@ from __future__ import absolute_import, division, print_function, with_statement
import sys
import datetime
import tornado.escape
from tornado.escape import utf8
from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer, timedelta_to_seconds
from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer, timedelta_to_seconds, import_object
from tornado.test.util import unittest
try:
@ -45,13 +46,15 @@ class TestConfigurable(Configurable):
class TestConfig1(TestConfigurable):
def initialize(self, a=None):
def initialize(self, pos_arg=None, a=None):
self.a = a
self.pos_arg = pos_arg
class TestConfig2(TestConfigurable):
def initialize(self, b=None):
def initialize(self, pos_arg=None, b=None):
self.b = b
self.pos_arg = pos_arg
class ConfigurableTest(unittest.TestCase):
@ -101,9 +104,10 @@ class ConfigurableTest(unittest.TestCase):
self.assertIsInstance(obj, TestConfig1)
self.assertEqual(obj.a, 3)
obj = TestConfigurable(a=4)
obj = TestConfigurable(42, a=4)
self.assertIsInstance(obj, TestConfig1)
self.assertEqual(obj.a, 4)
self.assertEqual(obj.pos_arg, 42)
self.checkSubclasses()
# args bound in configure don't apply when using the subclass directly
@ -116,9 +120,10 @@ class ConfigurableTest(unittest.TestCase):
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 5)
obj = TestConfigurable(b=6)
obj = TestConfigurable(42, b=6)
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 6)
self.assertEqual(obj.pos_arg, 42)
self.checkSubclasses()
# args bound in configure don't apply when using the subclass directly
@ -177,3 +182,20 @@ class TimedeltaToSecondsTest(unittest.TestCase):
def test_timedelta_to_seconds(self):
time_delta = datetime.timedelta(hours=1)
self.assertEqual(timedelta_to_seconds(time_delta), 3600.0)
class ImportObjectTest(unittest.TestCase):
def test_import_member(self):
self.assertIs(import_object('tornado.escape.utf8'), utf8)
def test_import_member_unicode(self):
self.assertIs(import_object(u('tornado.escape.utf8')), utf8)
def test_import_module(self):
self.assertIs(import_object('tornado.escape'), tornado.escape)
def test_import_module_unicode(self):
# The internal implementation of __import__ differs depending on
# whether the thing being imported is a module or not.
# This variant requires a byte string in python 2.
self.assertIs(import_object(u('tornado.escape')), tornado.escape)

View file

@ -11,7 +11,7 @@ from tornado.template import DictLoader
from tornado.testing import AsyncHTTPTestCase, ExpectLog, gen_test
from tornado.test.util import unittest
from tornado.util import u, ObjectDict, unicode_type, timedelta_to_seconds
from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler
from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler, get_signature_key_version
import binascii
import contextlib
@ -71,10 +71,14 @@ class HelloHandler(RequestHandler):
class CookieTestRequestHandler(RequestHandler):
# stub out enough methods to make the secure_cookie functions work
def __init__(self):
def __init__(self, cookie_secret='0123456789', key_version=None):
# don't call super.__init__
self._cookies = {}
self.application = ObjectDict(settings=dict(cookie_secret='0123456789'))
if key_version is None:
self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret))
else:
self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret,
key_version=key_version))
def get_cookie(self, name):
return self._cookies.get(name)
@ -128,6 +132,51 @@ class SecureCookieV1Test(unittest.TestCase):
self.assertEqual(handler.get_secure_cookie('foo', min_version=1), b'\xe9')
# See SignedValueTest below for more.
class SecureCookieV2Test(unittest.TestCase):
KEY_VERSIONS = {
0: 'ajklasdf0ojaisdf',
1: 'aslkjasaolwkjsdf'
}
def test_round_trip(self):
handler = CookieTestRequestHandler()
handler.set_secure_cookie('foo', b'bar', version=2)
self.assertEqual(handler.get_secure_cookie('foo', min_version=2), b'bar')
def test_key_version_roundtrip(self):
handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
key_version=0)
handler.set_secure_cookie('foo', b'bar')
self.assertEqual(handler.get_secure_cookie('foo'), b'bar')
def test_key_version_roundtrip_differing_version(self):
handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
key_version=1)
handler.set_secure_cookie('foo', b'bar')
self.assertEqual(handler.get_secure_cookie('foo'), b'bar')
def test_key_version_increment_version(self):
handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
key_version=0)
handler.set_secure_cookie('foo', b'bar')
new_handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
key_version=1)
new_handler._cookies = handler._cookies
self.assertEqual(new_handler.get_secure_cookie('foo'), b'bar')
def test_key_version_invalidate_version(self):
handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
key_version=0)
handler.set_secure_cookie('foo', b'bar')
new_key_versions = self.KEY_VERSIONS.copy()
new_key_versions.pop(0)
new_handler = CookieTestRequestHandler(cookie_secret=new_key_versions,
key_version=1)
new_handler._cookies = handler._cookies
self.assertEqual(new_handler.get_secure_cookie('foo'), None)
class CookieTest(WebTestCase):
def get_handlers(self):
class SetCookieHandler(RequestHandler):
@ -171,6 +220,13 @@ class CookieTest(WebTestCase):
def get(self):
self.set_cookie("foo", "bar", expires_days=10)
class SetCookieFalsyFlags(RequestHandler):
def get(self):
self.set_cookie("a", "1", secure=True)
self.set_cookie("b", "1", secure=False)
self.set_cookie("c", "1", httponly=True)
self.set_cookie("d", "1", httponly=False)
return [("/set", SetCookieHandler),
("/get", GetCookieHandler),
("/set_domain", SetCookieDomainHandler),
@ -178,6 +234,7 @@ class CookieTest(WebTestCase):
("/set_overwrite", SetCookieOverwriteHandler),
("/set_max_age", SetCookieMaxAgeHandler),
("/set_expires_days", SetCookieExpiresDaysHandler),
("/set_falsy_flags", SetCookieFalsyFlags)
]
def test_set_cookie(self):
@ -249,6 +306,16 @@ class CookieTest(WebTestCase):
*email.utils.parsedate(match.groupdict()["expires"])[:6])
self.assertTrue(abs(timedelta_to_seconds(expires - header_expires)) < 10)
def test_set_cookie_false_flags(self):
response = self.fetch("/set_falsy_flags")
headers = sorted(response.headers.get_list("Set-Cookie"))
# The secure and httponly headers are capitalized in py35 and
# lowercase in older versions.
self.assertEqual(headers[0].lower(), 'a=1; path=/; secure')
self.assertEqual(headers[1].lower(), 'b=1; path=/')
self.assertEqual(headers[2].lower(), 'c=1; httponly; path=/')
self.assertEqual(headers[3].lower(), 'd=1; path=/')
class AuthRedirectRequestHandler(RequestHandler):
def initialize(self, login_url):
@ -305,7 +372,7 @@ class ConnectionCloseTest(WebTestCase):
def test_connection_close(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
s.connect(("localhost", self.get_http_port()))
s.connect(("127.0.0.1", self.get_http_port()))
self.stream = IOStream(s, io_loop=self.io_loop)
self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
self.wait()
@ -379,6 +446,12 @@ class RequestEncodingTest(WebTestCase):
path_args=["a/b", "c/d"],
args={}))
def test_error(self):
# Percent signs (encoded as %25) should not mess up printf-style
# messages in logs
with ExpectLog(gen_log, ".*Invalid unicode"):
self.fetch("/group/?arg=%25%e9")
class TypeCheckHandler(RequestHandler):
def prepare(self):
@ -579,6 +652,7 @@ class WSGISafeWebTest(WebTestCase):
url("/redirect", RedirectHandler),
url("/web_redirect_permanent", WebRedirectHandler, {"url": "/web_redirect_newpath"}),
url("/web_redirect", WebRedirectHandler, {"url": "/web_redirect_newpath", "permanent": False}),
url("//web_redirect_double_slash", WebRedirectHandler, {"url": '/web_redirect_newpath'}),
url("/header_injection", HeaderInjectionHandler),
url("/get_argument", GetArgumentHandler),
url("/get_arguments", GetArgumentsHandler),
@ -712,6 +786,11 @@ js_embed()
self.assertEqual(response.code, 302)
self.assertEqual(response.headers['Location'], '/web_redirect_newpath')
def test_web_redirect_double_slash(self):
response = self.fetch("//web_redirect_double_slash", follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], '/web_redirect_newpath')
def test_header_injection(self):
response = self.fetch("/header_injection")
self.assertEqual(response.body, b"ok")
@ -1517,6 +1596,22 @@ class ExceptionHandlerTest(SimpleHandlerTestCase):
self.assertEqual(response.code, 403)
@wsgi_safe
class BuggyLoggingTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
1/0
def log_exception(self, typ, value, tb):
1/0
def test_buggy_log_exception(self):
# Something gets logged even though the application's
# logger is broken.
with ExpectLog(app_log, '.*'):
self.fetch('/')
@wsgi_safe
class UIMethodUIModuleTest(SimpleHandlerTestCase):
"""Test that UI methods and modules are created correctly and
@ -1533,6 +1628,7 @@ class UIMethodUIModuleTest(SimpleHandlerTestCase):
def my_ui_method(handler, x):
return "In my_ui_method(%s) with handler value %s." % (
x, handler.value())
class MyModule(UIModule):
def render(self, x):
return "In MyModule(%s) with handler value %s." % (
@ -1907,7 +2003,7 @@ class StreamingRequestBodyTest(WebTestCase):
def connect(self, url, connection_close):
# Use a raw connection so we can control the sending of data.
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
s.connect(("localhost", self.get_http_port()))
s.connect(("127.0.0.1", self.get_http_port()))
stream = IOStream(s, io_loop=self.io_loop)
stream.write(b"GET " + url + b" HTTP/1.1\r\n")
if connection_close:
@ -1988,7 +2084,9 @@ class StreamingRequestFlowControlTest(WebTestCase):
@gen.coroutine
def prepare(self):
with self.in_method('prepare'):
# Note that asynchronous prepare() does not block data_received,
# so we don't use in_method here.
self.methods.append('prepare')
yield gen.Task(IOLoop.current().add_callback)
@gen.coroutine
@ -2051,9 +2149,10 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase):
# When the content-length is too high, the connection is simply
# closed without completing the response. An error is logged on
# the server.
with ExpectLog(app_log, "Uncaught exception"):
with ExpectLog(app_log, "(Uncaught exception|Exception in callback)"):
with ExpectLog(gen_log,
"Cannot send error response after headers written"):
"(Cannot send error response after headers written"
"|Failed to flush partial response)"):
response = self.fetch("/high")
self.assertEqual(response.code, 599)
self.assertEqual(str(self.server_error),
@ -2063,9 +2162,10 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase):
# When the content-length is too low, the connection is closed
# without writing the last chunk, so the client never sees the request
# complete (which would be a framing error).
with ExpectLog(app_log, "Uncaught exception"):
with ExpectLog(app_log, "(Uncaught exception|Exception in callback)"):
with ExpectLog(gen_log,
"Cannot send error response after headers written"):
"(Cannot send error response after headers written"
"|Failed to flush partial response)"):
response = self.fetch("/low")
self.assertEqual(response.code, 599)
self.assertEqual(str(self.server_error),
@ -2075,6 +2175,7 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase):
class ClientCloseTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
if self.request.version.startswith('HTTP/1'):
# Simulate a connection closed by the client during
# request processing. The client will see an error, but the
# server should respond gracefully (without logging errors
@ -2082,14 +2183,20 @@ class ClientCloseTest(SimpleHandlerTestCase):
# Content-Length said we would)
self.request.connection.stream.close()
self.write('hello')
else:
# TODO: add a HTTP2-compatible version of this test.
self.write('requires HTTP/1.x')
def test_client_close(self):
response = self.fetch('/')
if response.body == b'requires HTTP/1.x':
self.skipTest('requires HTTP/1.x')
self.assertEqual(response.code, 599)
class SignedValueTest(unittest.TestCase):
SECRET = "It's a secret to everybody"
SECRET_DICT = {0: "asdfbasdf", 1: "12312312", 2: "2342342"}
def past(self):
return self.present() - 86400 * 32
@ -2151,6 +2258,7 @@ class SignedValueTest(unittest.TestCase):
def test_payload_tampering(self):
# These cookies are variants of the one in test_known_values.
sig = "3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152"
def validate(prefix):
return (b'value' ==
decode_signed_value(SignedValueTest.SECRET, "key",
@ -2165,6 +2273,7 @@ class SignedValueTest(unittest.TestCase):
def test_signature_tampering(self):
prefix = "2|1:0|10:1300000000|3:key|8:dmFsdWU=|"
def validate(sig):
return (b'value' ==
decode_signed_value(SignedValueTest.SECRET, "key",
@ -2194,6 +2303,43 @@ class SignedValueTest(unittest.TestCase):
clock=self.present)
self.assertEqual(value, decoded)
def test_key_versioning_read_write_default_key(self):
value = b"\xe9"
signed = create_signed_value(SignedValueTest.SECRET_DICT,
"key", value, clock=self.present,
key_version=0)
decoded = decode_signed_value(SignedValueTest.SECRET_DICT,
"key", signed, clock=self.present)
self.assertEqual(value, decoded)
def test_key_versioning_read_write_non_default_key(self):
value = b"\xe9"
signed = create_signed_value(SignedValueTest.SECRET_DICT,
"key", value, clock=self.present,
key_version=1)
decoded = decode_signed_value(SignedValueTest.SECRET_DICT,
"key", signed, clock=self.present)
self.assertEqual(value, decoded)
def test_key_versioning_invalid_key(self):
value = b"\xe9"
signed = create_signed_value(SignedValueTest.SECRET_DICT,
"key", value, clock=self.present,
key_version=0)
newkeys = SignedValueTest.SECRET_DICT.copy()
newkeys.pop(0)
decoded = decode_signed_value(newkeys,
"key", signed, clock=self.present)
self.assertEqual(None, decoded)
def test_key_version_retrieval(self):
value = b"\xe9"
signed = create_signed_value(SignedValueTest.SECRET_DICT,
"key", value, clock=self.present,
key_version=1)
key_version = get_signature_key_version(signed)
self.assertEqual(1, key_version)
@wsgi_safe
class XSRFTest(SimpleHandlerTestCase):
@ -2372,6 +2518,7 @@ class FinishExceptionTest(SimpleHandlerTestCase):
self.assertEqual(b'authentication required', response.body)
@wsgi_safe
class DecoratorTest(WebTestCase):
def get_handlers(self):
class RemoveSlashHandler(RequestHandler):
@ -2405,3 +2552,85 @@ class DecoratorTest(WebTestCase):
response = self.fetch("/addslash?foo=bar", follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], "/addslash/?foo=bar")
@wsgi_safe
class CacheTest(WebTestCase):
def get_handlers(self):
class EtagHandler(RequestHandler):
def get(self, computed_etag):
self.write(computed_etag)
def compute_etag(self):
return self._write_buffer[0]
return [
('/etag/(.*)', EtagHandler)
]
def test_wildcard_etag(self):
computed_etag = '"xyzzy"'
etags = '*'
self._test_etag(computed_etag, etags, 304)
def test_strong_etag_match(self):
computed_etag = '"xyzzy"'
etags = '"xyzzy"'
self._test_etag(computed_etag, etags, 304)
def test_multiple_strong_etag_match(self):
computed_etag = '"xyzzy1"'
etags = '"xyzzy1", "xyzzy2"'
self._test_etag(computed_etag, etags, 304)
def test_strong_etag_not_match(self):
computed_etag = '"xyzzy"'
etags = '"xyzzy1"'
self._test_etag(computed_etag, etags, 200)
def test_multiple_strong_etag_not_match(self):
computed_etag = '"xyzzy"'
etags = '"xyzzy1", "xyzzy2"'
self._test_etag(computed_etag, etags, 200)
def test_weak_etag_match(self):
computed_etag = '"xyzzy1"'
etags = 'W/"xyzzy1"'
self._test_etag(computed_etag, etags, 304)
def test_multiple_weak_etag_match(self):
computed_etag = '"xyzzy2"'
etags = 'W/"xyzzy1", W/"xyzzy2"'
self._test_etag(computed_etag, etags, 304)
def test_weak_etag_not_match(self):
computed_etag = '"xyzzy2"'
etags = 'W/"xyzzy1"'
self._test_etag(computed_etag, etags, 200)
def test_multiple_weak_etag_not_match(self):
computed_etag = '"xyzzy3"'
etags = 'W/"xyzzy1", W/"xyzzy2"'
self._test_etag(computed_etag, etags, 200)
def _test_etag(self, computed_etag, etags, status_code):
response = self.fetch(
'/etag/' + computed_etag,
headers={'If-None-Match': etags}
)
self.assertEqual(response.code, status_code)
@wsgi_safe
class RequestSummaryTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
# remote_ip is optional, although it's set by
# both HTTPServer and WSGIAdapter.
# Clobber it to make sure it doesn't break logging.
self.request.remote_ip = None
self.finish(self._request_summary())
def test_missing_remote_ip(self):
resp = self.fetch("/")
self.assertEqual(resp.body, b"GET / (None)")

View file

@ -12,7 +12,7 @@ from tornado.web import Application, RequestHandler
from tornado.util import u
try:
import tornado.websocket
import tornado.websocket # noqa
from tornado.util import _websocket_mask_python
except ImportError:
# The unittest module presents misleading errors on ImportError
@ -53,7 +53,7 @@ class EchoHandler(TestWebSocketHandler):
class ErrorInOnMessageHandler(TestWebSocketHandler):
def on_message(self, message):
1/0
1 / 0
class HeaderHandler(TestWebSocketHandler):
@ -75,6 +75,7 @@ class NonWebSocketHandler(RequestHandler):
class CloseReasonHandler(TestWebSocketHandler):
def open(self):
self.on_close_called = False
self.close(1001, "goodbye")
@ -91,7 +92,7 @@ class WebSocketBaseTestCase(AsyncHTTPTestCase):
@gen.coroutine
def ws_connect(self, path, compression_options=None):
ws = yield websocket_connect(
'ws://localhost:%d%s' % (self.get_http_port(), path),
'ws://127.0.0.1:%d%s' % (self.get_http_port(), path),
compression_options=compression_options)
raise gen.Return(ws)
@ -105,6 +106,7 @@ class WebSocketBaseTestCase(AsyncHTTPTestCase):
ws.close()
yield self.close_future
class WebSocketTest(WebSocketBaseTestCase):
def get_app(self):
self.close_future = Future()
@ -135,7 +137,7 @@ class WebSocketTest(WebSocketBaseTestCase):
def test_websocket_callbacks(self):
websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port(),
'ws://127.0.0.1:%d/echo' % self.get_http_port(),
io_loop=self.io_loop, callback=self.stop)
ws = self.wait().result()
ws.write_message('hello')
@ -189,14 +191,14 @@ class WebSocketTest(WebSocketBaseTestCase):
with self.assertRaises(IOError):
with ExpectLog(gen_log, ".*"):
yield websocket_connect(
'ws://localhost:%d/' % port,
'ws://127.0.0.1:%d/' % port,
io_loop=self.io_loop,
connect_timeout=3600)
@gen_test
def test_websocket_close_buffered_data(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port())
'ws://127.0.0.1:%d/echo' % self.get_http_port())
ws.write_message('hello')
ws.write_message('world')
# Close the underlying stream.
@ -207,7 +209,7 @@ class WebSocketTest(WebSocketBaseTestCase):
def test_websocket_headers(self):
# Ensure that arbitrary headers can be passed through websocket_connect.
ws = yield websocket_connect(
HTTPRequest('ws://localhost:%d/header' % self.get_http_port(),
HTTPRequest('ws://127.0.0.1:%d/header' % self.get_http_port(),
headers={'X-Test': 'hello'}))
response = yield ws.read_message()
self.assertEqual(response, 'hello')
@ -221,6 +223,8 @@ class WebSocketTest(WebSocketBaseTestCase):
self.assertIs(msg, None)
self.assertEqual(ws.close_code, 1001)
self.assertEqual(ws.close_reason, "goodbye")
# The on_close callback is called no matter which side closed.
yield self.close_future
@gen_test
def test_client_close_reason(self):
@ -243,8 +247,8 @@ class WebSocketTest(WebSocketBaseTestCase):
def test_check_origin_valid_no_path(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
headers = {'Origin': 'http://localhost:%d' % port}
url = 'ws://127.0.0.1:%d/echo' % port
headers = {'Origin': 'http://127.0.0.1:%d' % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
@ -257,8 +261,8 @@ class WebSocketTest(WebSocketBaseTestCase):
def test_check_origin_valid_with_path(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
headers = {'Origin': 'http://localhost:%d/something' % port}
url = 'ws://127.0.0.1:%d/echo' % port
headers = {'Origin': 'http://127.0.0.1:%d/something' % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
@ -271,8 +275,8 @@ class WebSocketTest(WebSocketBaseTestCase):
def test_check_origin_invalid_partial_url(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
headers = {'Origin': 'localhost:%d' % port}
url = 'ws://127.0.0.1:%d/echo' % port
headers = {'Origin': '127.0.0.1:%d' % port}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers),
@ -283,8 +287,8 @@ class WebSocketTest(WebSocketBaseTestCase):
def test_check_origin_invalid(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
# Host is localhost, which should not be accessible from some other
url = 'ws://127.0.0.1:%d/echo' % port
# Host is 127.0.0.1, which should not be accessible from some other
# domain
headers = {'Origin': 'http://somewhereelse.com'}

View file

@ -19,6 +19,7 @@ try:
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.ioloop import IOLoop, TimeoutError
from tornado import netutil
from tornado.process import Subprocess
except ImportError:
# These modules are not importable on app engine. Parts of this module
# won't work, but e.g. LogTrapTestCase and main() will.
@ -28,6 +29,7 @@ except ImportError:
IOLoop = None
netutil = None
SimpleAsyncHTTPClient = None
Subprocess = None
from tornado.log import gen_log, app_log
from tornado.stack_context import ExceptionStackContext
from tornado.util import raise_exc_info, basestring_type
@ -214,6 +216,8 @@ class AsyncTestCase(unittest.TestCase):
self.io_loop.make_current()
def tearDown(self):
# Clean up Subprocess, so it can be used again with a new ioloop.
Subprocess.uninitialize()
self.io_loop.clear_current()
if (not IOLoop.initialized() or
self.io_loop is not IOLoop.instance()):
@ -413,9 +417,7 @@ class AsyncHTTPSTestCase(AsyncHTTPTestCase):
Interface is generally the same as `AsyncHTTPTestCase`.
"""
def get_http_client(self):
# Some versions of libcurl have deadlock bugs with ssl,
# so always run these tests with SimpleAsyncHTTPClient.
return SimpleAsyncHTTPClient(io_loop=self.io_loop, force_instance=True,
return AsyncHTTPClient(io_loop=self.io_loop, force_instance=True,
defaults=dict(validate_cert=False))
def get_httpserver_options(self):
@ -539,6 +541,9 @@ class LogTrapTestCase(unittest.TestCase):
`logging.basicConfig` and the "pretty logging" configured by
`tornado.options`. It is not compatible with other log buffering
mechanisms, such as those provided by some test runners.
.. deprecated:: 4.1
Use the unittest module's ``--buffer`` option instead, or `.ExpectLog`.
"""
def run(self, result=None):
logger = logging.getLogger()

View file

@ -78,6 +78,25 @@ class GzipDecompressor(object):
return self.decompressobj.flush()
# Fake unicode literal support: Python 3.2 doesn't have the u'' marker for
# literal strings, and alternative solutions like "from __future__ import
# unicode_literals" have other problems (see PEP 414). u() can be applied
# to ascii strings that include \u escapes (but they must not contain
# literal non-ascii characters).
if not isinstance(b'', type('')):
def u(s):
return s
unicode_type = str
basestring_type = str
else:
def u(s):
return s.decode('unicode_escape')
# These names don't exist in py3, so use noqa comments to disable
# warnings in flake8.
unicode_type = unicode # noqa
basestring_type = basestring # noqa
def import_object(name):
"""Imports an object by name.
@ -96,6 +115,9 @@ def import_object(name):
...
ImportError: No module named missing_module
"""
if isinstance(name, unicode_type) and str is not unicode_type:
# On python 2 a byte string is required.
name = name.encode('utf-8')
if name.count('.') == 0:
return __import__(name, None, None)
@ -107,22 +129,6 @@ def import_object(name):
raise ImportError("No module named %s" % parts[-1])
# Fake unicode literal support: Python 3.2 doesn't have the u'' marker for
# literal strings, and alternative solutions like "from __future__ import
# unicode_literals" have other problems (see PEP 414). u() can be applied
# to ascii strings that include \u escapes (but they must not contain
# literal non-ascii characters).
if type('') is not type(b''):
def u(s):
return s
unicode_type = str
basestring_type = str
else:
def u(s):
return s.decode('unicode_escape')
unicode_type = unicode
basestring_type = basestring
# Deprecated alias that was used before we dropped py25 support.
# Left here in case anyone outside Tornado is using it.
bytes_type = bytes
@ -192,21 +198,21 @@ class Configurable(object):
__impl_class = None
__impl_kwargs = None
def __new__(cls, **kwargs):
def __new__(cls, *args, **kwargs):
base = cls.configurable_base()
args = {}
init_kwargs = {}
if cls is base:
impl = cls.configured_class()
if base.__impl_kwargs:
args.update(base.__impl_kwargs)
init_kwargs.update(base.__impl_kwargs)
else:
impl = cls
args.update(kwargs)
init_kwargs.update(kwargs)
instance = super(Configurable, cls).__new__(impl)
# initialize vs __init__ chosen for compatibility with AsyncHTTPClient
# singleton magic. If we get rid of that we can switch to __init__
# here too.
instance.initialize(**args)
instance.initialize(*args, **init_kwargs)
return instance
@classmethod
@ -227,6 +233,9 @@ class Configurable(object):
"""Initialize a `Configurable` subclass instance.
Configurable classes should use `initialize` instead of ``__init__``.
.. versionchanged:: 4.2
Now accepts positional arguments in addition to keyword arguments.
"""
@classmethod

View file

@ -19,7 +19,9 @@ features that allow it to scale to large numbers of open connections,
making it ideal for `long polling
<http://en.wikipedia.org/wiki/Push_technology#Long_polling>`_.
Here is a simple "Hello, world" example app::
Here is a simple "Hello, world" example app:
.. testcode::
import tornado.ioloop
import tornado.web
@ -33,7 +35,11 @@ Here is a simple "Hello, world" example app::
(r"/", MainHandler),
])
application.listen(8888)
tornado.ioloop.IOLoop.instance().start()
tornado.ioloop.IOLoop.current().start()
.. testoutput::
:hide:
See the :doc:`guide` for additional information.
@ -50,7 +56,8 @@ request.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import (absolute_import, division,
print_function, with_statement)
import base64
@ -84,7 +91,9 @@ from tornado.log import access_log, app_log, gen_log
from tornado import stack_context
from tornado import template
from tornado.escape import utf8, _unicode
from tornado.util import import_object, ObjectDict, raise_exc_info, unicode_type, _websocket_mask
from tornado.util import (import_object, ObjectDict, raise_exc_info,
unicode_type, _websocket_mask)
from tornado.httputil import split_host_and_port
try:
@ -130,12 +139,11 @@ May be overridden by passing a ``version`` keyword argument.
DEFAULT_SIGNED_VALUE_MIN_VERSION = 1
"""The oldest signed value accepted by `.RequestHandler.get_secure_cookie`.
May be overrided by passing a ``min_version`` keyword argument.
May be overridden by passing a ``min_version`` keyword argument.
.. versionadded:: 3.2.1
"""
class RequestHandler(object):
"""Subclass this class and define `get()` or `post()` to make a handler.
@ -267,6 +275,7 @@ class RequestHandler(object):
if _has_stream_request_body(self.__class__):
if not self.request.body.done():
self.request.body.set_exception(iostream.StreamClosedError())
self.request.body.exception()
def clear(self):
"""Resets all headers and content for this response."""
@ -382,6 +391,12 @@ class RequestHandler(object):
The returned values are always unicode.
"""
# Make sure `get_arguments` isn't accidentally being called with a
# positional argument that's assumed to be a default (like in
# `get_argument`.)
assert isinstance(strip, bool)
return self._get_arguments(name, self.request.arguments, strip)
def get_body_argument(self, name, default=_ARG_DEFAULT, strip=True):
@ -398,7 +413,8 @@ class RequestHandler(object):
.. versionadded:: 3.2
"""
return self._get_argument(name, default, self.request.body_arguments, strip)
return self._get_argument(name, default, self.request.body_arguments,
strip)
def get_body_arguments(self, name, strip=True):
"""Returns a list of the body arguments with the given name.
@ -425,7 +441,8 @@ class RequestHandler(object):
.. versionadded:: 3.2
"""
return self._get_argument(name, default, self.request.query_arguments, strip)
return self._get_argument(name, default,
self.request.query_arguments, strip)
def get_query_arguments(self, name, strip=True):
"""Returns a list of the query arguments with the given name.
@ -480,7 +497,8 @@ class RequestHandler(object):
@property
def cookies(self):
"""An alias for `self.request.cookies <.httputil.HTTPServerRequest.cookies>`."""
"""An alias for
`self.request.cookies <.httputil.HTTPServerRequest.cookies>`."""
return self.request.cookies
def get_cookie(self, name, default=None):
@ -522,6 +540,12 @@ class RequestHandler(object):
for k, v in kwargs.items():
if k == 'max_age':
k = 'max-age'
# skip falsy values for httponly and secure flags because
# SimpleCookie sets them regardless
if k in ['httponly', 'secure'] and not v:
continue
morsel[k] = v
def clear_cookie(self, name, path="/", domain=None):
@ -588,8 +612,15 @@ class RequestHandler(object):
and made it the default.
"""
self.require_setting("cookie_secret", "secure cookies")
return create_signed_value(self.application.settings["cookie_secret"],
name, value, version=version)
secret = self.application.settings["cookie_secret"]
key_version = None
if isinstance(secret, dict):
if self.application.settings.get("key_version") is None:
raise Exception("key_version setting must be used for secret_key dicts")
key_version = self.application.settings["key_version"]
return create_signed_value(secret, name, value, version=version,
key_version=key_version)
def get_secure_cookie(self, name, value=None, max_age_days=31,
min_version=None):
@ -610,6 +641,17 @@ class RequestHandler(object):
name, value, max_age_days=max_age_days,
min_version=min_version)
def get_secure_cookie_key_version(self, name, value=None):
"""Returns the signing key version of the secure cookie.
The version is returned as int.
"""
self.require_setting("cookie_secret", "secure cookies")
if value is None:
value = self.get_cookie(name)
return get_signature_key_version(value)
def redirect(self, url, permanent=False, status=None):
"""Sends a redirect to the given (optionally relative) URL.
@ -625,8 +667,7 @@ class RequestHandler(object):
else:
assert isinstance(status, int) and 300 <= status <= 399
self.set_status(status)
self.set_header("Location", urlparse.urljoin(utf8(self.request.uri),
utf8(url)))
self.set_header("Location", utf8(url))
self.finish()
def write(self, chunk):
@ -646,15 +687,13 @@ class RequestHandler(object):
https://github.com/facebook/tornado/issues/1009
"""
if self._finished:
raise RuntimeError("Cannot write() after finish(). May be caused "
"by using async operations without the "
"@asynchronous decorator.")
raise RuntimeError("Cannot write() after finish()")
if not isinstance(chunk, (bytes, unicode_type, dict)):
raise TypeError("write() only accepts bytes, unicode, and dict objects")
message = "write() only accepts bytes, unicode, and dict objects"
if isinstance(chunk, list):
message += ". Lists not accepted for security reasons; see http://www.tornadoweb.org/en/stable/web.html#tornado.web.RequestHandler.write"
raise TypeError(message)
if isinstance(chunk, dict):
if 'unwrap_json' in chunk:
chunk = chunk['unwrap_json']
else:
chunk = escape.json_encode(chunk)
self.set_header("Content-Type", "application/json; charset=UTF-8")
chunk = utf8(chunk)
@ -786,6 +825,7 @@ class RequestHandler(object):
current_user=self.current_user,
locale=self.locale,
_=self.locale.translate,
pgettext=self.locale.pgettext,
static_url=self.static_url,
xsrf_form_html=self.xsrf_form_html,
reverse_url=self.reverse_url
@ -830,7 +870,8 @@ class RequestHandler(object):
for transform in self._transforms:
self._status_code, self._headers, chunk = \
transform.transform_first_chunk(
self._status_code, self._headers, chunk, include_footers)
self._status_code, self._headers,
chunk, include_footers)
# Ignore the chunk and only write the headers for HEAD requests
if self.request.method == "HEAD":
chunk = None
@ -842,7 +883,7 @@ class RequestHandler(object):
for cookie in self._new_cookie.values():
self.add_header("Set-Cookie", cookie.OutputString(None))
start_line = httputil.ResponseStartLine(self.request.version,
start_line = httputil.ResponseStartLine('',
self._status_code,
self._reason)
return self.request.connection.write_headers(
@ -861,9 +902,7 @@ class RequestHandler(object):
def finish(self, chunk=None):
"""Finishes this response, ending the HTTP request."""
if self._finished:
raise RuntimeError("finish() called twice. May be caused "
"by using async operations without the "
"@asynchronous decorator.")
raise RuntimeError("finish() called twice")
if chunk is not None:
self.write(chunk)
@ -915,7 +954,15 @@ class RequestHandler(object):
if self._headers_written:
gen_log.error("Cannot send error response after headers written")
if not self._finished:
# If we get an error between writing headers and finishing,
# we are unlikely to be able to finish due to a
# Content-Length mismatch. Try anyway to release the
# socket.
try:
self.finish()
except Exception:
gen_log.error("Failed to flush partial response",
exc_info=True)
return
self.clear()
@ -1122,11 +1169,15 @@ class RequestHandler(object):
"""Convert a cookie string into a the tuple form returned by
_get_raw_xsrf_token.
"""
try:
m = _signed_value_version_re.match(utf8(cookie))
if m:
version = int(m.group(1))
if version == 2:
_, mask, masked_token, timestamp = cookie.split("|")
mask = binascii.a2b_hex(utf8(mask))
token = _websocket_mask(
mask, binascii.a2b_hex(utf8(masked_token)))
@ -1134,7 +1185,7 @@ class RequestHandler(object):
return version, token, timestamp
else:
# Treat unknown versions as not present instead of failing.
return None, None, None
raise Exception("Unknown xsrf cookie version")
else:
version = 1
try:
@ -1144,6 +1195,11 @@ class RequestHandler(object):
# We don't have a usable timestamp in older versions.
timestamp = int(time.time())
return (version, token, timestamp)
except Exception:
# Catch exceptions and return nothing instead of failing.
gen_log.debug("Uncaught exception in _decode_xsrf_token",
exc_info=True)
return None, None, None
def check_xsrf_cookie(self):
"""Verifies that the ``_xsrf`` cookie matches the ``_xsrf`` argument.
@ -1282,9 +1338,27 @@ class RequestHandler(object):
before completing the request. The ``Etag`` header should be set
(perhaps with `set_etag_header`) before calling this method.
"""
etag = self._headers.get("Etag")
inm = utf8(self.request.headers.get("If-None-Match", ""))
return bool(etag and inm and inm.find(etag) >= 0)
computed_etag = utf8(self._headers.get("Etag", ""))
# Find all weak and strong etag values from If-None-Match header
# because RFC 7232 allows multiple etag values in a single header.
etags = re.findall(
br'\*|(?:W/)?"[^"]*"',
utf8(self.request.headers.get("If-None-Match", ""))
)
if not computed_etag or not etags:
return False
match = False
if etags[0] == b'*':
match = True
else:
# Use a weak comparison when comparing entity-tags.
val = lambda x: x[2:] if x.startswith(b'W/') else x
for etag in etags:
if val(etag) == val(computed_etag):
match = True
break
return match
def _stack_context_handle_exception(self, type, value, traceback):
try:
@ -1344,7 +1418,10 @@ class RequestHandler(object):
if self._auto_finish and not self._finished:
self.finish()
except Exception as e:
try:
self._handle_request_exception(e)
except Exception:
app_log.error("Exception in exception handler", exc_info=True)
if (self._prepared_future is not None and
not self._prepared_future.done()):
# In case we failed before setting _prepared_future, do it
@ -1369,8 +1446,8 @@ class RequestHandler(object):
self.application.log_request(self)
def _request_summary(self):
return self.request.method + " " + self.request.uri + \
" (" + self.request.remote_ip + ")"
return "%s %s (%s)" % (self.request.method, self.request.uri,
self.request.remote_ip)
def _handle_request_exception(self, e):
if isinstance(e, Finish):
@ -1378,7 +1455,12 @@ class RequestHandler(object):
if not self._finished:
self.finish()
return
try:
self.log_exception(*sys.exc_info())
except Exception:
# An error here should still get a best-effort send_error()
# to avoid leaking the connection.
app_log.error("Error in exception logger", exc_info=True)
if self._finished:
# Extra errors after the request has been finished should
# be logged, but there is no reason to continue to try and
@ -1441,10 +1523,11 @@ class RequestHandler(object):
def asynchronous(method):
"""Wrap request handler methods with this if they are asynchronous.
This decorator is unnecessary if the method is also decorated with
``@gen.coroutine`` (it is legal but unnecessary to use the two
decorators together, in which case ``@asynchronous`` must be
first).
This decorator is for callback-style asynchronous methods; for
coroutines, use the ``@gen.coroutine`` decorator without
``@asynchronous``. (It is legal for legacy reasons to use the two
decorators together provided ``@asynchronous`` is first, but
``@asynchronous`` will be ignored in this case)
This decorator should only be applied to the :ref:`HTTP verb
methods <verbs>`; its behavior is undefined for any other method.
@ -1457,10 +1540,12 @@ def asynchronous(method):
method returns. It is up to the request handler to call
`self.finish() <RequestHandler.finish>` to finish the HTTP
request. Without this decorator, the request is automatically
finished when the ``get()`` or ``post()`` method returns. Example::
finished when the ``get()`` or ``post()`` method returns. Example:
class MyRequestHandler(web.RequestHandler):
@web.asynchronous
.. testcode::
class MyRequestHandler(RequestHandler):
@asynchronous
def get(self):
http = httpclient.AsyncHTTPClient()
http.fetch("http://friendfeed.com/", self._on_download)
@ -1469,18 +1554,23 @@ def asynchronous(method):
self.write("Downloaded!")
self.finish()
.. testoutput::
:hide:
.. versionadded:: 3.1
The ability to use ``@gen.coroutine`` without ``@asynchronous``.
"""
# Delay the IOLoop import because it's not available on app engine.
from tornado.ioloop import IOLoop
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
self._auto_finish = False
with stack_context.ExceptionStackContext(
self._stack_context_handle_exception):
result = method(self, *args, **kwargs)
if isinstance(result, Future):
if is_future(result):
# If @asynchronous is used with @gen.coroutine, (but
# not @gen.engine), we can automatically finish the
# request when the future resolves. Additionally,
@ -1521,7 +1611,7 @@ def stream_request_body(cls):
the entire body has been read.
There is a subtle interaction between ``data_received`` and asynchronous
``prepare``: The first call to ``data_recieved`` may occur at any point
``prepare``: The first call to ``data_received`` may occur at any point
after the call to ``prepare`` has returned *or yielded*.
"""
if not issubclass(cls, RequestHandler):
@ -1591,7 +1681,7 @@ class Application(httputil.HTTPServerConnectionDelegate):
])
http_server = httpserver.HTTPServer(application)
http_server.listen(8080)
ioloop.IOLoop.instance().start()
ioloop.IOLoop.current().start()
The constructor for this class takes in a list of `URLSpec` objects
or (regexp, request_class) tuples. When we receive requests, we
@ -1689,7 +1779,7 @@ class Application(httputil.HTTPServerConnectionDelegate):
`.TCPServer.bind`/`.TCPServer.start` methods directly.
Note that after calling this method you still need to call
``IOLoop.instance().start()`` to start the server.
``IOLoop.current().start()`` to start the server.
"""
# import is here rather than top level because HTTPServer
# is not importable on appengine
@ -1732,7 +1822,7 @@ class Application(httputil.HTTPServerConnectionDelegate):
self.transforms.append(transform_class)
def _get_host_handlers(self, request):
host = request.host.lower().split(':')[0]
host = split_host_and_port(request.host.lower())[0]
matches = []
for pattern, handlers in self.handlers:
if pattern.match(host):
@ -1773,9 +1863,9 @@ class Application(httputil.HTTPServerConnectionDelegate):
except TypeError:
pass
def start_request(self, connection):
def start_request(self, server_conn, request_conn):
# Modern HTTPServer interface
return _RequestDispatcher(self, connection)
return _RequestDispatcher(self, request_conn)
def __call__(self, request):
# Legacy HTTPServer interface
@ -1831,7 +1921,8 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate):
def headers_received(self, start_line, headers):
self.set_request(httputil.HTTPServerRequest(
connection=self.connection, start_line=start_line, headers=headers))
connection=self.connection, start_line=start_line,
headers=headers))
if self.stream_request_body:
self.request.body = Future()
return self.execute()
@ -1848,7 +1939,9 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate):
handlers = app._get_host_handlers(self.request)
if not handlers:
self.handler_class = RedirectHandler
self.handler_kwargs = dict(url="http://" + app.default_host + "/")
self.handler_kwargs = dict(url="%s://%s/"
% (self.request.protocol,
app.default_host))
return
for spec in handlers:
match = spec.regex.match(self.request.path)
@ -1914,11 +2007,14 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate):
if self.stream_request_body:
self.handler._prepared_future = Future()
# Note that if an exception escapes handler._execute it will be
# trapped in the Future it returns (which we are ignoring here).
# trapped in the Future it returns (which we are ignoring here,
# leaving it to be logged when the Future is GC'd).
# However, that shouldn't happen because _execute has a blanket
# except handler, and we cannot easily access the IOLoop here to
# call add_future.
self.handler._execute(transforms, *self.path_args, **self.path_kwargs)
# call add_future (because of the requirement to remain compatible
# with WSGI)
f = self.handler._execute(transforms, *self.path_args,
**self.path_kwargs)
# If we are streaming the request body, then execute() is finished
# when the handler has prepared to receive the body. If not,
# it doesn't matter when execute() finishes (so we return None)
@ -1952,6 +2048,8 @@ class HTTPError(Exception):
self.log_message = log_message
self.args = args
self.reason = kwargs.get('reason', None)
if log_message and not args:
self.log_message = log_message.replace('%', '%%')
def __str__(self):
message = "HTTP %d: %s" % (
@ -2212,7 +2310,8 @@ class StaticFileHandler(RequestHandler):
if content_type:
self.set_header("Content-Type", content_type)
cache_time = self.get_cache_time(self.path, self.modified, content_type)
cache_time = self.get_cache_time(self.path, self.modified,
content_type)
if cache_time > 0:
self.set_header("Expires", datetime.datetime.utcnow() +
datetime.timedelta(seconds=cache_time))
@ -2381,7 +2480,8 @@ class StaticFileHandler(RequestHandler):
.. versionadded:: 3.1
"""
stat_result = self._stat()
modified = datetime.datetime.utcfromtimestamp(stat_result[stat.ST_MTIME])
modified = datetime.datetime.utcfromtimestamp(
stat_result[stat.ST_MTIME])
return modified
def get_content_type(self):
@ -2624,6 +2724,8 @@ class UIModule(object):
UI modules often execute additional queries, and they can include
additional CSS and JavaScript that will be included in the output
page, which is automatically inserted on page render.
Subclasses of UIModule must override the `render` method.
"""
def __init__(self, handler):
self.handler = handler
@ -2636,31 +2738,45 @@ class UIModule(object):
return self.handler.current_user
def render(self, *args, **kwargs):
"""Overridden in subclasses to return this module's output."""
"""Override in subclasses to return this module's output."""
raise NotImplementedError()
def embedded_javascript(self):
"""Returns a JavaScript string that will be embedded in the page."""
"""Override to return a JavaScript string
to be embedded in the page."""
return None
def javascript_files(self):
"""Returns a list of JavaScript files required by this module."""
"""Override to return a list of JavaScript files needed by this module.
If the return values are relative paths, they will be passed to
`RequestHandler.static_url`; otherwise they will be used as-is.
"""
return None
def embedded_css(self):
"""Returns a CSS string that will be embedded in the page."""
"""Override to return a CSS string
that will be embedded in the page."""
return None
def css_files(self):
"""Returns a list of CSS files required by this module."""
"""Override to returns a list of CSS files required by this module.
If the return values are relative paths, they will be passed to
`RequestHandler.static_url`; otherwise they will be used as-is.
"""
return None
def html_head(self):
"""Returns a CSS string that will be put in the <head/> element"""
"""Override to return an HTML string that will be put in the <head/>
element.
"""
return None
def html_body(self):
"""Returns an HTML string that will be put in the <body/> element"""
"""Override to return an HTML string that will be put at the end of
the <body/> element.
"""
return None
def render_string(self, path, **kwargs):
@ -2862,11 +2978,13 @@ else:
return result == 0
def create_signed_value(secret, name, value, version=None, clock=None):
def create_signed_value(secret, name, value, version=None, clock=None,
key_version=None):
if version is None:
version = DEFAULT_SIGNED_VALUE_VERSION
if clock is None:
clock = time.time
timestamp = utf8(str(int(clock())))
value = base64.b64encode(utf8(value))
if version == 1:
@ -2883,7 +3001,7 @@ def create_signed_value(secret, name, value, version=None, clock=None):
#
# The fields are:
# - format version (i.e. 2; no length prefix)
# - key version (currently 0; reserved for future key rotation features)
# - key version (integer, default is 0)
# - timestamp (integer seconds since epoch)
# - name (not encoded; assumed to be ~alphanumeric)
# - value (base64-encoded)
@ -2891,34 +3009,32 @@ def create_signed_value(secret, name, value, version=None, clock=None):
def format_field(s):
return utf8("%d:" % len(s)) + utf8(s)
to_sign = b"|".join([
b"2|1:0",
b"2",
format_field(str(key_version or 0)),
format_field(timestamp),
format_field(name),
format_field(value),
b''])
if isinstance(secret, dict):
assert key_version is not None, 'Key version must be set when sign key dict is used'
assert version >= 2, 'Version must be at least 2 for key version support'
secret = secret[key_version]
signature = _create_signature_v2(secret, to_sign)
return to_sign + signature
else:
raise ValueError("Unsupported version %d" % version)
# A leading version number in decimal with no leading zeros, followed by a pipe.
# A leading version number in decimal
# with no leading zeros, followed by a pipe.
_signed_value_version_re = re.compile(br"^([1-9][0-9]*)\|(.*)$")
def decode_signed_value(secret, name, value, max_age_days=31, clock=None, min_version=None):
if clock is None:
clock = time.time
if min_version is None:
min_version = DEFAULT_SIGNED_VALUE_MIN_VERSION
if min_version > 2:
raise ValueError("Unsupported min_version %d" % min_version)
if not value:
return None
# Figure out what version this is. Version 1 did not include an
def _get_version(value):
# Figures out what version value is. Version 1 did not include an
# explicit version field and started with arbitrary base64 data,
# which makes this tricky.
value = utf8(value)
m = _signed_value_version_re.match(value)
if m is None:
version = 1
@ -2935,13 +3051,31 @@ def decode_signed_value(secret, name, value, max_age_days=31, clock=None, min_ve
version = 1
except ValueError:
version = 1
return version
def decode_signed_value(secret, name, value, max_age_days=31,
clock=None, min_version=None):
if clock is None:
clock = time.time
if min_version is None:
min_version = DEFAULT_SIGNED_VALUE_MIN_VERSION
if min_version > 2:
raise ValueError("Unsupported min_version %d" % min_version)
if not value:
return None
value = utf8(value)
version = _get_version(value)
if version < min_version:
return None
if version == 1:
return _decode_signed_value_v1(secret, name, value, max_age_days, clock)
return _decode_signed_value_v1(secret, name, value,
max_age_days, clock)
elif version == 2:
return _decode_signed_value_v2(secret, name, value, max_age_days, clock)
return _decode_signed_value_v2(secret, name, value,
max_age_days, clock)
else:
return None
@ -2964,7 +3098,8 @@ def _decode_signed_value_v1(secret, name, value, max_age_days, clock):
# digits from the payload to the timestamp without altering the
# signature. For backwards compatibility, sanity-check timestamp
# here instead of modifying _cookie_signature.
gen_log.warning("Cookie timestamp in future; possible tampering %r", value)
gen_log.warning("Cookie timestamp in future; possible tampering %r",
value)
return None
if parts[1].startswith(b"0"):
gen_log.warning("Tampered cookie %r", value)
@ -2975,7 +3110,7 @@ def _decode_signed_value_v1(secret, name, value, max_age_days, clock):
return None
def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
def _decode_fields_v2(value):
def _consume_field(s):
length, _, rest = s.partition(b':')
n = int(length)
@ -2986,16 +3121,28 @@ def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
raise ValueError("malformed v2 signed value field")
rest = rest[n + 1:]
return field_value, rest
rest = value[2:] # remove version number
try:
key_version, rest = _consume_field(rest)
timestamp, rest = _consume_field(rest)
name_field, rest = _consume_field(rest)
value_field, rest = _consume_field(rest)
value_field, passed_sig = _consume_field(rest)
return int(key_version), timestamp, name_field, value_field, passed_sig
def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
try:
key_version, timestamp, name_field, value_field, passed_sig = _decode_fields_v2(value)
except ValueError:
return None
passed_sig = rest
signed_string = value[:-len(passed_sig)]
if isinstance(secret, dict):
try:
secret = secret[key_version]
except KeyError:
return None
expected_sig = _create_signature_v2(secret, signed_string)
if not _time_independent_equals(passed_sig, expected_sig):
return None
@ -3011,6 +3158,19 @@ def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
return None
def get_signature_key_version(value):
value = utf8(value)
version = _get_version(value)
if version < 2:
return None
try:
key_version, _, _, _, _ = _decode_fields_v2(value)
except ValueError:
return None
return key_version
def _create_signature_v1(secret, *parts):
hash = hmac.new(utf8(secret), digestmod=hashlib.sha1)
for part in parts:

View file

@ -16,7 +16,8 @@ the protocol (known as "draft 76") and are not compatible with this module.
Removed support for the draft 76 protocol version.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import (absolute_import, division,
print_function, with_statement)
# Author: Jacob Kristhammar, 2010
import base64
@ -74,17 +75,22 @@ class WebSocketHandler(tornado.web.RequestHandler):
http://tools.ietf.org/html/rfc6455.
Here is an example WebSocket handler that echos back all received messages
back to the client::
back to the client:
class EchoWebSocket(websocket.WebSocketHandler):
.. testcode::
class EchoWebSocket(tornado.websocket.WebSocketHandler):
def open(self):
print "WebSocket opened"
print("WebSocket opened")
def on_message(self, message):
self.write_message(u"You said: " + message)
def on_close(self):
print "WebSocket closed"
print("WebSocket closed")
.. testoutput::
:hide:
WebSockets are not standard HTTP connections. The "handshake" is
HTTP, but after the handshake, the protocol is
@ -129,6 +135,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
self.close_code = None
self.close_reason = None
self.stream = None
self._on_close_called = False
@tornado.web.asynchronous
def get(self, *args, **kwargs):
@ -138,16 +145,22 @@ class WebSocketHandler(tornado.web.RequestHandler):
# Upgrade header should be present and should be equal to WebSocket
if self.request.headers.get("Upgrade", "").lower() != 'websocket':
self.set_status(400)
self.finish("Can \"Upgrade\" only to \"WebSocket\".")
log_msg = "Can \"Upgrade\" only to \"WebSocket\"."
self.finish(log_msg)
gen_log.debug(log_msg)
return
# Connection header should be upgrade. Some proxy servers/load balancers
# Connection header should be upgrade.
# Some proxy servers/load balancers
# might mess with it.
headers = self.request.headers
connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(","))
connection = map(lambda s: s.strip().lower(),
headers.get("Connection", "").split(","))
if 'upgrade' not in connection:
self.set_status(400)
self.finish("\"Connection\" must be \"Upgrade\".")
log_msg = "\"Connection\" must be \"Upgrade\"."
self.finish(log_msg)
gen_log.debug(log_msg)
return
# Handle WebSocket Origin naming convention differences
@ -159,30 +172,29 @@ class WebSocketHandler(tornado.web.RequestHandler):
else:
origin = self.request.headers.get("Sec-Websocket-Origin", None)
# If there was an origin header, check to make sure it matches
# according to check_origin. When the origin is None, we assume it
# did not come from a browser and that it can be passed on.
if origin is not None and not self.check_origin(origin):
self.set_status(403)
self.finish("Cross origin websockets not allowed")
log_msg = "Cross origin websockets not allowed"
self.finish(log_msg)
gen_log.debug(log_msg)
return
self.stream = self.request.connection.detach()
self.stream.set_close_callback(self.on_connection_close)
if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
self.ws_connection = WebSocketProtocol13(
self, compression_options=self.get_compression_options())
self.ws_connection = self.get_websocket_protocol()
if self.ws_connection:
self.ws_connection.accept_connection()
else:
if not self.stream.closed():
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 426 Upgrade Required\r\n"
"Sec-WebSocket-Version: 8\r\n\r\n"))
"Sec-WebSocket-Version: 7, 8, 13\r\n\r\n"))
self.stream.close()
def write_message(self, message, binary=False):
"""Sends the given message to the client of this Web Socket.
@ -229,7 +241,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
"""
return None
def open(self):
def open(self, *args, **kwargs):
"""Invoked when a new WebSocket is opened.
The arguments to `open` are extracted from the `tornado.web.URLSpec`
@ -350,6 +362,8 @@ class WebSocketHandler(tornado.web.RequestHandler):
if self.ws_connection:
self.ws_connection.on_connection_close()
self.ws_connection = None
if not self._on_close_called:
self._on_close_called = True
self.on_close()
def send_error(self, *args, **kwargs):
@ -362,6 +376,13 @@ class WebSocketHandler(tornado.web.RequestHandler):
# we can close the connection more gracefully.
self.stream.close()
def get_websocket_protocol(self):
websocket_version = self.request.headers.get("Sec-WebSocket-Version")
if websocket_version in ("7", "8", "13"):
return WebSocketProtocol13(
self, compression_options=self.get_compression_options())
def _wrap_method(method):
def _disallow_for_websocket(self, *args, **kwargs):
if self.stream is None:
@ -499,7 +520,8 @@ class WebSocketProtocol13(WebSocketProtocol):
self._handle_websocket_headers()
self._accept_connection()
except ValueError:
gen_log.debug("Malformed WebSocket request received", exc_info=True)
gen_log.debug("Malformed WebSocket request received",
exc_info=True)
self._abort()
return
@ -535,7 +557,8 @@ class WebSocketProtocol13(WebSocketProtocol):
selected = self.handler.select_subprotocol(subprotocols)
if selected:
assert selected in subprotocols
subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
subprotocol_header = ("Sec-WebSocket-Protocol: %s\r\n"
% selected)
extension_header = ''
extensions = self._parse_extensions_header(self.request.headers)
@ -703,7 +726,8 @@ class WebSocketProtocol13(WebSocketProtocol):
if self._masked_frame:
self.stream.read_bytes(4, self._on_masking_key)
else:
self.stream.read_bytes(self._frame_length, self._on_frame_data)
self.stream.read_bytes(self._frame_length,
self._on_frame_data)
elif payloadlen == 126:
self.stream.read_bytes(2, self._on_frame_length_16)
elif payloadlen == 127:
@ -737,7 +761,8 @@ class WebSocketProtocol13(WebSocketProtocol):
self._wire_bytes_in += len(data)
self._frame_mask = data
try:
self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
self.stream.read_bytes(self._frame_length,
self._on_masked_frame_data)
except StreamClosedError:
self._abort()
@ -852,12 +877,15 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
This class should not be instantiated directly; use the
`websocket_connect` function instead.
"""
def __init__(self, io_loop, request, compression_options=None):
def __init__(self, io_loop, request, on_message_callback=None,
compression_options=None):
self.compression_options = compression_options
self.connect_future = TracebackFuture()
self.protocol = None
self.read_future = None
self.read_queue = collections.deque()
self.key = base64.b64encode(os.urandom(16))
self._on_message_callback = on_message_callback
scheme, sep, rest = request.url.partition(':')
scheme = {'ws': 'http', 'wss': 'https'}[scheme]
@ -880,7 +908,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
self.tcp_client = TCPClient(io_loop=io_loop)
super(WebSocketClientConnection, self).__init__(
io_loop, None, request, lambda: None, self._on_http_response,
104857600, self.tcp_client, 65536)
104857600, self.tcp_client, 65536, 104857600)
def close(self, code=None, reason=None):
"""Closes the websocket connection.
@ -919,9 +947,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
start_line, headers)
self.headers = headers
self.protocol = WebSocketProtocol13(
self, mask_outgoing=True,
compression_options=self.compression_options)
self.protocol = self.get_websocket_protocol()
self.protocol._process_server_headers(self.key, self.headers)
self.protocol._receive_frame()
@ -946,6 +972,9 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
def read_message(self, callback=None):
"""Reads a message from the WebSocket server.
If on_message_callback was specified at WebSocket
initialization, this function will never return messages
Returns a future whose result is the message, or None
if the connection is closed. If a callback argument
is given it will be called with the future when it is
@ -962,7 +991,9 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
return future
def on_message(self, message):
if self.read_future is not None:
if self._on_message_callback:
self._on_message_callback(message)
elif self.read_future is not None:
self.read_future.set_result(message)
self.read_future = None
else:
@ -971,9 +1002,13 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
def on_pong(self, data):
pass
def get_websocket_protocol(self):
return WebSocketProtocol13(self, mask_outgoing=True,
compression_options=self.compression_options)
def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
compression_options=None):
on_message_callback=None, compression_options=None):
"""Client-side websocket support.
Takes a url and returns a Future whose result is a
@ -982,11 +1017,26 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
``compression_options`` is interpreted in the same way as the
return value of `.WebSocketHandler.get_compression_options`.
The connection supports two styles of operation. In the coroutine
style, the application typically calls
`~.WebSocketClientConnection.read_message` in a loop::
conn = yield websocket_connection(loop)
while True:
msg = yield conn.read_message()
if msg is None: break
# Do something with msg
In the callback style, pass an ``on_message_callback`` to
``websocket_connect``. In both styles, a message of ``None``
indicates that the connection has been closed.
.. versionchanged:: 3.2
Also accepts ``HTTPRequest`` objects in place of urls.
.. versionchanged:: 4.1
Added ``compression_options``.
Added ``compression_options`` and ``on_message_callback``.
The ``io_loop`` argument is deprecated.
"""
if io_loop is None:
io_loop = IOLoop.current()
@ -1000,7 +1050,9 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
request = httpclient._RequestProxy(
request, httpclient.HTTPRequest._DEFAULTS)
conn = WebSocketClientConnection(io_loop, request, compression_options)
conn = WebSocketClientConnection(io_loop, request,
on_message_callback=on_message_callback,
compression_options=compression_options)
if callback is not None:
io_loop.add_future(conn.connect_future, callback)
return conn.connect_future

View file

@ -207,7 +207,7 @@ class WSGIAdapter(object):
body = environ["wsgi.input"].read(
int(headers["Content-Length"]))
else:
body = ""
body = b""
protocol = environ["wsgi.url_scheme"]
remote_ip = environ.get("REMOTE_ADDR", "")
if environ.get("HTTP_HOST"):
@ -253,7 +253,7 @@ class WSGIContainer(object):
container = tornado.wsgi.WSGIContainer(simple_app)
http_server = tornado.httpserver.HTTPServer(container)
http_server.listen(8888)
tornado.ioloop.IOLoop.instance().start()
tornado.ioloop.IOLoop.current().start()
This class is intended to let other frameworks (Django, web.py, etc)
run on the Tornado HTTP server and I/O loop.
@ -284,7 +284,8 @@ class WSGIContainer(object):
if not data:
raise Exception("WSGI app did not call start_response")
status_code = int(data["status"].split()[0])
status_code, reason = data["status"].split(' ', 1)
status_code = int(status_code)
headers = data["headers"]
header_set = set(k.lower() for (k, v) in headers)
body = escape.utf8(body)
@ -296,13 +297,12 @@ class WSGIContainer(object):
if "server" not in header_set:
headers.append(("Server", "TornadoServer/%s" % tornado.version))
parts = [escape.utf8("HTTP/1.1 " + data["status"] + "\r\n")]
start_line = httputil.ResponseStartLine("HTTP/1.1", status_code, reason)
header_obj = httputil.HTTPHeaders()
for key, value in headers:
parts.append(escape.utf8(key) + b": " + escape.utf8(value) + b"\r\n")
parts.append(b"\r\n")
parts.append(body)
request.write(b"".join(parts))
request.finish()
header_obj.add(key, value)
request.connection.write_headers(start_line, header_obj, chunk=body)
request.connection.finish()
self._log(status_code, request)
@staticmethod