diff --git a/CHANGES.md b/CHANGES.md
index f6ac5ad0..b8108a21 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,5 +1,6 @@
### 0.x.x (2015-xx-xx xx:xx:xx UTC)
+* Update Tornado webserver to 4.2.dev1 (609dbb9)
* Change network names to only display on top line of Day by Day layout on Episode View
* Reposition country part of network name into the hover over in Day by Day layout
* Add ToTV provider
diff --git a/tornado/__init__.py b/tornado/__init__.py
index 0e39f842..01b926be 100644
--- a/tornado/__init__.py
+++ b/tornado/__init__.py
@@ -25,5 +25,5 @@ from __future__ import absolute_import, division, print_function, with_statement
# is zero for an official release, positive for a development branch,
# or negative for a release candidate or beta (after the base version
# number has been incremented)
-version = "4.1.dev1"
-version_info = (4, 1, 0, -100)
+version = "4.2.dev1"
+version_info = (4, 2, 0, -100)
diff --git a/tornado/auth.py b/tornado/auth.py
index ac2fd0d1..800b10af 100644
--- a/tornado/auth.py
+++ b/tornado/auth.py
@@ -32,7 +32,9 @@ They all take slightly different arguments due to the fact all these
services implement authentication and authorization slightly differently.
See the individual service classes below for complete documentation.
-Example usage for Google OpenID::
+Example usage for Google OAuth:
+
+.. testcode::
class GoogleOAuth2LoginHandler(tornado.web.RequestHandler,
tornado.auth.GoogleOAuth2Mixin):
@@ -51,6 +53,10 @@ Example usage for Google OpenID::
response_type='code',
extra_params={'approval_prompt': 'auto'})
+.. testoutput::
+ :hide:
+
+
.. versionchanged:: 4.0
All of the callback interfaces in this module are now guaranteed
to run their callback with an argument of ``None`` on error.
@@ -69,7 +75,7 @@ import hmac
import time
import uuid
-from tornado.concurrent import TracebackFuture, chain_future, return_future
+from tornado.concurrent import TracebackFuture, return_future
from tornado import gen
from tornado import httpclient
from tornado import escape
@@ -123,6 +129,7 @@ def _auth_return_future(f):
if callback is not None:
future.add_done_callback(
functools.partial(_auth_future_to_callback, callback))
+
def handle_exception(typ, value, tb):
if future.done():
return False
@@ -138,9 +145,6 @@ def _auth_return_future(f):
class OpenIdMixin(object):
"""Abstract implementation of OpenID and Attribute Exchange.
- See `GoogleMixin` below for a customized example (which also
- includes OAuth support).
-
Class attributes:
* ``_OPENID_ENDPOINT``: the identity provider's URI.
@@ -312,8 +316,7 @@ class OpenIdMixin(object):
class OAuthMixin(object):
"""Abstract implementation of OAuth 1.0 and 1.0a.
- See `TwitterMixin` and `FriendFeedMixin` below for example implementations,
- or `GoogleMixin` for an OAuth/OpenID hybrid.
+ See `TwitterMixin` below for an example implementation.
Class attributes:
@@ -565,7 +568,8 @@ class OAuthMixin(object):
class OAuth2Mixin(object):
"""Abstract implementation of OAuth 2.0.
- See `FacebookGraphMixin` below for an example implementation.
+ See `FacebookGraphMixin` or `GoogleOAuth2Mixin` below for example
+ implementations.
Class attributes:
@@ -629,7 +633,9 @@ class TwitterMixin(OAuthMixin):
URL you registered as your application's callback URL.
When your application is set up, you can use this mixin like this
- to authenticate the user with Twitter and get access to their stream::
+ to authenticate the user with Twitter and get access to their stream:
+
+ .. testcode::
class TwitterLoginHandler(tornado.web.RequestHandler,
tornado.auth.TwitterMixin):
@@ -641,6 +647,9 @@ class TwitterMixin(OAuthMixin):
else:
yield self.authorize_redirect()
+ .. testoutput::
+ :hide:
+
The user object returned by `~OAuthMixin.get_authenticated_user`
includes the attributes ``username``, ``name``, ``access_token``,
and all of the custom Twitter user attributes described at
@@ -689,7 +698,9 @@ class TwitterMixin(OAuthMixin):
`~OAuthMixin.get_authenticated_user`. The user returned through that
process includes an 'access_token' attribute that can be used
to make authenticated requests via this method. Example
- usage::
+ usage:
+
+ .. testcode::
class MainHandler(tornado.web.RequestHandler,
tornado.auth.TwitterMixin):
@@ -706,6 +717,9 @@ class TwitterMixin(OAuthMixin):
return
self.finish("Posted a message!")
+ .. testoutput::
+ :hide:
+
"""
if path.startswith('http:') or path.startswith('https:'):
# Raw urls are useful for e.g. search which doesn't follow the
@@ -757,223 +771,6 @@ class TwitterMixin(OAuthMixin):
raise gen.Return(user)
-class FriendFeedMixin(OAuthMixin):
- """FriendFeed OAuth authentication.
-
- To authenticate with FriendFeed, register your application with
- FriendFeed at http://friendfeed.com/api/applications. Then copy
- your Consumer Key and Consumer Secret to the application
- `~tornado.web.Application.settings` ``friendfeed_consumer_key``
- and ``friendfeed_consumer_secret``. Use this mixin on the handler
- for the URL you registered as your application's Callback URL.
-
- When your application is set up, you can use this mixin like this
- to authenticate the user with FriendFeed and get access to their feed::
-
- class FriendFeedLoginHandler(tornado.web.RequestHandler,
- tornado.auth.FriendFeedMixin):
- @tornado.gen.coroutine
- def get(self):
- if self.get_argument("oauth_token", None):
- user = yield self.get_authenticated_user()
- # Save the user using e.g. set_secure_cookie()
- else:
- yield self.authorize_redirect()
-
- The user object returned by `~OAuthMixin.get_authenticated_user()` includes the
- attributes ``username``, ``name``, and ``description`` in addition to
- ``access_token``. You should save the access token with the user;
- it is required to make requests on behalf of the user later with
- `friendfeed_request()`.
- """
- _OAUTH_VERSION = "1.0"
- _OAUTH_REQUEST_TOKEN_URL = "https://friendfeed.com/account/oauth/request_token"
- _OAUTH_ACCESS_TOKEN_URL = "https://friendfeed.com/account/oauth/access_token"
- _OAUTH_AUTHORIZE_URL = "https://friendfeed.com/account/oauth/authorize"
- _OAUTH_NO_CALLBACKS = True
- _OAUTH_VERSION = "1.0"
-
- @_auth_return_future
- def friendfeed_request(self, path, callback, access_token=None,
- post_args=None, **args):
- """Fetches the given relative API path, e.g., "/bret/friends"
-
- If the request is a POST, ``post_args`` should be provided. Query
- string arguments should be given as keyword arguments.
-
- All the FriendFeed methods are documented at
- http://friendfeed.com/api/documentation.
-
- Many methods require an OAuth access token which you can
- obtain through `~OAuthMixin.authorize_redirect` and
- `~OAuthMixin.get_authenticated_user`. The user returned
- through that process includes an ``access_token`` attribute that
- can be used to make authenticated requests via this
- method.
-
- Example usage::
-
- class MainHandler(tornado.web.RequestHandler,
- tornado.auth.FriendFeedMixin):
- @tornado.web.authenticated
- @tornado.gen.coroutine
- def get(self):
- new_entry = yield self.friendfeed_request(
- "/entry",
- post_args={"body": "Testing Tornado Web Server"},
- access_token=self.current_user["access_token"])
-
- if not new_entry:
- # Call failed; perhaps missing permission?
- yield self.authorize_redirect()
- return
- self.finish("Posted a message!")
-
- """
- # Add the OAuth resource request signature if we have credentials
- url = "http://friendfeed-api.com/v2" + path
- if access_token:
- all_args = {}
- all_args.update(args)
- all_args.update(post_args or {})
- method = "POST" if post_args is not None else "GET"
- oauth = self._oauth_request_parameters(
- url, access_token, all_args, method=method)
- args.update(oauth)
- if args:
- url += "?" + urllib_parse.urlencode(args)
- callback = functools.partial(self._on_friendfeed_request, callback)
- http = self.get_auth_http_client()
- if post_args is not None:
- http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args),
- callback=callback)
- else:
- http.fetch(url, callback=callback)
-
- def _on_friendfeed_request(self, future, response):
- if response.error:
- future.set_exception(AuthError(
- "Error response %s fetching %s" % (response.error,
- response.request.url)))
- return
- future.set_result(escape.json_decode(response.body))
-
- def _oauth_consumer_token(self):
- self.require_setting("friendfeed_consumer_key", "FriendFeed OAuth")
- self.require_setting("friendfeed_consumer_secret", "FriendFeed OAuth")
- return dict(
- key=self.settings["friendfeed_consumer_key"],
- secret=self.settings["friendfeed_consumer_secret"])
-
- @gen.coroutine
- def _oauth_get_user_future(self, access_token, callback):
- user = yield self.friendfeed_request(
- "/feedinfo/" + access_token["username"],
- include="id,name,description", access_token=access_token)
- if user:
- user["username"] = user["id"]
- callback(user)
-
- def _parse_user_response(self, callback, user):
- if user:
- user["username"] = user["id"]
- callback(user)
-
-
-class GoogleMixin(OpenIdMixin, OAuthMixin):
- """Google Open ID / OAuth authentication.
-
- .. deprecated:: 4.0
- New applications should use `GoogleOAuth2Mixin`
- below instead of this class. As of May 19, 2014, Google has stopped
- supporting registration-free authentication.
-
- No application registration is necessary to use Google for
- authentication or to access Google resources on behalf of a user.
-
- Google implements both OpenID and OAuth in a hybrid mode. If you
- just need the user's identity, use
- `~OpenIdMixin.authenticate_redirect`. If you need to make
- requests to Google on behalf of the user, use
- `authorize_redirect`. On return, parse the response with
- `~OpenIdMixin.get_authenticated_user`. We send a dict containing
- the values for the user, including ``email``, ``name``, and
- ``locale``.
-
- Example usage::
-
- class GoogleLoginHandler(tornado.web.RequestHandler,
- tornado.auth.GoogleMixin):
- @tornado.gen.coroutine
- def get(self):
- if self.get_argument("openid.mode", None):
- user = yield self.get_authenticated_user()
- # Save the user with e.g. set_secure_cookie()
- else:
- yield self.authenticate_redirect()
- """
- _OPENID_ENDPOINT = "https://www.google.com/accounts/o8/ud"
- _OAUTH_ACCESS_TOKEN_URL = "https://www.google.com/accounts/OAuthGetAccessToken"
-
- @return_future
- def authorize_redirect(self, oauth_scope, callback_uri=None,
- ax_attrs=["name", "email", "language", "username"],
- callback=None):
- """Authenticates and authorizes for the given Google resource.
-
- Some of the available resources which can be used in the ``oauth_scope``
- argument are:
-
- * Gmail Contacts - http://www.google.com/m8/feeds/
- * Calendar - http://www.google.com/calendar/feeds/
- * Finance - http://finance.google.com/finance/feeds/
-
- You can authorize multiple resources by separating the resource
- URLs with a space.
-
- .. versionchanged:: 3.1
- Returns a `.Future` and takes an optional callback. These are
- not strictly necessary as this method is synchronous,
- but they are supplied for consistency with
- `OAuthMixin.authorize_redirect`.
- """
- callback_uri = callback_uri or self.request.uri
- args = self._openid_args(callback_uri, ax_attrs=ax_attrs,
- oauth_scope=oauth_scope)
- self.redirect(self._OPENID_ENDPOINT + "?" + urllib_parse.urlencode(args))
- callback()
-
- @_auth_return_future
- def get_authenticated_user(self, callback):
- """Fetches the authenticated user data upon redirect."""
- # Look to see if we are doing combined OpenID/OAuth
- oauth_ns = ""
- for name, values in self.request.arguments.items():
- if name.startswith("openid.ns.") and \
- values[-1] == b"http://specs.openid.net/extensions/oauth/1.0":
- oauth_ns = name[10:]
- break
- token = self.get_argument("openid." + oauth_ns + ".request_token", "")
- if token:
- http = self.get_auth_http_client()
- token = dict(key=token, secret="")
- http.fetch(self._oauth_access_token_url(token),
- functools.partial(self._on_access_token, callback))
- else:
- chain_future(OpenIdMixin.get_authenticated_user(self),
- callback)
-
- def _oauth_consumer_token(self):
- self.require_setting("google_consumer_key", "Google OAuth")
- self.require_setting("google_consumer_secret", "Google OAuth")
- return dict(
- key=self.settings["google_consumer_key"],
- secret=self.settings["google_consumer_secret"])
-
- def _oauth_get_user_future(self, access_token):
- return OpenIdMixin.get_authenticated_user(self)
-
-
class GoogleOAuth2Mixin(OAuth2Mixin):
"""Google authentication using OAuth2.
@@ -1001,7 +798,9 @@ class GoogleOAuth2Mixin(OAuth2Mixin):
def get_authenticated_user(self, redirect_uri, code, callback):
"""Handles the login for the Google user, returning a user object.
- Example usage::
+ Example usage:
+
+ .. testcode::
class GoogleOAuth2LoginHandler(tornado.web.RequestHandler,
tornado.auth.GoogleOAuth2Mixin):
@@ -1019,6 +818,10 @@ class GoogleOAuth2Mixin(OAuth2Mixin):
scope=['profile', 'email'],
response_type='code',
extra_params={'approval_prompt': 'auto'})
+
+ .. testoutput::
+ :hide:
+
"""
http = self.get_auth_http_client()
body = urllib_parse.urlencode({
@@ -1051,217 +854,6 @@ class GoogleOAuth2Mixin(OAuth2Mixin):
return httpclient.AsyncHTTPClient()
-class FacebookMixin(object):
- """Facebook Connect authentication.
-
- .. deprecated:: 1.1
- New applications should use `FacebookGraphMixin`
- below instead of this class. This class does not support the
- Future-based interface seen on other classes in this module.
-
- To authenticate with Facebook, register your application with
- Facebook at http://www.facebook.com/developers/apps.php. Then
- copy your API Key and Application Secret to the application settings
- ``facebook_api_key`` and ``facebook_secret``.
-
- When your application is set up, you can use this mixin like this
- to authenticate the user with Facebook::
-
- class FacebookHandler(tornado.web.RequestHandler,
- tornado.auth.FacebookMixin):
- @tornado.web.asynchronous
- def get(self):
- if self.get_argument("session", None):
- self.get_authenticated_user(self._on_auth)
- return
- yield self.authenticate_redirect()
-
- def _on_auth(self, user):
- if not user:
- raise tornado.web.HTTPError(500, "Facebook auth failed")
- # Save the user using, e.g., set_secure_cookie()
-
- The user object returned by `get_authenticated_user` includes the
- attributes ``facebook_uid`` and ``name`` in addition to session attributes
- like ``session_key``. You should save the session key with the user; it is
- required to make requests on behalf of the user later with
- `facebook_request`.
- """
- @return_future
- def authenticate_redirect(self, callback_uri=None, cancel_uri=None,
- extended_permissions=None, callback=None):
- """Authenticates/installs this app for the current user.
-
- .. versionchanged:: 3.1
- Returns a `.Future` and takes an optional callback. These are
- not strictly necessary as this method is synchronous,
- but they are supplied for consistency with
- `OAuthMixin.authorize_redirect`.
- """
- self.require_setting("facebook_api_key", "Facebook Connect")
- callback_uri = callback_uri or self.request.uri
- args = {
- "api_key": self.settings["facebook_api_key"],
- "v": "1.0",
- "fbconnect": "true",
- "display": "page",
- "next": urlparse.urljoin(self.request.full_url(), callback_uri),
- "return_session": "true",
- }
- if cancel_uri:
- args["cancel_url"] = urlparse.urljoin(
- self.request.full_url(), cancel_uri)
- if extended_permissions:
- if isinstance(extended_permissions, (unicode_type, bytes)):
- extended_permissions = [extended_permissions]
- args["req_perms"] = ",".join(extended_permissions)
- self.redirect("http://www.facebook.com/login.php?" +
- urllib_parse.urlencode(args))
- callback()
-
- def authorize_redirect(self, extended_permissions, callback_uri=None,
- cancel_uri=None, callback=None):
- """Redirects to an authorization request for the given FB resource.
-
- The available resource names are listed at
- http://wiki.developers.facebook.com/index.php/Extended_permission.
- The most common resource types include:
-
- * publish_stream
- * read_stream
- * email
- * sms
-
- extended_permissions can be a single permission name or a list of
- names. To get the session secret and session key, call
- get_authenticated_user() just as you would with
- authenticate_redirect().
-
- .. versionchanged:: 3.1
- Returns a `.Future` and takes an optional callback. These are
- not strictly necessary as this method is synchronous,
- but they are supplied for consistency with
- `OAuthMixin.authorize_redirect`.
- """
- return self.authenticate_redirect(callback_uri, cancel_uri,
- extended_permissions,
- callback=callback)
-
- def get_authenticated_user(self, callback):
- """Fetches the authenticated Facebook user.
-
- The authenticated user includes the special Facebook attributes
- 'session_key' and 'facebook_uid' in addition to the standard
- user attributes like 'name'.
- """
- self.require_setting("facebook_api_key", "Facebook Connect")
- session = escape.json_decode(self.get_argument("session"))
- self.facebook_request(
- method="facebook.users.getInfo",
- callback=functools.partial(
- self._on_get_user_info, callback, session),
- session_key=session["session_key"],
- uids=session["uid"],
- fields="uid,first_name,last_name,name,locale,pic_square,"
- "profile_url,username")
-
- def facebook_request(self, method, callback, **args):
- """Makes a Facebook API REST request.
-
- We automatically include the Facebook API key and signature, but
- it is the callers responsibility to include 'session_key' and any
- other required arguments to the method.
-
- The available Facebook methods are documented here:
- http://wiki.developers.facebook.com/index.php/API
-
- Here is an example for the stream.get() method::
-
- class MainHandler(tornado.web.RequestHandler,
- tornado.auth.FacebookMixin):
- @tornado.web.authenticated
- @tornado.web.asynchronous
- def get(self):
- self.facebook_request(
- method="stream.get",
- callback=self._on_stream,
- session_key=self.current_user["session_key"])
-
- def _on_stream(self, stream):
- if stream is None:
- # Not authorized to read the stream yet?
- self.redirect(self.authorize_redirect("read_stream"))
- return
- self.render("stream.html", stream=stream)
-
- """
- self.require_setting("facebook_api_key", "Facebook Connect")
- self.require_setting("facebook_secret", "Facebook Connect")
- if not method.startswith("facebook."):
- method = "facebook." + method
- args["api_key"] = self.settings["facebook_api_key"]
- args["v"] = "1.0"
- args["method"] = method
- args["call_id"] = str(long(time.time() * 1e6))
- args["format"] = "json"
- args["sig"] = self._signature(args)
- url = "http://api.facebook.com/restserver.php?" + \
- urllib_parse.urlencode(args)
- http = self.get_auth_http_client()
- http.fetch(url, callback=functools.partial(
- self._parse_response, callback))
-
- def _on_get_user_info(self, callback, session, users):
- if users is None:
- callback(None)
- return
- callback({
- "name": users[0]["name"],
- "first_name": users[0]["first_name"],
- "last_name": users[0]["last_name"],
- "uid": users[0]["uid"],
- "locale": users[0]["locale"],
- "pic_square": users[0]["pic_square"],
- "profile_url": users[0]["profile_url"],
- "username": users[0].get("username"),
- "session_key": session["session_key"],
- "session_expires": session.get("expires"),
- })
-
- def _parse_response(self, callback, response):
- if response.error:
- gen_log.warning("HTTP error from Facebook: %s", response.error)
- callback(None)
- return
- try:
- json = escape.json_decode(response.body)
- except Exception:
- gen_log.warning("Invalid JSON from Facebook: %r", response.body)
- callback(None)
- return
- if isinstance(json, dict) and json.get("error_code"):
- gen_log.warning("Facebook error: %d: %r", json["error_code"],
- json.get("error_msg"))
- callback(None)
- return
- callback(json)
-
- def _signature(self, args):
- parts = ["%s=%s" % (n, args[n]) for n in sorted(args.keys())]
- body = "".join(parts) + self.settings["facebook_secret"]
- if isinstance(body, unicode_type):
- body = body.encode("utf-8")
- return hashlib.md5(body).hexdigest()
-
- def get_auth_http_client(self):
- """Returns the `.AsyncHTTPClient` instance to be used for auth requests.
-
- May be overridden by subclasses to use an HTTP client other than
- the default.
- """
- return httpclient.AsyncHTTPClient()
-
-
class FacebookGraphMixin(OAuth2Mixin):
"""Facebook authentication using the new Graph API and OAuth2."""
_OAUTH_ACCESS_TOKEN_URL = "https://graph.facebook.com/oauth/access_token?"
@@ -1274,9 +866,12 @@ class FacebookGraphMixin(OAuth2Mixin):
code, callback, extra_fields=None):
"""Handles the login for the Facebook user, returning a user object.
- Example usage::
+ Example usage:
- class FacebookGraphLoginHandler(LoginHandler, tornado.auth.FacebookGraphMixin):
+ .. testcode::
+
+ class FacebookGraphLoginHandler(tornado.web.RequestHandler,
+ tornado.auth.FacebookGraphMixin):
@tornado.gen.coroutine
def get(self):
if self.get_argument("code", False):
@@ -1291,6 +886,10 @@ class FacebookGraphMixin(OAuth2Mixin):
redirect_uri='/auth/facebookgraph/',
client_id=self.settings["facebook_api_key"],
extra_params={"scope": "read_stream,offline_access"})
+
+ .. testoutput::
+ :hide:
+
"""
http = self.get_auth_http_client()
args = {
@@ -1307,7 +906,7 @@ class FacebookGraphMixin(OAuth2Mixin):
http.fetch(self._oauth_request_token_url(**args),
functools.partial(self._on_access_token, redirect_uri, client_id,
- client_secret, callback, fields))
+ client_secret, callback, fields))
def _on_access_token(self, redirect_uri, client_id, client_secret,
future, fields, response):
@@ -1358,7 +957,9 @@ class FacebookGraphMixin(OAuth2Mixin):
process includes an ``access_token`` attribute that can be
used to make authenticated requests via this method.
- Example usage::
+ Example usage:
+
+ ..testcode::
class MainHandler(tornado.web.RequestHandler,
tornado.auth.FacebookGraphMixin):
@@ -1376,6 +977,9 @@ class FacebookGraphMixin(OAuth2Mixin):
return
self.finish("Posted a message!")
+ .. testoutput::
+ :hide:
+
The given path is relative to ``self._FACEBOOK_BASE_URL``,
by default "https://graph.facebook.com".
diff --git a/tornado/autoreload.py b/tornado/autoreload.py
index 3982579a..a52ddde4 100644
--- a/tornado/autoreload.py
+++ b/tornado/autoreload.py
@@ -100,6 +100,14 @@ try:
except ImportError:
signal = None
+# os.execv is broken on Windows and can't properly parse command line
+# arguments and executable name if they contain whitespaces. subprocess
+# fixes that behavior.
+# This distinction is also important because when we use execv, we want to
+# close the IOLoop and all its file descriptors, to guard against any
+# file descriptors that were not set CLOEXEC. When execv is not available,
+# we must not close the IOLoop because we want the process to exit cleanly.
+_has_execv = sys.platform != 'win32'
_watched_files = set()
_reload_hooks = []
@@ -108,14 +116,19 @@ _io_loops = weakref.WeakKeyDictionary()
def start(io_loop=None, check_time=500):
- """Begins watching source files for changes using the given `.IOLoop`. """
+ """Begins watching source files for changes.
+
+ .. versionchanged:: 4.1
+ The ``io_loop`` argument is deprecated.
+ """
io_loop = io_loop or ioloop.IOLoop.current()
if io_loop in _io_loops:
return
_io_loops[io_loop] = True
if len(_io_loops) > 1:
gen_log.warning("tornado.autoreload started more than once in the same process")
- add_reload_hook(functools.partial(io_loop.close, all_fds=True))
+ if _has_execv:
+ add_reload_hook(functools.partial(io_loop.close, all_fds=True))
modify_times = {}
callback = functools.partial(_reload_on_update, modify_times)
scheduler = ioloop.PeriodicCallback(callback, check_time, io_loop=io_loop)
@@ -162,7 +175,7 @@ def _reload_on_update(modify_times):
# processes restarted themselves, they'd all restart and then
# all call fork_processes again.
return
- for module in sys.modules.values():
+ for module in list(sys.modules.values()):
# Some modules play games with sys.modules (e.g. email/__init__.py
# in the standard library), and occasionally this can cause strange
# failures in getattr. Just ignore anything that's not an ordinary
@@ -211,10 +224,7 @@ def _reload():
not os.environ.get("PYTHONPATH", "").startswith(path_prefix)):
os.environ["PYTHONPATH"] = (path_prefix +
os.environ.get("PYTHONPATH", ""))
- if sys.platform == 'win32':
- # os.execv is broken on Windows and can't properly parse command line
- # arguments and executable name if they contain whitespaces. subprocess
- # fixes that behavior.
+ if not _has_execv:
subprocess.Popen([sys.executable] + sys.argv)
sys.exit(0)
else:
@@ -234,7 +244,10 @@ def _reload():
# this error specifically.
os.spawnv(os.P_NOWAIT, sys.executable,
[sys.executable] + sys.argv)
- sys.exit(0)
+ # At this point the IOLoop has been closed and finally
+ # blocks will experience errors if we allow the stack to
+ # unwind, so just exit uncleanly.
+ os._exit(0)
_USAGE = """\
Usage:
diff --git a/tornado/concurrent.py b/tornado/concurrent.py
index 6bab5d2e..51ae239f 100644
--- a/tornado/concurrent.py
+++ b/tornado/concurrent.py
@@ -25,11 +25,13 @@ module.
from __future__ import absolute_import, division, print_function, with_statement
import functools
+import platform
+import traceback
import sys
+from tornado.log import app_log
from tornado.stack_context import ExceptionStackContext, wrap
from tornado.util import raise_exc_info, ArgReplacer
-from tornado.log import app_log
try:
from concurrent import futures
@@ -37,9 +39,90 @@ except ImportError:
futures = None
+# Can the garbage collector handle cycles that include __del__ methods?
+# This is true in cpython beginning with version 3.4 (PEP 442).
+_GC_CYCLE_FINALIZERS = (platform.python_implementation() == 'CPython' and
+ sys.version_info >= (3, 4))
+
+
class ReturnValueIgnoredError(Exception):
pass
+# This class and associated code in the future object is derived
+# from the Trollius project, a backport of asyncio to Python 2.x - 3.x
+
+
+class _TracebackLogger(object):
+ """Helper to log a traceback upon destruction if not cleared.
+
+ This solves a nasty problem with Futures and Tasks that have an
+ exception set: if nobody asks for the exception, the exception is
+ never logged. This violates the Zen of Python: 'Errors should
+ never pass silently. Unless explicitly silenced.'
+
+ However, we don't want to log the exception as soon as
+ set_exception() is called: if the calling code is written
+ properly, it will get the exception and handle it properly. But
+ we *do* want to log it if result() or exception() was never called
+ -- otherwise developers waste a lot of time wondering why their
+ buggy code fails silently.
+
+ An earlier attempt added a __del__() method to the Future class
+ itself, but this backfired because the presence of __del__()
+ prevents garbage collection from breaking cycles. A way out of
+ this catch-22 is to avoid having a __del__() method on the Future
+ class itself, but instead to have a reference to a helper object
+ with a __del__() method that logs the traceback, where we ensure
+ that the helper object doesn't participate in cycles, and only the
+ Future has a reference to it.
+
+ The helper object is added when set_exception() is called. When
+ the Future is collected, and the helper is present, the helper
+ object is also collected, and its __del__() method will log the
+ traceback. When the Future's result() or exception() method is
+ called (and a helper object is present), it removes the the helper
+ object, after calling its clear() method to prevent it from
+ logging.
+
+ One downside is that we do a fair amount of work to extract the
+ traceback from the exception, even when it is never logged. It
+ would seem cheaper to just store the exception object, but that
+ references the traceback, which references stack frames, which may
+ reference the Future, which references the _TracebackLogger, and
+ then the _TracebackLogger would be included in a cycle, which is
+ what we're trying to avoid! As an optimization, we don't
+ immediately format the exception; we only do the work when
+ activate() is called, which call is delayed until after all the
+ Future's callbacks have run. Since usually a Future has at least
+ one callback (typically set by 'yield From') and usually that
+ callback extracts the callback, thereby removing the need to
+ format the exception.
+
+ PS. I don't claim credit for this solution. I first heard of it
+ in a discussion about closing files when they are collected.
+ """
+
+ __slots__ = ('exc_info', 'formatted_tb')
+
+ def __init__(self, exc_info):
+ self.exc_info = exc_info
+ self.formatted_tb = None
+
+ def activate(self):
+ exc_info = self.exc_info
+ if exc_info is not None:
+ self.exc_info = None
+ self.formatted_tb = traceback.format_exception(*exc_info)
+
+ def clear(self):
+ self.exc_info = None
+ self.formatted_tb = None
+
+ def __del__(self):
+ if self.formatted_tb:
+ app_log.error('Future exception was never retrieved: %s',
+ ''.join(self.formatted_tb).rstrip())
+
class Future(object):
"""Placeholder for an asynchronous result.
@@ -68,12 +151,23 @@ class Future(object):
if that package was available and fall back to the thread-unsafe
implementation if it was not.
+ .. versionchanged:: 4.1
+ If a `.Future` contains an error but that error is never observed
+ (by calling ``result()``, ``exception()``, or ``exc_info()``),
+ a stack trace will be logged when the `.Future` is garbage collected.
+ This normally indicates an error in the application, but in cases
+ where it results in undesired logging it may be necessary to
+ suppress the logging by ensuring that the exception is observed:
+ ``f.add_done_callback(lambda f: f.exception())``.
"""
def __init__(self):
self._done = False
self._result = None
- self._exception = None
self._exc_info = None
+
+ self._log_traceback = False # Used for Python >= 3.4
+ self._tb_logger = None # Used for Python <= 3.3
+
self._callbacks = []
def cancel(self):
@@ -100,16 +194,21 @@ class Future(object):
"""Returns True if the future has finished running."""
return self._done
+ def _clear_tb_log(self):
+ self._log_traceback = False
+ if self._tb_logger is not None:
+ self._tb_logger.clear()
+ self._tb_logger = None
+
def result(self, timeout=None):
"""If the operation succeeded, return its result. If it failed,
re-raise its exception.
"""
+ self._clear_tb_log()
if self._result is not None:
return self._result
if self._exc_info is not None:
raise_exc_info(self._exc_info)
- elif self._exception is not None:
- raise self._exception
self._check_done()
return self._result
@@ -117,8 +216,9 @@ class Future(object):
"""If the operation raised an exception, return the `Exception`
object. Otherwise returns None.
"""
- if self._exception is not None:
- return self._exception
+ self._clear_tb_log()
+ if self._exc_info is not None:
+ return self._exc_info[1]
else:
self._check_done()
return None
@@ -147,14 +247,17 @@ class Future(object):
def set_exception(self, exception):
"""Sets the exception of a ``Future.``"""
- self._exception = exception
- self._set_done()
+ self.set_exc_info(
+ (exception.__class__,
+ exception,
+ getattr(exception, '__traceback__', None)))
def exc_info(self):
"""Returns a tuple in the same format as `sys.exc_info` or None.
.. versionadded:: 4.0
"""
+ self._clear_tb_log()
return self._exc_info
def set_exc_info(self, exc_info):
@@ -165,7 +268,18 @@ class Future(object):
.. versionadded:: 4.0
"""
self._exc_info = exc_info
- self.set_exception(exc_info[1])
+ self._log_traceback = True
+ if not _GC_CYCLE_FINALIZERS:
+ self._tb_logger = _TracebackLogger(exc_info)
+
+ try:
+ self._set_done()
+ finally:
+ # Activate the logger after all callbacks have had a
+ # chance to call result() or exception().
+ if self._log_traceback and self._tb_logger is not None:
+ self._tb_logger.activate()
+ self._exc_info = exc_info
def _check_done(self):
if not self._done:
@@ -177,10 +291,25 @@ class Future(object):
try:
cb(self)
except Exception:
- app_log.exception('exception calling callback %r for %r',
+ app_log.exception('Exception in callback %r for %r',
cb, self)
self._callbacks = None
+ # On Python 3.3 or older, objects with a destructor part of a reference
+ # cycle are never destroyed. It's no longer the case on Python 3.4 thanks to
+ # the PEP 442.
+ if _GC_CYCLE_FINALIZERS:
+ def __del__(self):
+ if not self._log_traceback:
+ # set_exception() was not called, or result() or exception()
+ # has consumed the exception
+ return
+
+ tb = traceback.format_exception(*self._exc_info)
+
+ app_log.error('Future %r exception was never retrieved: %s',
+ self, ''.join(tb).rstrip())
+
TracebackFuture = Future
if futures is None:
@@ -208,24 +337,42 @@ class DummyExecutor(object):
dummy_executor = DummyExecutor()
-def run_on_executor(fn):
+def run_on_executor(*args, **kwargs):
"""Decorator to run a synchronous method asynchronously on an executor.
The decorated method may be called with a ``callback`` keyword
argument and returns a future.
- This decorator should be used only on methods of objects with attributes
- ``executor`` and ``io_loop``.
+ The `.IOLoop` and executor to be used are determined by the ``io_loop``
+ and ``executor`` attributes of ``self``. To use different attributes,
+ pass keyword arguments to the decorator::
+
+ @run_on_executor(executor='_thread_pool')
+ def foo(self):
+ pass
+
+ .. versionchanged:: 4.2
+ Added keyword arguments to use alternative attributes.
"""
- @functools.wraps(fn)
- def wrapper(self, *args, **kwargs):
- callback = kwargs.pop("callback", None)
- future = self.executor.submit(fn, self, *args, **kwargs)
- if callback:
- self.io_loop.add_future(future,
- lambda future: callback(future.result()))
- return future
- return wrapper
+ def run_on_executor_decorator(fn):
+ executor = kwargs.get("executor", "executor")
+ io_loop = kwargs.get("io_loop", "io_loop")
+ @functools.wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ callback = kwargs.pop("callback", None)
+ future = getattr(self, executor).submit(fn, self, *args, **kwargs)
+ if callback:
+ getattr(self, io_loop).add_future(
+ future, lambda future: callback(future.result()))
+ return future
+ return wrapper
+ if args and kwargs:
+ raise ValueError("cannot combine positional and keyword args")
+ if len(args) == 1:
+ return run_on_executor_decorator(args[0])
+ elif len(args) != 0:
+ raise ValueError("expected 1 argument, got %d", len(args))
+ return run_on_executor_decorator
_NO_RESULT = object()
@@ -250,7 +397,9 @@ def return_future(f):
wait for the function to complete (perhaps by yielding it in a
`.gen.engine` function, or passing it to `.IOLoop.add_future`).
- Usage::
+ Usage:
+
+ .. testcode::
@return_future
def future_func(arg1, arg2, callback):
@@ -262,6 +411,8 @@ def return_future(f):
yield future_func(arg1, arg2)
callback()
+ ..
+
Note that ``@return_future`` and ``@gen.engine`` can be applied to the
same function, provided ``@return_future`` appears first. However,
consider using ``@gen.coroutine`` instead of this combination.
@@ -293,7 +444,7 @@ def return_future(f):
# If the initial synchronous part of f() raised an exception,
# go ahead and raise it to the caller directly without waiting
# for them to inspect the Future.
- raise_exc_info(exc_info)
+ future.result()
# If the caller passed in a callback, schedule it to be called
# when the future resolves. It is important that this happens
diff --git a/tornado/curl_httpclient.py b/tornado/curl_httpclient.py
index 68047cc9..ae6f114a 100644
--- a/tornado/curl_httpclient.py
+++ b/tornado/curl_httpclient.py
@@ -28,12 +28,13 @@ from io import BytesIO
from tornado import httputil
from tornado import ioloop
-from tornado.log import gen_log
from tornado import stack_context
from tornado.escape import utf8, native_str
from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main
+curl_log = logging.getLogger('tornado.curl_httpclient')
+
class CurlAsyncHTTPClient(AsyncHTTPClient):
def initialize(self, io_loop, max_clients=10, defaults=None):
@@ -207,9 +208,25 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
"callback": callback,
"curl_start_time": time.time(),
}
- self._curl_setup_request(curl, request, curl.info["buffer"],
- curl.info["headers"])
- self._multi.add_handle(curl)
+ try:
+ self._curl_setup_request(
+ curl, request, curl.info["buffer"],
+ curl.info["headers"])
+ except Exception as e:
+ # If there was an error in setup, pass it on
+ # to the callback. Note that allowing the
+ # error to escape here will appear to work
+ # most of the time since we are still in the
+ # caller's original stack frame, but when
+ # _process_queue() is called from
+ # _finish_pending_requests the exceptions have
+ # nowhere to go.
+ callback(HTTPResponse(
+ request=request,
+ code=599,
+ error=e))
+ else:
+ self._multi.add_handle(curl)
if not started:
break
@@ -257,7 +274,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
def _curl_create(self):
curl = pycurl.Curl()
- if gen_log.isEnabledFor(logging.DEBUG):
+ if curl_log.isEnabledFor(logging.DEBUG):
curl.setopt(pycurl.VERBOSE, 1)
curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug)
return curl
@@ -286,10 +303,10 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
curl.setopt(pycurl.HEADERFUNCTION,
functools.partial(self._curl_header_callback,
- headers, request.header_callback))
+ headers, request.header_callback))
if request.streaming_callback:
- write_function = lambda chunk: self.io_loop.add_callback(
- request.streaming_callback, chunk)
+ def write_function(chunk):
+ self.io_loop.add_callback(request.streaming_callback, chunk)
else:
write_function = buffer.write
if bytes is str: # py2
@@ -381,6 +398,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
% request.method)
request_buffer = BytesIO(utf8(request.body))
+
def ioctl(cmd):
if cmd == curl.IOCMD_RESTARTREAD:
request_buffer.seek(0)
@@ -403,11 +421,11 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
raise ValueError("Unsupported auth_mode %s" % request.auth_mode)
curl.setopt(pycurl.USERPWD, native_str(userpwd))
- gen_log.debug("%s %s (username: %r)", request.method, request.url,
- request.auth_username)
+ curl_log.debug("%s %s (username: %r)", request.method, request.url,
+ request.auth_username)
else:
curl.unsetopt(pycurl.USERPWD)
- gen_log.debug("%s %s", request.method, request.url)
+ curl_log.debug("%s %s", request.method, request.url)
if request.client_cert is not None:
curl.setopt(pycurl.SSLCERT, request.client_cert)
@@ -415,6 +433,9 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
if request.client_key is not None:
curl.setopt(pycurl.SSLKEY, request.client_key)
+ if request.ssl_options is not None:
+ raise ValueError("ssl_options not supported in curl_httpclient")
+
if threading.activeCount() > 1:
# libcurl/pycurl is not thread-safe by default. When multiple threads
# are used, signals should be disabled. This has the side effect
@@ -448,12 +469,12 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
def _curl_debug(self, debug_type, debug_msg):
debug_types = ('I', '<', '>', '<', '>')
if debug_type == 0:
- gen_log.debug('%s', debug_msg.strip())
+ curl_log.debug('%s', debug_msg.strip())
elif debug_type in (1, 2):
for line in debug_msg.splitlines():
- gen_log.debug('%s %s', debug_types[debug_type], line)
+ curl_log.debug('%s %s', debug_types[debug_type], line)
elif debug_type == 4:
- gen_log.debug('%s %r', debug_types[debug_type], debug_msg)
+ curl_log.debug('%s %r', debug_types[debug_type], debug_msg)
class CurlError(HTTPError):
diff --git a/tornado/escape.py b/tornado/escape.py
index 24be2264..2852cf51 100644
--- a/tornado/escape.py
+++ b/tornado/escape.py
@@ -82,7 +82,7 @@ def json_encode(value):
# JSON permits but does not require forward slashes to be escaped.
# This is useful when json data is emitted in a tags from prematurely terminating
- # the javscript. Some json libraries do this escaping by default,
+ # the javascript. Some json libraries do this escaping by default,
# although python's standard library does not, so we do it here.
# http://stackoverflow.com/questions/1580647/json-why-are-forward-slashes-escaped
return json.dumps(value).replace("", "<\\/")
diff --git a/tornado/gen.py b/tornado/gen.py
index 2fc9b0c7..4cc578eb 100644
--- a/tornado/gen.py
+++ b/tornado/gen.py
@@ -3,7 +3,9 @@ work in an asynchronous environment. Code using the ``gen`` module
is technically asynchronous, but it is written as a single generator
instead of a collection of separate functions.
-For example, the following asynchronous handler::
+For example, the following asynchronous handler:
+
+.. testcode::
class AsyncHandler(RequestHandler):
@asynchronous
@@ -16,7 +18,12 @@ For example, the following asynchronous handler::
do_something_with_response(response)
self.render("template.html")
-could be written with ``gen`` as::
+.. testoutput::
+ :hide:
+
+could be written with ``gen`` as:
+
+.. testcode::
class GenAsyncHandler(RequestHandler):
@gen.coroutine
@@ -26,12 +33,17 @@ could be written with ``gen`` as::
do_something_with_response(response)
self.render("template.html")
+.. testoutput::
+ :hide:
+
Most asynchronous functions in Tornado return a `.Future`;
yielding this object returns its `~.Future.result`.
You can also yield a list or dict of ``Futures``, which will be
started at the same time and run in parallel; a list or dict of results will
-be returned when they are all finished::
+be returned when they are all finished:
+
+.. testcode::
@gen.coroutine
def get(self):
@@ -43,8 +55,24 @@ be returned when they are all finished::
response3 = response_dict['response3']
response4 = response_dict['response4']
+.. testoutput::
+ :hide:
+
+If the `~functools.singledispatch` library is available (standard in
+Python 3.4, available via the `singledispatch
+`_ package on older
+versions), additional types of objects may be yielded. Tornado includes
+support for ``asyncio.Future`` and Twisted's ``Deferred`` class when
+``tornado.platform.asyncio`` and ``tornado.platform.twisted`` are imported.
+See the `convert_yielded` function to extend this mechanism.
+
.. versionchanged:: 3.2
Dict support added.
+
+.. versionchanged:: 4.1
+ Support added for yielding ``asyncio`` Futures and Twisted Deferreds
+ via ``singledispatch``.
+
"""
from __future__ import absolute_import, division, print_function, with_statement
@@ -53,10 +81,21 @@ import functools
import itertools
import sys
import types
+import weakref
from tornado.concurrent import Future, TracebackFuture, is_future, chain_future
from tornado.ioloop import IOLoop
+from tornado.log import app_log
from tornado import stack_context
+from tornado.util import raise_exc_info
+
+try:
+ from functools import singledispatch # py34+
+except ImportError as e:
+ try:
+ from singledispatch import singledispatch # backport
+ except ImportError:
+ singledispatch = None
class KeyReuseError(Exception):
@@ -101,9 +140,11 @@ def engine(func):
which use ``self.finish()`` in place of a callback argument.
"""
func = _make_coroutine_wrapper(func, replace_callback=False)
+
@functools.wraps(func)
def wrapper(*args, **kwargs):
future = func(*args, **kwargs)
+
def final_callback(future):
if future.result() is not None:
raise ReturnValueIgnoredError(
@@ -241,6 +282,113 @@ class Return(Exception):
self.value = value
+class WaitIterator(object):
+ """Provides an iterator to yield the results of futures as they finish.
+
+ Yielding a set of futures like this:
+
+ ``results = yield [future1, future2]``
+
+ pauses the coroutine until both ``future1`` and ``future2``
+ return, and then restarts the coroutine with the results of both
+ futures. If either future is an exception, the expression will
+ raise that exception and all the results will be lost.
+
+ If you need to get the result of each future as soon as possible,
+ or if you need the result of some futures even if others produce
+ errors, you can use ``WaitIterator``::
+
+ wait_iterator = gen.WaitIterator(future1, future2)
+ while not wait_iterator.done():
+ try:
+ result = yield wait_iterator.next()
+ except Exception as e:
+ print("Error {} from {}".format(e, wait_iterator.current_future))
+ else:
+ print("Result {} received from {} at {}".format(
+ result, wait_iterator.current_future,
+ wait_iterator.current_index))
+
+ Because results are returned as soon as they are available the
+ output from the iterator *will not be in the same order as the
+ input arguments*. If you need to know which future produced the
+ current result, you can use the attributes
+ ``WaitIterator.current_future``, or ``WaitIterator.current_index``
+ to get the index of the future from the input list. (if keyword
+ arguments were used in the construction of the `WaitIterator`,
+ ``current_index`` will use the corresponding keyword).
+
+ .. versionadded:: 4.1
+ """
+ def __init__(self, *args, **kwargs):
+ if args and kwargs:
+ raise ValueError(
+ "You must provide args or kwargs, not both")
+
+ if kwargs:
+ self._unfinished = dict((f, k) for (k, f) in kwargs.items())
+ futures = list(kwargs.values())
+ else:
+ self._unfinished = dict((f, i) for (i, f) in enumerate(args))
+ futures = args
+
+ self._finished = collections.deque()
+ self.current_index = self.current_future = None
+ self._running_future = None
+
+ # Use a weak reference to self to avoid cycles that may delay
+ # garbage collection.
+ self_ref = weakref.ref(self)
+ for future in futures:
+ future.add_done_callback(functools.partial(
+ self._done_callback, self_ref))
+
+ def done(self):
+ """Returns True if this iterator has no more results."""
+ if self._finished or self._unfinished:
+ return False
+ # Clear the 'current' values when iteration is done.
+ self.current_index = self.current_future = None
+ return True
+
+ def next(self):
+ """Returns a `.Future` that will yield the next available result.
+
+ Note that this `.Future` will not be the same object as any of
+ the inputs.
+ """
+ self._running_future = TracebackFuture()
+ # As long as there is an active _running_future, we must
+ # ensure that the WaitIterator is not GC'd (due to the
+ # use of weak references in __init__). Add a callback that
+ # references self so there is a hard reference that will be
+ # cleared automatically when this Future finishes.
+ self._running_future.add_done_callback(lambda f: self)
+
+ if self._finished:
+ self._return_result(self._finished.popleft())
+
+ return self._running_future
+
+ @staticmethod
+ def _done_callback(self_ref, done):
+ self = self_ref()
+ if self is not None:
+ if self._running_future and not self._running_future.done():
+ self._return_result(done)
+ else:
+ self._finished.append(done)
+
+ def _return_result(self, done):
+ """Called set the returned future's state that of the future
+ we yielded, and set the current future for the iterator.
+ """
+ chain_future(done, self._running_future)
+
+ self.current_future = done
+ self.current_index = self._unfinished.pop(done)
+
+
class YieldPoint(object):
"""Base class for objects that may be yielded from the generator.
@@ -355,11 +503,13 @@ def Task(func, *args, **kwargs):
yielded.
"""
future = Future()
+
def handle_exception(typ, value, tb):
if future.done():
return False
future.set_exc_info((typ, value, tb))
return True
+
def set_result(result):
if future.done():
return
@@ -371,6 +521,11 @@ def Task(func, *args, **kwargs):
class YieldFuture(YieldPoint):
def __init__(self, future, io_loop=None):
+ """Adapts a `.Future` to the `YieldPoint` interface.
+
+ .. versionchanged:: 4.1
+ The ``io_loop`` argument is deprecated.
+ """
self.future = future
self.io_loop = io_loop or IOLoop.current()
@@ -382,7 +537,7 @@ class YieldFuture(YieldPoint):
self.io_loop.add_future(self.future, runner.result_callback(self.key))
else:
self.runner = None
- self.result = self.future.result()
+ self.result_fn = self.future.result
def is_ready(self):
if self.runner is not None:
@@ -394,7 +549,7 @@ class YieldFuture(YieldPoint):
if self.runner is not None:
return self.runner.pop_result(self.key).result()
else:
- return self.result
+ return self.result_fn()
class Multi(YieldPoint):
@@ -408,8 +563,18 @@ class Multi(YieldPoint):
Instead of a list, the argument may also be a dictionary whose values are
Futures, in which case a parallel dictionary is returned mapping the same
keys to their results.
+
+ It is not normally necessary to call this class directly, as it
+ will be created automatically as needed. However, calling it directly
+ allows you to use the ``quiet_exceptions`` argument to control
+ the logging of multiple exceptions.
+
+ .. versionchanged:: 4.2
+ If multiple ``YieldPoints`` fail, any exceptions after the first
+ (which is raised) will be logged. Added the ``quiet_exceptions``
+ argument to suppress this logging for selected exception types.
"""
- def __init__(self, children):
+ def __init__(self, children, quiet_exceptions=()):
self.keys = None
if isinstance(children, dict):
self.keys = list(children.keys())
@@ -421,6 +586,7 @@ class Multi(YieldPoint):
self.children.append(i)
assert all(isinstance(i, YieldPoint) for i in self.children)
self.unfinished_children = set(self.children)
+ self.quiet_exceptions = quiet_exceptions
def start(self, runner):
for i in self.children:
@@ -433,14 +599,27 @@ class Multi(YieldPoint):
return not self.unfinished_children
def get_result(self):
- result = (i.get_result() for i in self.children)
+ result_list = []
+ exc_info = None
+ for f in self.children:
+ try:
+ result_list.append(f.get_result())
+ except Exception as e:
+ if exc_info is None:
+ exc_info = sys.exc_info()
+ else:
+ if not isinstance(e, self.quiet_exceptions):
+ app_log.error("Multiple exceptions in yield list",
+ exc_info=True)
+ if exc_info is not None:
+ raise_exc_info(exc_info)
if self.keys is not None:
- return dict(zip(self.keys, result))
+ return dict(zip(self.keys, result_list))
else:
- return list(result)
+ return list(result_list)
-def multi_future(children):
+def multi_future(children, quiet_exceptions=()):
"""Wait for multiple asynchronous futures in parallel.
Takes a list of ``Futures`` (but *not* other ``YieldPoints``) and returns
@@ -453,12 +632,21 @@ def multi_future(children):
Futures, in which case a parallel dictionary is returned mapping the same
keys to their results.
- It is not necessary to call `multi_future` explcitly, since the engine will
- do so automatically when the generator yields a list of `Futures`.
- This function is faster than the `Multi` `YieldPoint` because it does not
- require the creation of a stack context.
+ It is not normally necessary to call `multi_future` explcitly,
+ since the engine will do so automatically when the generator
+ yields a list of ``Futures``. However, calling it directly
+ allows you to use the ``quiet_exceptions`` argument to control
+ the logging of multiple exceptions.
+
+ This function is faster than the `Multi` `YieldPoint` because it
+ does not require the creation of a stack context.
.. versionadded:: 4.0
+
+ .. versionchanged:: 4.2
+ If multiple ``Futures`` fail, any exceptions after the first (which is
+ raised) will be logged. Added the ``quiet_exceptions``
+ argument to suppress this logging for selected exception types.
"""
if isinstance(children, dict):
keys = list(children.keys())
@@ -471,20 +659,32 @@ def multi_future(children):
future = Future()
if not children:
future.set_result({} if keys is not None else [])
+
def callback(f):
unfinished_children.remove(f)
if not unfinished_children:
- try:
- result_list = [i.result() for i in children]
- except Exception:
- future.set_exc_info(sys.exc_info())
- else:
+ result_list = []
+ for f in children:
+ try:
+ result_list.append(f.result())
+ except Exception as e:
+ if future.done():
+ if not isinstance(e, quiet_exceptions):
+ app_log.error("Multiple exceptions in yield list",
+ exc_info=True)
+ else:
+ future.set_exc_info(sys.exc_info())
+ if not future.done():
if keys is not None:
future.set_result(dict(zip(keys, result_list)))
else:
future.set_result(result_list)
+
+ listening = set()
for f in children:
- f.add_done_callback(callback)
+ if f not in listening:
+ listening.add(f)
+ f.add_done_callback(callback)
return future
@@ -504,7 +704,7 @@ def maybe_future(x):
return fut
-def with_timeout(timeout, future, io_loop=None):
+def with_timeout(timeout, future, io_loop=None, quiet_exceptions=()):
"""Wraps a `.Future` in a timeout.
Raises `TimeoutError` if the input future does not complete before
@@ -512,9 +712,17 @@ def with_timeout(timeout, future, io_loop=None):
`.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or an absolute time
relative to `.IOLoop.time`)
+ If the wrapped `.Future` fails after it has timed out, the exception
+ will be logged unless it is of a type contained in ``quiet_exceptions``
+ (which may be an exception type or a sequence of types).
+
Currently only supports Futures, not other `YieldPoint` classes.
.. versionadded:: 4.0
+
+ .. versionchanged:: 4.1
+ Added the ``quiet_exceptions`` argument and the logging of unhandled
+ exceptions.
"""
# TODO: allow yield points in addition to futures?
# Tricky to do with stack_context semantics.
@@ -528,9 +736,21 @@ def with_timeout(timeout, future, io_loop=None):
chain_future(future, result)
if io_loop is None:
io_loop = IOLoop.current()
+
+ def error_callback(future):
+ try:
+ future.result()
+ except Exception as e:
+ if not isinstance(e, quiet_exceptions):
+ app_log.error("Exception in Future %r after timeout",
+ future, exc_info=True)
+
+ def timeout_callback():
+ result.set_exception(TimeoutError("Timeout"))
+ # In case the wrapped future goes on to fail, log it.
+ future.add_done_callback(error_callback)
timeout_handle = io_loop.add_timeout(
- timeout,
- lambda: result.set_exception(TimeoutError("Timeout")))
+ timeout, timeout_callback)
if isinstance(future, Future):
# We know this future will resolve on the IOLoop, so we don't
# need the extra thread-safety of IOLoop.add_future (and we also
@@ -545,6 +765,25 @@ def with_timeout(timeout, future, io_loop=None):
return result
+def sleep(duration):
+ """Return a `.Future` that resolves after the given number of seconds.
+
+ When used with ``yield`` in a coroutine, this is a non-blocking
+ analogue to `time.sleep` (which should not be used in coroutines
+ because it is blocking)::
+
+ yield gen.sleep(0.5)
+
+ Note that calling this function on its own does nothing; you must
+ wait on the `.Future` it returns (usually by yielding it).
+
+ .. versionadded:: 4.1
+ """
+ f = Future()
+ IOLoop.current().call_later(duration, lambda: f.set_result(None))
+ return f
+
+
_null_future = Future()
_null_future.set_result(None)
@@ -638,13 +877,20 @@ class Runner(object):
self.future = None
try:
orig_stack_contexts = stack_context._state.contexts
+ exc_info = None
+
try:
value = future.result()
except Exception:
self.had_exception = True
- yielded = self.gen.throw(*sys.exc_info())
+ exc_info = sys.exc_info()
+
+ if exc_info is not None:
+ yielded = self.gen.throw(*exc_info)
+ exc_info = None
else:
yielded = self.gen.send(value)
+
if stack_context._state.contexts is not orig_stack_contexts:
self.gen.throw(
stack_context.StackContextInconsistentError(
@@ -678,19 +924,20 @@ class Runner(object):
self.running = False
def handle_yield(self, yielded):
- if isinstance(yielded, list):
- if all(is_future(f) for f in yielded):
- yielded = multi_future(yielded)
- else:
- yielded = Multi(yielded)
- elif isinstance(yielded, dict):
- if all(is_future(f) for f in yielded.values()):
- yielded = multi_future(yielded)
- else:
- yielded = Multi(yielded)
+ # Lists containing YieldPoints require stack contexts;
+ # other lists are handled via multi_future in convert_yielded.
+ if (isinstance(yielded, list) and
+ any(isinstance(f, YieldPoint) for f in yielded)):
+ yielded = Multi(yielded)
+ elif (isinstance(yielded, dict) and
+ any(isinstance(f, YieldPoint) for f in yielded.values())):
+ yielded = Multi(yielded)
if isinstance(yielded, YieldPoint):
+ # YieldPoints are too closely coupled to the Runner to go
+ # through the generic convert_yielded mechanism.
self.future = TracebackFuture()
+
def start_yield_point():
try:
yielded.start(self)
@@ -702,12 +949,14 @@ class Runner(object):
except Exception:
self.future = TracebackFuture()
self.future.set_exc_info(sys.exc_info())
+
if self.stack_context_deactivate is None:
# Start a stack context if this is the first
# YieldPoint we've seen.
with stack_context.ExceptionStackContext(
self.handle_exception) as deactivate:
self.stack_context_deactivate = deactivate
+
def cb():
start_yield_point()
self.run()
@@ -715,16 +964,17 @@ class Runner(object):
return False
else:
start_yield_point()
- elif is_future(yielded):
- self.future = yielded
- if not self.future.done() or self.future is moment:
- self.io_loop.add_future(
- self.future, lambda f: self.run())
- return False
else:
- self.future = TracebackFuture()
- self.future.set_exception(BadYieldError(
- "yielded unknown object %r" % (yielded,)))
+ try:
+ self.future = convert_yielded(yielded)
+ except BadYieldError:
+ self.future = TracebackFuture()
+ self.future.set_exc_info(sys.exc_info())
+
+ if not self.future.done() or self.future is moment:
+ self.io_loop.add_future(
+ self.future, lambda f: self.run())
+ return False
return True
def result_callback(self, key):
@@ -763,3 +1013,30 @@ def _argument_adapter(callback):
else:
callback(None)
return wrapper
+
+
+def convert_yielded(yielded):
+ """Convert a yielded object into a `.Future`.
+
+ The default implementation accepts lists, dictionaries, and Futures.
+
+ If the `~functools.singledispatch` library is available, this function
+ may be extended to support additional types. For example::
+
+ @convert_yielded.register(asyncio.Future)
+ def _(asyncio_future):
+ return tornado.platform.asyncio.to_tornado_future(asyncio_future)
+
+ .. versionadded:: 4.1
+ """
+ # Lists and dicts containing YieldPoints were handled separately
+ # via Multi().
+ if isinstance(yielded, (list, dict)):
+ return multi_future(yielded)
+ elif is_future(yielded):
+ return yielded
+ else:
+ raise BadYieldError("yielded unknown object %r" % (yielded,))
+
+if singledispatch is not None:
+ convert_yielded = singledispatch(convert_yielded)
diff --git a/tornado/http1connection.py b/tornado/http1connection.py
index 8a5f46c4..6226ef7a 100644
--- a/tornado/http1connection.py
+++ b/tornado/http1connection.py
@@ -37,6 +37,7 @@ class _QuietException(Exception):
def __init__(self):
pass
+
class _ExceptionLoggingContext(object):
"""Used with the ``with`` statement when calling delegate methods to
log any exceptions with the given logger. Any exceptions caught are
@@ -53,6 +54,7 @@ class _ExceptionLoggingContext(object):
self.logger.error("Uncaught exception", exc_info=(typ, value, tb))
raise _QuietException
+
class HTTP1ConnectionParameters(object):
"""Parameters for `.HTTP1Connection` and `.HTTP1ServerConnection`.
"""
@@ -162,7 +164,8 @@ class HTTP1Connection(httputil.HTTPConnection):
header_data = yield gen.with_timeout(
self.stream.io_loop.time() + self.params.header_timeout,
header_future,
- io_loop=self.stream.io_loop)
+ io_loop=self.stream.io_loop,
+ quiet_exceptions=iostream.StreamClosedError)
except gen.TimeoutError:
self.close()
raise gen.Return(False)
@@ -201,7 +204,7 @@ class HTTP1Connection(httputil.HTTPConnection):
# 1xx responses should never indicate the presence of
# a body.
if ('Content-Length' in headers or
- 'Transfer-Encoding' in headers):
+ 'Transfer-Encoding' in headers):
raise httputil.HTTPInputError(
"Response code %d cannot have body" % code)
# TODO: client delegates will get headers_received twice
@@ -221,7 +224,8 @@ class HTTP1Connection(httputil.HTTPConnection):
try:
yield gen.with_timeout(
self.stream.io_loop.time() + self._body_timeout,
- body_future, self.stream.io_loop)
+ body_future, self.stream.io_loop,
+ quiet_exceptions=iostream.StreamClosedError)
except gen.TimeoutError:
gen_log.info("Timeout reading body from %s",
self.context)
@@ -326,8 +330,10 @@ class HTTP1Connection(httputil.HTTPConnection):
def write_headers(self, start_line, headers, chunk=None, callback=None):
"""Implements `.HTTPConnection.write_headers`."""
+ lines = []
if self.is_client:
self._request_start_line = start_line
+ lines.append(utf8('%s %s HTTP/1.1' % (start_line[0], start_line[1])))
# Client requests with a non-empty body must have either a
# Content-Length or a Transfer-Encoding.
self._chunking_output = (
@@ -336,6 +342,7 @@ class HTTP1Connection(httputil.HTTPConnection):
'Transfer-Encoding' not in headers)
else:
self._response_start_line = start_line
+ lines.append(utf8('HTTP/1.1 %s %s' % (start_line[1], start_line[2])))
self._chunking_output = (
# TODO: should this use
# self._request_start_line.version or
@@ -365,7 +372,6 @@ class HTTP1Connection(httputil.HTTPConnection):
self._expected_content_remaining = int(headers['Content-Length'])
else:
self._expected_content_remaining = None
- lines = [utf8("%s %s %s" % start_line)]
lines.extend([utf8(n) + b": " + utf8(v) for n, v in headers.get_all()])
for line in lines:
if b'\n' in line:
@@ -374,6 +380,7 @@ class HTTP1Connection(httputil.HTTPConnection):
if self.stream.closed():
future = self._write_future = Future()
future.set_exception(iostream.StreamClosedError())
+ future.exception()
else:
if callback is not None:
self._write_callback = stack_context.wrap(callback)
@@ -412,6 +419,7 @@ class HTTP1Connection(httputil.HTTPConnection):
if self.stream.closed():
future = self._write_future = Future()
self._write_future.set_exception(iostream.StreamClosedError())
+ self._write_future.exception()
else:
if callback is not None:
self._write_callback = stack_context.wrap(callback)
@@ -451,6 +459,9 @@ class HTTP1Connection(httputil.HTTPConnection):
self._pending_write.add_done_callback(self._finish_request)
def _on_write_complete(self, future):
+ exc = future.exception()
+ if exc is not None and not isinstance(exc, iostream.StreamClosedError):
+ future.result()
if self._write_callback is not None:
callback = self._write_callback
self._write_callback = None
@@ -491,8 +502,9 @@ class HTTP1Connection(httputil.HTTPConnection):
# we SHOULD ignore at least one empty line before the request.
# http://tools.ietf.org/html/rfc7230#section-3.5
data = native_str(data.decode('latin1')).lstrip("\r\n")
- eol = data.find("\r\n")
- start_line = data[:eol]
+ # RFC 7230 section allows for both CRLF and bare LF.
+ eol = data.find("\n")
+ start_line = data[:eol].rstrip("\r")
try:
headers = httputil.HTTPHeaders.parse(data[eol:])
except ValueError:
@@ -686,9 +698,8 @@ class HTTP1ServerConnection(object):
# This exception was already logged.
conn.close()
return
- except Exception as e:
- if 1 != e.errno:
- gen_log.error("Uncaught exception", exc_info=True)
+ except Exception:
+ gen_log.error("Uncaught exception", exc_info=True)
conn.close()
return
if not ret:
diff --git a/tornado/httpclient.py b/tornado/httpclient.py
index df429517..c2e68623 100644
--- a/tornado/httpclient.py
+++ b/tornado/httpclient.py
@@ -72,7 +72,7 @@ class HTTPClient(object):
http_client.close()
"""
def __init__(self, async_client_class=None, **kwargs):
- self._io_loop = IOLoop()
+ self._io_loop = IOLoop(make_current=False)
if async_client_class is None:
async_client_class = AsyncHTTPClient
self._async_client = async_client_class(self._io_loop, **kwargs)
@@ -95,11 +95,11 @@ class HTTPClient(object):
If it is a string, we construct an `HTTPRequest` using any additional
kwargs: ``HTTPRequest(request, **kwargs)``
- If an error occurs during the fetch, we raise an `HTTPError`.
+ If an error occurs during the fetch, we raise an `HTTPError` unless
+ the ``raise_error`` keyword argument is set to False.
"""
response = self._io_loop.run_sync(functools.partial(
self._async_client.fetch, request, **kwargs))
- response.rethrow()
return response
@@ -136,6 +136,9 @@ class AsyncHTTPClient(Configurable):
# or with force_instance:
client = AsyncHTTPClient(force_instance=True,
defaults=dict(user_agent="MyUserAgent"))
+
+ .. versionchanged:: 4.1
+ The ``io_loop`` argument is deprecated.
"""
@classmethod
def configurable_base(cls):
@@ -200,7 +203,7 @@ class AsyncHTTPClient(Configurable):
raise RuntimeError("inconsistent AsyncHTTPClient cache")
del self._instance_cache[self.io_loop]
- def fetch(self, request, callback=None, **kwargs):
+ def fetch(self, request, callback=None, raise_error=True, **kwargs):
"""Executes a request, asynchronously returning an `HTTPResponse`.
The request may be either a string URL or an `HTTPRequest` object.
@@ -208,8 +211,10 @@ class AsyncHTTPClient(Configurable):
kwargs: ``HTTPRequest(request, **kwargs)``
This method returns a `.Future` whose result is an
- `HTTPResponse`. The ``Future`` will raise an `HTTPError` if
- the request returned a non-200 response code.
+ `HTTPResponse`. By default, the ``Future`` will raise an `HTTPError`
+ if the request returned a non-200 response code. Instead, if
+ ``raise_error`` is set to False, the response will always be
+ returned regardless of the response code.
If a ``callback`` is given, it will be invoked with the `HTTPResponse`.
In the callback interface, `HTTPError` is not automatically raised.
@@ -243,7 +248,7 @@ class AsyncHTTPClient(Configurable):
future.add_done_callback(handle_future)
def handle_response(response):
- if response.error:
+ if raise_error and response.error:
future.set_exception(response.error)
else:
future.set_result(response)
@@ -304,7 +309,8 @@ class HTTPRequest(object):
validate_cert=None, ca_certs=None,
allow_ipv6=None,
client_key=None, client_cert=None, body_producer=None,
- expect_100_continue=False, decompress_response=None):
+ expect_100_continue=False, decompress_response=None,
+ ssl_options=None):
r"""All parameters except ``url`` are optional.
:arg string url: URL to fetch
@@ -374,12 +380,15 @@ class HTTPRequest(object):
:arg string ca_certs: filename of CA certificates in PEM format,
or None to use defaults. See note below when used with
``curl_httpclient``.
- :arg bool allow_ipv6: Use IPv6 when available? Default is false in
- ``simple_httpclient`` and true in ``curl_httpclient``
:arg string client_key: Filename for client SSL key, if any. See
note below when used with ``curl_httpclient``.
:arg string client_cert: Filename for client SSL certificate, if any.
See note below when used with ``curl_httpclient``.
+ :arg ssl.SSLContext ssl_options: `ssl.SSLContext` object for use in
+ ``simple_httpclient`` (unsupported by ``curl_httpclient``).
+ Overrides ``validate_cert``, ``ca_certs``, ``client_key``,
+ and ``client_cert``.
+ :arg bool allow_ipv6: Use IPv6 when available? Default is true.
:arg bool expect_100_continue: If true, send the
``Expect: 100-continue`` header and wait for a continue response
before sending the request body. Only supported with
@@ -402,6 +411,9 @@ class HTTPRequest(object):
.. versionadded:: 4.0
The ``body_producer`` and ``expect_100_continue`` arguments.
+
+ .. versionadded:: 4.2
+ The ``ssl_options`` argument.
"""
# Note that some of these attributes go through property setters
# defined below.
@@ -439,6 +451,7 @@ class HTTPRequest(object):
self.allow_ipv6 = allow_ipv6
self.client_key = client_key
self.client_cert = client_cert
+ self.ssl_options = ssl_options
self.expect_100_continue = expect_100_continue
self.start_time = time.time()
diff --git a/tornado/httpserver.py b/tornado/httpserver.py
index 05d0e186..2dd04dd7 100644
--- a/tornado/httpserver.py
+++ b/tornado/httpserver.py
@@ -37,35 +37,17 @@ from tornado import httputil
from tornado import iostream
from tornado import netutil
from tornado.tcpserver import TCPServer
+from tornado.util import Configurable
-class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
+class HTTPServer(TCPServer, Configurable,
+ httputil.HTTPServerConnectionDelegate):
r"""A non-blocking, single-threaded HTTP server.
- A server is defined by either a request callback that takes a
- `.HTTPServerRequest` as an argument or a `.HTTPServerConnectionDelegate`
- instance.
-
- A simple example server that echoes back the URI you requested::
-
- import tornado.httpserver
- import tornado.ioloop
- from tornado import httputil
-
- def handle_request(request):
- message = "You requested %s\n" % request.uri
- request.connection.write_headers(
- httputil.ResponseStartLine('HTTP/1.1', 200, 'OK'),
- {"Content-Length": str(len(message))})
- request.connection.write(message)
- request.connection.finish()
-
- http_server = tornado.httpserver.HTTPServer(handle_request)
- http_server.listen(8888)
- tornado.ioloop.IOLoop.instance().start()
-
- Applications should use the methods of `.HTTPConnection` to write
- their response.
+ A server is defined by a subclass of `.HTTPServerConnectionDelegate`,
+ or, for backwards compatibility, a callback that takes an
+ `.HTTPServerRequest` as an argument. The delegate is usually a
+ `tornado.web.Application`.
`HTTPServer` supports keep-alive connections by default
(automatically for HTTP/1.1, or for HTTP/1.0 when the client
@@ -80,15 +62,15 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
if Tornado is run behind an SSL-decoding proxy that does not set one of
the supported ``xheaders``.
- To make this server serve SSL traffic, send the ``ssl_options`` dictionary
- argument with the arguments required for the `ssl.wrap_socket` method,
- including ``certfile`` and ``keyfile``. (In Python 3.2+ you can pass
- an `ssl.SSLContext` object instead of a dict)::
+ To make this server serve SSL traffic, send the ``ssl_options`` keyword
+ argument with an `ssl.SSLContext` object. For compatibility with older
+ versions of Python ``ssl_options`` may also be a dictionary of keyword
+ arguments for the `ssl.wrap_socket` method.::
- HTTPServer(applicaton, ssl_options={
- "certfile": os.path.join(data_dir, "mydomain.crt"),
- "keyfile": os.path.join(data_dir, "mydomain.key"),
- })
+ ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
+ ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"),
+ os.path.join(data_dir, "mydomain.key"))
+ HTTPServer(applicaton, ssl_options=ssl_ctx)
`HTTPServer` initialization follows one of three patterns (the
initialization methods are defined on `tornado.tcpserver.TCPServer`):
@@ -97,7 +79,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
server = HTTPServer(app)
server.listen(8888)
- IOLoop.instance().start()
+ IOLoop.current().start()
In many cases, `tornado.web.Application.listen` can be used to avoid
the need to explicitly create the `HTTPServer`.
@@ -108,7 +90,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
server = HTTPServer(app)
server.bind(8888)
server.start(0) # Forks multiple sub-processes
- IOLoop.instance().start()
+ IOLoop.current().start()
When using this interface, an `.IOLoop` must *not* be passed
to the `HTTPServer` constructor. `~.TCPServer.start` will always start
@@ -120,7 +102,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
tornado.process.fork_processes(0)
server = HTTPServer(app)
server.add_sockets(sockets)
- IOLoop.instance().start()
+ IOLoop.current().start()
The `~.TCPServer.add_sockets` interface is more complicated,
but it can be used with `tornado.process.fork_processes` to
@@ -134,13 +116,29 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
``idle_connection_timeout``, ``body_timeout``, ``max_body_size``
arguments. Added support for `.HTTPServerConnectionDelegate`
instances as ``request_callback``.
+
+ .. versionchanged:: 4.1
+ `.HTTPServerConnectionDelegate.start_request` is now called with
+ two arguments ``(server_conn, request_conn)`` (in accordance with the
+ documentation) instead of one ``(request_conn)``.
+
+ .. versionchanged:: 4.2
+ `HTTPServer` is now a subclass of `tornado.util.Configurable`.
"""
- def __init__(self, request_callback, no_keep_alive=False, io_loop=None,
- xheaders=False, ssl_options=None, protocol=None,
- decompress_request=False,
- chunk_size=None, max_header_size=None,
- idle_connection_timeout=None, body_timeout=None,
- max_body_size=None, max_buffer_size=None):
+ def __init__(self, *args, **kwargs):
+ # Ignore args to __init__; real initialization belongs in
+ # initialize since we're Configurable. (there's something
+ # weird in initialization order between this class,
+ # Configurable, and TCPServer so we can't leave __init__ out
+ # completely)
+ pass
+
+ def initialize(self, request_callback, no_keep_alive=False, io_loop=None,
+ xheaders=False, ssl_options=None, protocol=None,
+ decompress_request=False,
+ chunk_size=None, max_header_size=None,
+ idle_connection_timeout=None, body_timeout=None,
+ max_body_size=None, max_buffer_size=None):
self.request_callback = request_callback
self.no_keep_alive = no_keep_alive
self.xheaders = xheaders
@@ -157,6 +155,14 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
read_chunk_size=chunk_size)
self._connections = set()
+ @classmethod
+ def configurable_base(cls):
+ return HTTPServer
+
+ @classmethod
+ def configurable_default(cls):
+ return HTTPServer
+
@gen.coroutine
def close_all_connections(self):
while self._connections:
@@ -173,7 +179,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
conn.start_serving(self)
def start_request(self, server_conn, request_conn):
- return _ServerRequestAdapter(self, request_conn)
+ return _ServerRequestAdapter(self, server_conn, request_conn)
def on_close(self, server_conn):
self._connections.remove(server_conn)
@@ -246,13 +252,14 @@ class _ServerRequestAdapter(httputil.HTTPMessageDelegate):
"""Adapts the `HTTPMessageDelegate` interface to the interface expected
by our clients.
"""
- def __init__(self, server, connection):
+ def __init__(self, server, server_conn, request_conn):
self.server = server
- self.connection = connection
+ self.connection = request_conn
self.request = None
if isinstance(server.request_callback,
httputil.HTTPServerConnectionDelegate):
- self.delegate = server.request_callback.start_request(connection)
+ self.delegate = server.request_callback.start_request(
+ server_conn, request_conn)
self._chunks = None
else:
self.delegate = None
diff --git a/tornado/httputil.py b/tornado/httputil.py
index f5c9c04f..f5fea213 100644
--- a/tornado/httputil.py
+++ b/tornado/httputil.py
@@ -62,6 +62,11 @@ except ImportError:
pass
+# RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line
+# terminator and ignore any preceding CR.
+_CRLF_RE = re.compile(r'\r?\n')
+
+
class _NormalizedHeaderCache(dict):
"""Dynamic cached mapping of header names to Http-Header-Case.
@@ -193,7 +198,7 @@ class HTTPHeaders(dict):
[('Content-Length', '42'), ('Content-Type', 'text/html')]
"""
h = cls()
- for line in headers.splitlines():
+ for line in _CRLF_RE.split(headers):
if line:
h.parse_line(line)
return h
@@ -229,6 +234,14 @@ class HTTPHeaders(dict):
# default implementation returns dict(self), not the subclass
return HTTPHeaders(self)
+ # Use our overridden copy method for the copy.copy module.
+ __copy__ = copy
+
+ def __deepcopy__(self, memo_dict):
+ # Our values are immutable strings, so our standard copy is
+ # effectively a deep copy.
+ return self.copy()
+
class HTTPServerRequest(object):
"""A single HTTP request.
@@ -331,7 +344,7 @@ class HTTPServerRequest(object):
self.uri = uri
self.version = version
self.headers = headers or HTTPHeaders()
- self.body = body or ""
+ self.body = body or b""
# set remote IP and protocol
context = getattr(connection, 'context', None)
@@ -380,6 +393,8 @@ class HTTPServerRequest(object):
to write the response.
"""
assert isinstance(chunk, bytes)
+ assert self.version.startswith("HTTP/1."), \
+ "deprecated interface ony supported in HTTP/1.x"
self.connection.write(chunk, callback=callback)
def finish(self):
@@ -406,15 +421,14 @@ class HTTPServerRequest(object):
def get_ssl_certificate(self, binary_form=False):
"""Returns the client's SSL certificate, if any.
- To use client certificates, the HTTPServer must have been constructed
- with cert_reqs set in ssl_options, e.g.::
+ To use client certificates, the HTTPServer's
+ `ssl.SSLContext.verify_mode` field must be set, e.g.::
- server = HTTPServer(app,
- ssl_options=dict(
- certfile="foo.crt",
- keyfile="foo.key",
- cert_reqs=ssl.CERT_REQUIRED,
- ca_certs="cacert.crt"))
+ ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
+ ssl_ctx.load_cert_chain("foo.crt", "foo.key")
+ ssl_ctx.load_verify_locations("cacerts.pem")
+ ssl_ctx.verify_mode = ssl.CERT_REQUIRED
+ server = HTTPServer(app, ssl_options=ssl_ctx)
By default, the return value is a dictionary (or None, if no
client certificate is present). If ``binary_form`` is true, a
@@ -543,6 +557,8 @@ class HTTPConnection(object):
headers.
:arg callback: a callback to be run when the write is complete.
+ The ``version`` field of ``start_line`` is ignored.
+
Returns a `.Future` if no callback is given.
"""
raise NotImplementedError()
@@ -689,14 +705,17 @@ def parse_body_arguments(content_type, body, arguments, files, headers=None):
if values:
arguments.setdefault(name, []).extend(values)
elif content_type.startswith("multipart/form-data"):
- fields = content_type.split(";")
- for field in fields:
- k, sep, v = field.strip().partition("=")
- if k == "boundary" and v:
- parse_multipart_form_data(utf8(v), body, arguments, files)
- break
- else:
- gen_log.warning("Invalid multipart/form-data")
+ try:
+ fields = content_type.split(";")
+ for field in fields:
+ k, sep, v = field.strip().partition("=")
+ if k == "boundary" and v:
+ parse_multipart_form_data(utf8(v), body, arguments, files)
+ break
+ else:
+ raise ValueError("multipart boundary not found")
+ except Exception as e:
+ gen_log.warning("Invalid multipart/form-data: %s", e)
def parse_multipart_form_data(boundary, data, arguments, files):
@@ -782,7 +801,7 @@ def parse_request_start_line(line):
method, path, version = line.split(" ")
except ValueError:
raise HTTPInputError("Malformed HTTP request line")
- if not version.startswith("HTTP/"):
+ if not re.match(r"^HTTP/1\.[0-9]$", version):
raise HTTPInputError(
"Malformed HTTP version in HTTP Request-Line: %r" % version)
return RequestStartLine(method, path, version)
@@ -801,7 +820,7 @@ def parse_response_start_line(line):
ResponseStartLine(version='HTTP/1.1', code=200, reason='OK')
"""
line = native_str(line)
- match = re.match("(HTTP/1.[01]) ([0-9]+) ([^\r]*)", line)
+ match = re.match("(HTTP/1.[0-9]) ([0-9]+) ([^\r]*)", line)
if not match:
raise HTTPInputError("Error parsing response start line")
return ResponseStartLine(match.group(1), int(match.group(2)),
@@ -873,3 +892,20 @@ def _encode_header(key, pdict):
def doctests():
import doctest
return doctest.DocTestSuite()
+
+
+def split_host_and_port(netloc):
+ """Returns ``(host, port)`` tuple from ``netloc``.
+
+ Returned ``port`` will be ``None`` if not present.
+
+ .. versionadded:: 4.1
+ """
+ match = re.match(r'^(.+):(\d+)$', netloc)
+ if match:
+ host = match.group(1)
+ port = int(match.group(2))
+ else:
+ host = netloc
+ port = None
+ return (host, port)
diff --git a/tornado/ioloop.py b/tornado/ioloop.py
index 03193865..67e33b52 100644
--- a/tornado/ioloop.py
+++ b/tornado/ioloop.py
@@ -41,6 +41,7 @@ import sys
import threading
import time
import traceback
+import math
from tornado.concurrent import TracebackFuture, is_future
from tornado.log import app_log, gen_log
@@ -76,35 +77,52 @@ class IOLoop(Configurable):
simultaneous connections, you should use a system that supports
either ``epoll`` or ``kqueue``.
- Example usage for a simple TCP server::
+ Example usage for a simple TCP server:
+
+ .. testcode::
import errno
import functools
- import ioloop
+ import tornado.ioloop
import socket
def connection_ready(sock, fd, events):
while True:
try:
connection, address = sock.accept()
- except socket.error, e:
+ except socket.error as e:
if e.args[0] not in (errno.EWOULDBLOCK, errno.EAGAIN):
raise
return
connection.setblocking(0)
handle_connection(connection, address)
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- sock.setblocking(0)
- sock.bind(("", port))
- sock.listen(128)
+ if __name__ == '__main__':
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.setblocking(0)
+ sock.bind(("", port))
+ sock.listen(128)
- io_loop = ioloop.IOLoop.instance()
- callback = functools.partial(connection_ready, sock)
- io_loop.add_handler(sock.fileno(), callback, io_loop.READ)
- io_loop.start()
+ io_loop = tornado.ioloop.IOLoop.current()
+ callback = functools.partial(connection_ready, sock)
+ io_loop.add_handler(sock.fileno(), callback, io_loop.READ)
+ io_loop.start()
+ .. testoutput::
+ :hide:
+
+ By default, a newly-constructed `IOLoop` becomes the thread's current
+ `IOLoop`, unless there already is a current `IOLoop`. This behavior
+ can be controlled with the ``make_current`` argument to the `IOLoop`
+ constructor: if ``make_current=True``, the new `IOLoop` will always
+ try to become current and it raises an error if there is already a
+ current instance. If ``make_current=False``, the new `IOLoop` will
+ not try to become current.
+
+ .. versionchanged:: 4.2
+ Added the ``make_current`` keyword argument to the `IOLoop`
+ constructor.
"""
# Constants from the epoll module
_EPOLLIN = 0x001
@@ -133,7 +151,8 @@ class IOLoop(Configurable):
Most applications have a single, global `IOLoop` running on the
main thread. Use this method to get this instance from
- another thread. To get the current thread's `IOLoop`, use `current()`.
+ another thread. In most other cases, it is better to use `current()`
+ to get the current thread's `IOLoop`.
"""
if not hasattr(IOLoop, "_instance"):
with IOLoop._instance_lock:
@@ -167,28 +186,26 @@ class IOLoop(Configurable):
del IOLoop._instance
@staticmethod
- def current():
+ def current(instance=True):
"""Returns the current thread's `IOLoop`.
- If an `IOLoop` is currently running or has been marked as current
- by `make_current`, returns that instance. Otherwise returns
- `IOLoop.instance()`, i.e. the main thread's `IOLoop`.
-
- A common pattern for classes that depend on ``IOLoops`` is to use
- a default argument to enable programs with multiple ``IOLoops``
- but not require the argument for simpler applications::
-
- class MyClass(object):
- def __init__(self, io_loop=None):
- self.io_loop = io_loop or IOLoop.current()
+ If an `IOLoop` is currently running or has been marked as
+ current by `make_current`, returns that instance. If there is
+ no current `IOLoop`, returns `IOLoop.instance()` (i.e. the
+ main thread's `IOLoop`, creating one if necessary) if ``instance``
+ is true.
In general you should use `IOLoop.current` as the default when
constructing an asynchronous object, and use `IOLoop.instance`
when you mean to communicate to the main thread from a different
one.
+
+ .. versionchanged:: 4.1
+ Added ``instance`` argument to control the fallback to
+ `IOLoop.instance()`.
"""
current = getattr(IOLoop._current, "instance", None)
- if current is None:
+ if current is None and instance:
return IOLoop.instance()
return current
@@ -200,6 +217,10 @@ class IOLoop(Configurable):
`make_current` explicitly before starting the `IOLoop`,
so that code run at startup time can find the right
instance.
+
+ .. versionchanged:: 4.1
+ An `IOLoop` created while there is no current `IOLoop`
+ will automatically become current.
"""
IOLoop._current.instance = self
@@ -223,8 +244,14 @@ class IOLoop(Configurable):
from tornado.platform.select import SelectIOLoop
return SelectIOLoop
- def initialize(self):
- pass
+ def initialize(self, make_current=None):
+ if make_current is None:
+ if IOLoop.current(instance=False) is None:
+ self.make_current()
+ elif make_current:
+ if IOLoop.current(instance=False) is None:
+ raise RuntimeError("current IOLoop already exists")
+ self.make_current()
def close(self, all_fds=False):
"""Closes the `IOLoop`, freeing any resources used.
@@ -390,7 +417,7 @@ class IOLoop(Configurable):
# do stuff...
if __name__ == '__main__':
- IOLoop.instance().run_sync(main)
+ IOLoop.current().run_sync(main)
"""
future_cell = [None]
@@ -633,8 +660,8 @@ class PollIOLoop(IOLoop):
(Linux), `tornado.platform.kqueue.KQueueIOLoop` (BSD and Mac), or
`tornado.platform.select.SelectIOLoop` (all platforms).
"""
- def initialize(self, impl, time_func=None):
- super(PollIOLoop, self).initialize()
+ def initialize(self, impl, time_func=None, **kwargs):
+ super(PollIOLoop, self).initialize(**kwargs)
self._impl = impl
if hasattr(self._impl, 'fileno'):
set_close_exec(self._impl.fileno())
@@ -739,8 +766,10 @@ class PollIOLoop(IOLoop):
# IOLoop is just started once at the beginning.
signal.set_wakeup_fd(old_wakeup_fd)
old_wakeup_fd = None
- except ValueError: # non-main thread
- pass
+ except ValueError:
+ # Non-main thread, or the previous value of wakeup_fd
+ # is no longer valid.
+ old_wakeup_fd = None
try:
while True:
@@ -944,8 +973,16 @@ class PeriodicCallback(object):
"""Schedules the given callback to be called periodically.
The callback is called every ``callback_time`` milliseconds.
+ Note that the timeout is given in milliseconds, while most other
+ time-related functions in Tornado use seconds.
+
+ If the callback runs for longer than ``callback_time`` milliseconds,
+ subsequent invocations will be skipped to get back on schedule.
`start` must be called after the `PeriodicCallback` is created.
+
+ .. versionchanged:: 4.1
+ The ``io_loop`` argument is deprecated.
"""
def __init__(self, callback, callback_time, io_loop=None):
self.callback = callback
@@ -969,6 +1006,13 @@ class PeriodicCallback(object):
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
+ def is_running(self):
+ """Return True if this `.PeriodicCallback` has been started.
+
+ .. versionadded:: 4.1
+ """
+ return self._running
+
def _run(self):
if not self._running:
return
@@ -982,6 +1026,9 @@ class PeriodicCallback(object):
def _schedule_next(self):
if self._running:
current_time = self.io_loop.time()
- while self._next_timeout <= current_time:
- self._next_timeout += self.callback_time / 1000.0
+
+ if self._next_timeout <= current_time:
+ callback_time_sec = self.callback_time / 1000.0
+ self._next_timeout += (math.floor((current_time - self._next_timeout) / callback_time_sec) + 1) * callback_time_sec
+
self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run)
diff --git a/tornado/iostream.py b/tornado/iostream.py
index 69f43957..3a175a67 100644
--- a/tornado/iostream.py
+++ b/tornado/iostream.py
@@ -37,7 +37,7 @@ import re
from tornado.concurrent import TracebackFuture
from tornado import ioloop
from tornado.log import gen_log, app_log
-from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError
+from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError, _client_ssl_defaults, _server_ssl_defaults
from tornado import stack_context
from tornado.util import errno_from_exception
@@ -68,13 +68,21 @@ _ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE,
if hasattr(errno, "WSAECONNRESET"):
_ERRNO_CONNRESET += (errno.WSAECONNRESET, errno.WSAECONNABORTED, errno.WSAETIMEDOUT)
+if sys.platform == 'darwin':
+ # OSX appears to have a race condition that causes send(2) to return
+ # EPROTOTYPE if called while a socket is being torn down:
+ # http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/
+ # Since the socket is being closed anyway, treat this as an ECONNRESET
+ # instead of an unexpected error.
+ _ERRNO_CONNRESET += (errno.EPROTOTYPE,)
+
# More non-portable errnos:
_ERRNO_INPROGRESS = (errno.EINPROGRESS,)
if hasattr(errno, "WSAEINPROGRESS"):
_ERRNO_INPROGRESS += (errno.WSAEINPROGRESS,)
-#######################################################
+
class StreamClosedError(IOError):
"""Exception raised by `IOStream` methods when the stream is closed.
@@ -122,6 +130,7 @@ class BaseIOStream(object):
"""`BaseIOStream` constructor.
:arg io_loop: The `.IOLoop` to use; defaults to `.IOLoop.current`.
+ Deprecated since Tornado 4.1.
:arg max_buffer_size: Maximum amount of incoming data to buffer;
defaults to 100MB.
:arg read_chunk_size: Amount of data to read at one time from the
@@ -160,6 +169,11 @@ class BaseIOStream(object):
self._close_callback = None
self._connect_callback = None
self._connect_future = None
+ # _ssl_connect_future should be defined in SSLIOStream
+ # but it's here so we can clean it up in maybe_run_close_callback.
+ # TODO: refactor that so subclasses can add additional futures
+ # to be cancelled.
+ self._ssl_connect_future = None
self._connecting = False
self._state = None
self._pending_callbacks = 0
@@ -230,6 +244,12 @@ class BaseIOStream(object):
gen_log.info("Unsatisfiable read, closing connection: %s" % e)
self.close(exc_info=True)
return future
+ except:
+ if future is not None:
+ # Ensure that the future doesn't log an error because its
+ # failure was never examined.
+ future.add_done_callback(lambda f: f.exception())
+ raise
return future
def read_until(self, delimiter, callback=None, max_bytes=None):
@@ -257,6 +277,10 @@ class BaseIOStream(object):
gen_log.info("Unsatisfiable read, closing connection: %s" % e)
self.close(exc_info=True)
return future
+ except:
+ if future is not None:
+ future.add_done_callback(lambda f: f.exception())
+ raise
return future
def read_bytes(self, num_bytes, callback=None, streaming_callback=None,
@@ -281,7 +305,12 @@ class BaseIOStream(object):
self._read_bytes = num_bytes
self._read_partial = partial
self._streaming_callback = stack_context.wrap(streaming_callback)
- self._try_inline_read()
+ try:
+ self._try_inline_read()
+ except:
+ if future is not None:
+ future.add_done_callback(lambda f: f.exception())
+ raise
return future
def read_until_close(self, callback=None, streaming_callback=None):
@@ -293,9 +322,16 @@ class BaseIOStream(object):
If a callback is given, it will be run with the data as an argument;
if not, this method returns a `.Future`.
+ Note that if a ``streaming_callback`` is used, data will be
+ read from the socket as quickly as it becomes available; there
+ is no way to apply backpressure or cancel the reads. If flow
+ control or cancellation are desired, use a loop with
+ `read_bytes(partial=True) <.read_bytes>` instead.
+
.. versionchanged:: 4.0
The callback argument is now optional and a `.Future` will
be returned if it is omitted.
+
"""
future = self._set_read_callback(callback)
self._streaming_callback = stack_context.wrap(streaming_callback)
@@ -305,7 +341,11 @@ class BaseIOStream(object):
self._run_read_callback(self._read_buffer_size, False)
return future
self._read_until_close = True
- self._try_inline_read()
+ try:
+ self._try_inline_read()
+ except:
+ future.add_done_callback(lambda f: f.exception())
+ raise
return future
def write(self, data, callback=None):
@@ -331,7 +371,7 @@ class BaseIOStream(object):
if data:
if (self.max_write_buffer_size is not None and
self._write_buffer_size + len(data) > self.max_write_buffer_size):
- raise StreamBufferFullError("Reached maximum read buffer size")
+ raise StreamBufferFullError("Reached maximum write buffer size")
# Break up large contiguous strings before inserting them in the
# write buffer, so we don't have to recopy the entire thing
# as we slice off pieces to send to the socket.
@@ -344,6 +384,7 @@ class BaseIOStream(object):
future = None
else:
future = self._write_future = TracebackFuture()
+ future.add_done_callback(lambda f: f.exception())
if not self._connecting:
self._handle_write()
if self._write_buffer:
@@ -401,9 +442,11 @@ class BaseIOStream(object):
if self._connect_future is not None:
futures.append(self._connect_future)
self._connect_future = None
+ if self._ssl_connect_future is not None:
+ futures.append(self._ssl_connect_future)
+ self._ssl_connect_future = None
for future in futures:
- if (isinstance(self.error, (socket.error, IOError)) and
- errno_from_exception(self.error) in _ERRNO_CONNRESET):
+ if self._is_connreset(self.error):
# Treat connection resets as closed connections so
# clients only have to catch one kind of exception
# to avoid logging.
@@ -601,9 +644,8 @@ class BaseIOStream(object):
pos = self._read_to_buffer_loop()
except UnsatisfiableReadError:
raise
- except Exception as e:
- if 1 != e.errno:
- gen_log.warning("error on read", exc_info=True)
+ except Exception:
+ gen_log.warning("error on read", exc_info=True)
self.close(exc_info=True)
return
if pos is not None:
@@ -627,13 +669,13 @@ class BaseIOStream(object):
else:
callback = self._read_callback
self._read_callback = self._streaming_callback = None
- if self._read_future is not None:
- assert callback is None
- future = self._read_future
- self._read_future = None
- future.set_result(self._consume(size))
+ if self._read_future is not None:
+ assert callback is None
+ future = self._read_future
+ self._read_future = None
+ future.set_result(self._consume(size))
if callback is not None:
- assert self._read_future is None
+ assert (self._read_future is None) or streaming
self._run_callback(callback, self._consume(size))
else:
# If we scheduled a callback, we will add the error listener
@@ -684,7 +726,7 @@ class BaseIOStream(object):
chunk = self.read_from_fd()
except (socket.error, IOError, OSError) as e:
# ssl.SSLError is a subclass of socket.error
- if e.args[0] in _ERRNO_CONNRESET:
+ if self._is_connreset(e):
# Treat ECONNRESET as a connection close rather than
# an error to minimize log spam (the exception will
# be available on self.error for apps that care).
@@ -806,7 +848,7 @@ class BaseIOStream(object):
self._write_buffer_frozen = True
break
else:
- if e.args[0] not in _ERRNO_CONNRESET:
+ if not self._is_connreset(e):
# Broken pipe errors are usually caused by connection
# reset, and its better to not log EPIPE errors to
# minimize log spam
@@ -884,6 +926,14 @@ class BaseIOStream(object):
self._state = self._state | state
self.io_loop.update_handler(self.fileno(), self._state)
+ def _is_connreset(self, exc):
+ """Return true if exc is ECONNRESET or equivalent.
+
+ May be overridden in subclasses.
+ """
+ return (isinstance(exc, (socket.error, IOError)) and
+ errno_from_exception(exc) in _ERRNO_CONNRESET)
+
class IOStream(BaseIOStream):
r"""Socket-based `IOStream` implementation.
@@ -898,7 +948,9 @@ class IOStream(BaseIOStream):
connected before passing it to the `IOStream` or connected with
`IOStream.connect`.
- A very simple (and broken) HTTP client using this class::
+ A very simple (and broken) HTTP client using this class:
+
+ .. testcode::
import tornado.ioloop
import tornado.iostream
@@ -917,14 +969,19 @@ class IOStream(BaseIOStream):
stream.read_bytes(int(headers[b"Content-Length"]), on_body)
def on_body(data):
- print data
+ print(data)
stream.close()
- tornado.ioloop.IOLoop.instance().stop()
+ tornado.ioloop.IOLoop.current().stop()
+
+ if __name__ == '__main__':
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+ stream = tornado.iostream.IOStream(s)
+ stream.connect(("friendfeed.com", 80), send_request)
+ tornado.ioloop.IOLoop.current().start()
+
+ .. testoutput::
+ :hide:
- s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
- stream = tornado.iostream.IOStream(s)
- stream.connect(("friendfeed.com", 80), send_request)
- tornado.ioloop.IOLoop.instance().start()
"""
def __init__(self, socket, *args, **kwargs):
self.socket = socket
@@ -978,10 +1035,10 @@ class IOStream(BaseIOStream):
returns a `.Future` (whose result after a successful
connection will be the stream itself).
- If specified, the ``server_hostname`` parameter will be used
- in SSL connections for certificate validation (if requested in
- the ``ssl_options``) and SNI (if supported; requires
- Python 3.2+).
+ In SSL mode, the ``server_hostname`` parameter will be used
+ for certificate validation (unless disabled in the
+ ``ssl_options``) and SNI (if supported; requires Python
+ 2.7.9+).
Note that it is safe to call `IOStream.write
` while the connection is pending, in
@@ -992,6 +1049,11 @@ class IOStream(BaseIOStream):
.. versionchanged:: 4.0
If no callback is given, returns a `.Future`.
+ .. versionchanged:: 4.2
+ SSL certificates are validated by default; pass
+ ``ssl_options=dict(cert_reqs=ssl.CERT_NONE)`` or a
+ suitably-configured `ssl.SSLContext` to the
+ `SSLIOStream` constructor to disable.
"""
self._connecting = True
if callback is not None:
@@ -1011,8 +1073,9 @@ class IOStream(BaseIOStream):
# reported later in _handle_connect.
if (errno_from_exception(e) not in _ERRNO_INPROGRESS and
errno_from_exception(e) not in _ERRNO_WOULDBLOCK):
- gen_log.warning("Connect error on fd %s: %s",
- self.socket.fileno(), e)
+ if future is None:
+ gen_log.warning("Connect error on fd %s: %s",
+ self.socket.fileno(), e)
self.close(exc_info=True)
return future
self._add_io_state(self.io_loop.WRITE)
@@ -1033,10 +1096,11 @@ class IOStream(BaseIOStream):
data. It can also be used immediately after connecting,
before any reads or writes.
- The ``ssl_options`` argument may be either a dictionary
- of options or an `ssl.SSLContext`. If a ``server_hostname``
- is given, it will be used for certificate verification
- (as configured in the ``ssl_options``).
+ The ``ssl_options`` argument may be either an `ssl.SSLContext`
+ object or a dictionary of keyword arguments for the
+ `ssl.wrap_socket` function. The ``server_hostname`` argument
+ will be used for certificate validation unless disabled
+ in the ``ssl_options``.
This method returns a `.Future` whose result is the new
`SSLIOStream`. After this method has been called,
@@ -1046,6 +1110,11 @@ class IOStream(BaseIOStream):
transferred to the new stream.
.. versionadded:: 4.0
+
+ .. versionchanged:: 4.2
+ SSL certificates are validated by default; pass
+ ``ssl_options=dict(cert_reqs=ssl.CERT_NONE)`` or a
+ suitably-configured `ssl.SSLContext` to disable.
"""
if (self._read_callback or self._read_future or
self._write_callback or self._write_future or
@@ -1054,12 +1123,17 @@ class IOStream(BaseIOStream):
self._read_buffer or self._write_buffer):
raise ValueError("IOStream is not idle; cannot convert to SSL")
if ssl_options is None:
- ssl_options = {}
+ if server_side:
+ ssl_options = _server_ssl_defaults
+ else:
+ ssl_options = _client_ssl_defaults
socket = self.socket
self.io_loop.remove_handler(socket)
self.socket = None
- socket = ssl_wrap_socket(socket, ssl_options, server_side=server_side,
+ socket = ssl_wrap_socket(socket, ssl_options,
+ server_hostname=server_hostname,
+ server_side=server_side,
do_handshake_on_connect=False)
orig_close_callback = self._close_callback
self._close_callback = None
@@ -1071,6 +1145,7 @@ class IOStream(BaseIOStream):
# If we had an "unwrap" counterpart to this method we would need
# to restore the original callback after our Future resolves
# so that repeated wrap/unwrap calls don't build up layers.
+
def close_callback():
if not future.done():
future.set_exception(ssl_stream.error or StreamClosedError())
@@ -1115,7 +1190,7 @@ class IOStream(BaseIOStream):
# Sometimes setsockopt will fail if the socket is closed
# at the wrong time. This can happen with HTTPServer
# resetting the value to false between requests.
- if e.errno not in (errno.EINVAL, errno.ECONNRESET):
+ if e.errno != errno.EINVAL and not self._is_connreset(e):
raise
@@ -1131,11 +1206,11 @@ class SSLIOStream(IOStream):
wrapped when `IOStream.connect` is finished.
"""
def __init__(self, *args, **kwargs):
- """The ``ssl_options`` keyword argument may either be a dictionary
- of keywords arguments for `ssl.wrap_socket`, or an `ssl.SSLContext`
- object.
+ """The ``ssl_options`` keyword argument may either be an
+ `ssl.SSLContext` object or a dictionary of keywords arguments
+ for `ssl.wrap_socket`
"""
- self._ssl_options = kwargs.pop('ssl_options', {})
+ self._ssl_options = kwargs.pop('ssl_options', _client_ssl_defaults)
super(SSLIOStream, self).__init__(*args, **kwargs)
self._ssl_accepting = True
self._handshake_reading = False
@@ -1190,8 +1265,7 @@ class SSLIOStream(IOStream):
# to cause do_handshake to raise EBADF, so make that error
# quiet as well.
# https://groups.google.com/forum/?fromgroups#!topic/python-tornado/ApucKJat1_0
- if (err.args[0] in _ERRNO_CONNRESET or
- err.args[0] == errno.EBADF):
+ if self._is_connreset(err) or err.args[0] == errno.EBADF:
return self.close(exc_info=True)
raise
except AttributeError:
@@ -1204,10 +1278,17 @@ class SSLIOStream(IOStream):
if not self._verify_cert(self.socket.getpeercert()):
self.close()
return
- if self._ssl_connect_callback is not None:
- callback = self._ssl_connect_callback
- self._ssl_connect_callback = None
- self._run_callback(callback)
+ self._run_ssl_connect_callback()
+
+ def _run_ssl_connect_callback(self):
+ if self._ssl_connect_callback is not None:
+ callback = self._ssl_connect_callback
+ self._ssl_connect_callback = None
+ self._run_callback(callback)
+ if self._ssl_connect_future is not None:
+ future = self._ssl_connect_future
+ self._ssl_connect_future = None
+ future.set_result(self)
def _verify_cert(self, peercert):
"""Returns True if peercert is valid according to the configured
@@ -1249,14 +1330,11 @@ class SSLIOStream(IOStream):
super(SSLIOStream, self)._handle_write()
def connect(self, address, callback=None, server_hostname=None):
- # Save the user's callback and run it after the ssl handshake
- # has completed.
- self._ssl_connect_callback = stack_context.wrap(callback)
self._server_hostname = server_hostname
- # Note: Since we don't pass our callback argument along to
- # super.connect(), this will always return a Future.
- # This is harmless, but a bit less efficient than it could be.
- return super(SSLIOStream, self).connect(address, callback=None)
+ # Pass a dummy callback to super.connect(), which is slightly
+ # more efficient than letting it return a Future we ignore.
+ super(SSLIOStream, self).connect(address, callback=lambda: None)
+ return self.wait_for_handshake(callback)
def _handle_connect(self):
# Call the superclass method to check for errors.
@@ -1281,6 +1359,51 @@ class SSLIOStream(IOStream):
do_handshake_on_connect=False)
self._add_io_state(old_state)
+ def wait_for_handshake(self, callback=None):
+ """Wait for the initial SSL handshake to complete.
+
+ If a ``callback`` is given, it will be called with no
+ arguments once the handshake is complete; otherwise this
+ method returns a `.Future` which will resolve to the
+ stream itself after the handshake is complete.
+
+ Once the handshake is complete, information such as
+ the peer's certificate and NPN/ALPN selections may be
+ accessed on ``self.socket``.
+
+ This method is intended for use on server-side streams
+ or after using `IOStream.start_tls`; it should not be used
+ with `IOStream.connect` (which already waits for the
+ handshake to complete). It may only be called once per stream.
+
+ .. versionadded:: 4.2
+ """
+ if (self._ssl_connect_callback is not None or
+ self._ssl_connect_future is not None):
+ raise RuntimeError("Already waiting")
+ if callback is not None:
+ self._ssl_connect_callback = stack_context.wrap(callback)
+ future = None
+ else:
+ future = self._ssl_connect_future = TracebackFuture()
+ if not self._ssl_accepting:
+ self._run_ssl_connect_callback()
+ return future
+
+ def write_to_fd(self, data):
+ try:
+ return self.socket.send(data)
+ except ssl.SSLError as e:
+ if e.args[0] == ssl.SSL_ERROR_WANT_WRITE:
+ # In Python 3.5+, SSLSocket.send raises a WANT_WRITE error if
+ # the socket is not writeable; we need to transform this into
+ # an EWOULDBLOCK socket.error or a zero return value,
+ # either of which will be recognized by the caller of this
+ # method. Prior to Python 3.5, an unwriteable socket would
+ # simply return 0 bytes written.
+ return 0
+ raise
+
def read_from_fd(self):
if self._ssl_accepting:
# If the handshake hasn't finished yet, there can't be anything
@@ -1311,6 +1434,11 @@ class SSLIOStream(IOStream):
return None
return chunk
+ def _is_connreset(self, e):
+ if isinstance(e, ssl.SSLError) and e.args[0] == ssl.SSL_ERROR_EOF:
+ return True
+ return super(SSLIOStream, self)._is_connreset(e)
+
class PipeIOStream(BaseIOStream):
"""Pipe-based `IOStream` implementation.
diff --git a/tornado/locale.py b/tornado/locale.py
index 07c6d582..a668765b 100644
--- a/tornado/locale.py
+++ b/tornado/locale.py
@@ -55,6 +55,7 @@ _default_locale = "en_US"
_translations = {}
_supported_locales = frozenset([_default_locale])
_use_gettext = False
+CONTEXT_SEPARATOR = "\x04"
def get(*locale_codes):
@@ -273,6 +274,9 @@ class Locale(object):
"""
raise NotImplementedError()
+ def pgettext(self, context, message, plural_message=None, count=None):
+ raise NotImplementedError()
+
def format_date(self, date, gmt_offset=0, relative=True, shorter=False,
full_format=False):
"""Formats the given date (which should be GMT).
@@ -422,6 +426,11 @@ class CSVLocale(Locale):
message_dict = self.translations.get("unknown", {})
return message_dict.get(message, message)
+ def pgettext(self, context, message, plural_message=None, count=None):
+ if self.translations:
+ gen_log.warning('pgettext is not supported by CSVLocale')
+ return self.translate(message, plural_message, count)
+
class GettextLocale(Locale):
"""Locale implementation using the `gettext` module."""
@@ -445,6 +454,44 @@ class GettextLocale(Locale):
else:
return self.gettext(message)
+ def pgettext(self, context, message, plural_message=None, count=None):
+ """Allows to set context for translation, accepts plural forms.
+
+ Usage example::
+
+ pgettext("law", "right")
+ pgettext("good", "right")
+
+ Plural message example::
+
+ pgettext("organization", "club", "clubs", len(clubs))
+ pgettext("stick", "club", "clubs", len(clubs))
+
+ To generate POT file with context, add following options to step 1
+ of `load_gettext_translations` sequence::
+
+ xgettext [basic options] --keyword=pgettext:1c,2 --keyword=pgettext:1c,2,3
+
+ .. versionadded:: 4.2
+ """
+ if plural_message is not None:
+ assert count is not None
+ msgs_with_ctxt = ("%s%s%s" % (context, CONTEXT_SEPARATOR, message),
+ "%s%s%s" % (context, CONTEXT_SEPARATOR, plural_message),
+ count)
+ result = self.ngettext(*msgs_with_ctxt)
+ if CONTEXT_SEPARATOR in result:
+ # Translation not found
+ result = self.ngettext(message, plural_message, count)
+ return result
+ else:
+ msg_with_ctxt = "%s%s%s" % (context, CONTEXT_SEPARATOR, message)
+ result = self.gettext(msg_with_ctxt)
+ if CONTEXT_SEPARATOR in result:
+ # Translation not found
+ result = message
+ return result
+
LOCALE_NAMES = {
"af_ZA": {"name_en": u("Afrikaans"), "name": u("Afrikaans")},
"am_ET": {"name_en": u("Amharic"), "name": u('\u12a0\u121b\u122d\u129b')},
diff --git a/tornado/locks.py b/tornado/locks.py
new file mode 100644
index 00000000..4b0bdb38
--- /dev/null
+++ b/tornado/locks.py
@@ -0,0 +1,460 @@
+# Copyright 2015 The Tornado Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+.. testsetup:: *
+
+ from tornado import ioloop, gen, locks
+ io_loop = ioloop.IOLoop.current()
+"""
+
+from __future__ import absolute_import, division, print_function, with_statement
+
+__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock']
+
+import collections
+
+from tornado import gen, ioloop
+from tornado.concurrent import Future
+
+
+class _TimeoutGarbageCollector(object):
+ """Base class for objects that periodically clean up timed-out waiters.
+
+ Avoids memory leak in a common pattern like:
+
+ while True:
+ yield condition.wait(short_timeout)
+ print('looping....')
+ """
+ def __init__(self):
+ self._waiters = collections.deque() # Futures.
+ self._timeouts = 0
+
+ def _garbage_collect(self):
+ # Occasionally clear timed-out waiters.
+ self._timeouts += 1
+ if self._timeouts > 100:
+ self._timeouts = 0
+ self._waiters = collections.deque(
+ w for w in self._waiters if not w.done())
+
+
+class Condition(_TimeoutGarbageCollector):
+ """A condition allows one or more coroutines to wait until notified.
+
+ Like a standard `threading.Condition`, but does not need an underlying lock
+ that is acquired and released.
+
+ With a `Condition`, coroutines can wait to be notified by other coroutines:
+
+ .. testcode::
+
+ condition = locks.Condition()
+
+ @gen.coroutine
+ def waiter():
+ print("I'll wait right here")
+ yield condition.wait() # Yield a Future.
+ print("I'm done waiting")
+
+ @gen.coroutine
+ def notifier():
+ print("About to notify")
+ condition.notify()
+ print("Done notifying")
+
+ @gen.coroutine
+ def runner():
+ # Yield two Futures; wait for waiter() and notifier() to finish.
+ yield [waiter(), notifier()]
+
+ io_loop.run_sync(runner)
+
+ .. testoutput::
+
+ I'll wait right here
+ About to notify
+ Done notifying
+ I'm done waiting
+
+ `wait` takes an optional ``timeout`` argument, which is either an absolute
+ timestamp::
+
+ io_loop = ioloop.IOLoop.current()
+
+ # Wait up to 1 second for a notification.
+ yield condition.wait(timeout=io_loop.time() + 1)
+
+ ...or a `datetime.timedelta` for a timeout relative to the current time::
+
+ # Wait up to 1 second.
+ yield condition.wait(timeout=datetime.timedelta(seconds=1))
+
+ The method raises `tornado.gen.TimeoutError` if there's no notification
+ before the deadline.
+ """
+
+ def __init__(self):
+ super(Condition, self).__init__()
+ self.io_loop = ioloop.IOLoop.current()
+
+ def __repr__(self):
+ result = '<%s' % (self.__class__.__name__, )
+ if self._waiters:
+ result += ' waiters[%s]' % len(self._waiters)
+ return result + '>'
+
+ def wait(self, timeout=None):
+ """Wait for `.notify`.
+
+ Returns a `.Future` that resolves ``True`` if the condition is notified,
+ or ``False`` after a timeout.
+ """
+ waiter = Future()
+ self._waiters.append(waiter)
+ if timeout:
+ def on_timeout():
+ waiter.set_result(False)
+ self._garbage_collect()
+ io_loop = ioloop.IOLoop.current()
+ timeout_handle = io_loop.add_timeout(timeout, on_timeout)
+ waiter.add_done_callback(
+ lambda _: io_loop.remove_timeout(timeout_handle))
+ return waiter
+
+ def notify(self, n=1):
+ """Wake ``n`` waiters."""
+ waiters = [] # Waiters we plan to run right now.
+ while n and self._waiters:
+ waiter = self._waiters.popleft()
+ if not waiter.done(): # Might have timed out.
+ n -= 1
+ waiters.append(waiter)
+
+ for waiter in waiters:
+ waiter.set_result(True)
+
+ def notify_all(self):
+ """Wake all waiters."""
+ self.notify(len(self._waiters))
+
+
+class Event(object):
+ """An event blocks coroutines until its internal flag is set to True.
+
+ Similar to `threading.Event`.
+
+ A coroutine can wait for an event to be set. Once it is set, calls to
+ ``yield event.wait()`` will not block unless the event has been cleared:
+
+ .. testcode::
+
+ event = locks.Event()
+
+ @gen.coroutine
+ def waiter():
+ print("Waiting for event")
+ yield event.wait()
+ print("Not waiting this time")
+ yield event.wait()
+ print("Done")
+
+ @gen.coroutine
+ def setter():
+ print("About to set the event")
+ event.set()
+
+ @gen.coroutine
+ def runner():
+ yield [waiter(), setter()]
+
+ io_loop.run_sync(runner)
+
+ .. testoutput::
+
+ Waiting for event
+ About to set the event
+ Not waiting this time
+ Done
+ """
+ def __init__(self):
+ self._future = Future()
+
+ def __repr__(self):
+ return '<%s %s>' % (
+ self.__class__.__name__, 'set' if self.is_set() else 'clear')
+
+ def is_set(self):
+ """Return ``True`` if the internal flag is true."""
+ return self._future.done()
+
+ def set(self):
+ """Set the internal flag to ``True``. All waiters are awakened.
+
+ Calling `.wait` once the flag is set will not block.
+ """
+ if not self._future.done():
+ self._future.set_result(None)
+
+ def clear(self):
+ """Reset the internal flag to ``False``.
+
+ Calls to `.wait` will block until `.set` is called.
+ """
+ if self._future.done():
+ self._future = Future()
+
+ def wait(self, timeout=None):
+ """Block until the internal flag is true.
+
+ Returns a Future, which raises `tornado.gen.TimeoutError` after a
+ timeout.
+ """
+ if timeout is None:
+ return self._future
+ else:
+ return gen.with_timeout(timeout, self._future)
+
+
+class _ReleasingContextManager(object):
+ """Releases a Lock or Semaphore at the end of a "with" statement.
+
+ with (yield semaphore.acquire()):
+ pass
+
+ # Now semaphore.release() has been called.
+ """
+ def __init__(self, obj):
+ self._obj = obj
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self._obj.release()
+
+
+class Semaphore(_TimeoutGarbageCollector):
+ """A lock that can be acquired a fixed number of times before blocking.
+
+ A Semaphore manages a counter representing the number of `.release` calls
+ minus the number of `.acquire` calls, plus an initial value. The `.acquire`
+ method blocks if necessary until it can return without making the counter
+ negative.
+
+ Semaphores limit access to a shared resource. To allow access for two
+ workers at a time:
+
+ .. testsetup:: semaphore
+
+ from collections import deque
+
+ from tornado import gen, ioloop
+ from tornado.concurrent import Future
+
+ # Ensure reliable doctest output: resolve Futures one at a time.
+ futures_q = deque([Future() for _ in range(3)])
+
+ @gen.coroutine
+ def simulator(futures):
+ for f in futures:
+ yield gen.moment
+ f.set_result(None)
+
+ ioloop.IOLoop.current().add_callback(simulator, list(futures_q))
+
+ def use_some_resource():
+ return futures_q.popleft()
+
+ .. testcode:: semaphore
+
+ sem = locks.Semaphore(2)
+
+ @gen.coroutine
+ def worker(worker_id):
+ yield sem.acquire()
+ try:
+ print("Worker %d is working" % worker_id)
+ yield use_some_resource()
+ finally:
+ print("Worker %d is done" % worker_id)
+ sem.release()
+
+ @gen.coroutine
+ def runner():
+ # Join all workers.
+ yield [worker(i) for i in range(3)]
+
+ io_loop.run_sync(runner)
+
+ .. testoutput:: semaphore
+
+ Worker 0 is working
+ Worker 1 is working
+ Worker 0 is done
+ Worker 2 is working
+ Worker 1 is done
+ Worker 2 is done
+
+ Workers 0 and 1 are allowed to run concurrently, but worker 2 waits until
+ the semaphore has been released once, by worker 0.
+
+ `.acquire` is a context manager, so ``worker`` could be written as::
+
+ @gen.coroutine
+ def worker(worker_id):
+ with (yield sem.acquire()):
+ print("Worker %d is working" % worker_id)
+ yield use_some_resource()
+
+ # Now the semaphore has been released.
+ print("Worker %d is done" % worker_id)
+ """
+ def __init__(self, value=1):
+ super(Semaphore, self).__init__()
+ if value < 0:
+ raise ValueError('semaphore initial value must be >= 0')
+
+ self._value = value
+
+ def __repr__(self):
+ res = super(Semaphore, self).__repr__()
+ extra = 'locked' if self._value == 0 else 'unlocked,value:{0}'.format(
+ self._value)
+ if self._waiters:
+ extra = '{0},waiters:{1}'.format(extra, len(self._waiters))
+ return '<{0} [{1}]>'.format(res[1:-1], extra)
+
+ def release(self):
+ """Increment the counter and wake one waiter."""
+ self._value += 1
+ while self._waiters:
+ waiter = self._waiters.popleft()
+ if not waiter.done():
+ self._value -= 1
+
+ # If the waiter is a coroutine paused at
+ #
+ # with (yield semaphore.acquire()):
+ #
+ # then the context manager's __exit__ calls release() at the end
+ # of the "with" block.
+ waiter.set_result(_ReleasingContextManager(self))
+ break
+
+ def acquire(self, timeout=None):
+ """Decrement the counter. Returns a Future.
+
+ Block if the counter is zero and wait for a `.release`. The Future
+ raises `.TimeoutError` after the deadline.
+ """
+ waiter = Future()
+ if self._value > 0:
+ self._value -= 1
+ waiter.set_result(_ReleasingContextManager(self))
+ else:
+ self._waiters.append(waiter)
+ if timeout:
+ def on_timeout():
+ waiter.set_exception(gen.TimeoutError())
+ self._garbage_collect()
+ io_loop = ioloop.IOLoop.current()
+ timeout_handle = io_loop.add_timeout(timeout, on_timeout)
+ waiter.add_done_callback(
+ lambda _: io_loop.remove_timeout(timeout_handle))
+ return waiter
+
+ def __enter__(self):
+ raise RuntimeError(
+ "Use Semaphore like 'with (yield semaphore.acquire())', not like"
+ " 'with semaphore'")
+
+ __exit__ = __enter__
+
+
+class BoundedSemaphore(Semaphore):
+ """A semaphore that prevents release() being called too many times.
+
+ If `.release` would increment the semaphore's value past the initial
+ value, it raises `ValueError`. Semaphores are mostly used to guard
+ resources with limited capacity, so a semaphore released too many times
+ is a sign of a bug.
+ """
+ def __init__(self, value=1):
+ super(BoundedSemaphore, self).__init__(value=value)
+ self._initial_value = value
+
+ def release(self):
+ """Increment the counter and wake one waiter."""
+ if self._value >= self._initial_value:
+ raise ValueError("Semaphore released too many times")
+ super(BoundedSemaphore, self).release()
+
+
+class Lock(object):
+ """A lock for coroutines.
+
+ A Lock begins unlocked, and `acquire` locks it immediately. While it is
+ locked, a coroutine that yields `acquire` waits until another coroutine
+ calls `release`.
+
+ Releasing an unlocked lock raises `RuntimeError`.
+
+ `acquire` supports the context manager protocol:
+
+ >>> from tornado import gen, locks
+ >>> lock = locks.Lock()
+ >>>
+ >>> @gen.coroutine
+ ... def f():
+ ... with (yield lock.acquire()):
+ ... # Do something holding the lock.
+ ... pass
+ ...
+ ... # Now the lock is released.
+ """
+ def __init__(self):
+ self._block = BoundedSemaphore(value=1)
+
+ def __repr__(self):
+ return "<%s _block=%s>" % (
+ self.__class__.__name__,
+ self._block)
+
+ def acquire(self, timeout=None):
+ """Attempt to lock. Returns a Future.
+
+ Returns a Future, which raises `tornado.gen.TimeoutError` after a
+ timeout.
+ """
+ return self._block.acquire(timeout)
+
+ def release(self):
+ """Unlock.
+
+ The first coroutine in line waiting for `acquire` gets the lock.
+
+ If not locked, raise a `RuntimeError`.
+ """
+ try:
+ self._block.release()
+ except ValueError:
+ raise RuntimeError('release unlocked lock')
+
+ def __enter__(self):
+ raise RuntimeError(
+ "Use Lock like 'with (yield lock)', not like 'with lock'")
+
+ __exit__ = __enter__
diff --git a/tornado/log.py b/tornado/log.py
index 374071d4..c68dec46 100644
--- a/tornado/log.py
+++ b/tornado/log.py
@@ -206,6 +206,14 @@ def enable_pretty_logging(options=None, logger=None):
def define_logging_options(options=None):
+ """Add logging-related flags to ``options``.
+
+ These options are present automatically on the default options instance;
+ this method is only necessary if you have created your own `.OptionParser`.
+
+ .. versionadded:: 4.2
+ This function existed in prior versions but was broken and undocumented until 4.2.
+ """
if options is None:
# late import to prevent cycle
from tornado.options import options
@@ -227,4 +235,4 @@ def define_logging_options(options=None):
options.define("log_file_num_backups", type=int, default=10,
help="number of log files to keep")
- options.add_parse_callback(enable_pretty_logging)
+ options.add_parse_callback(lambda: enable_pretty_logging(options))
diff --git a/tornado/netutil.py b/tornado/netutil.py
index f147c974..9aa292c4 100644
--- a/tornado/netutil.py
+++ b/tornado/netutil.py
@@ -20,7 +20,7 @@ from __future__ import absolute_import, division, print_function, with_statement
import errno
import os
-import platform
+import sys
import socket
import stat
@@ -35,6 +35,15 @@ except ImportError:
# ssl is not available on Google App Engine
ssl = None
+try:
+ import certifi
+except ImportError:
+ # certifi is optional as long as we have ssl.create_default_context.
+ if ssl is None or hasattr(ssl, 'create_default_context'):
+ certifi = None
+ else:
+ raise
+
try:
xrange # py2
except NameError:
@@ -50,6 +59,38 @@ else:
ssl_match_hostname = backports.ssl_match_hostname.match_hostname
SSLCertificateError = backports.ssl_match_hostname.CertificateError
+if hasattr(ssl, 'SSLContext'):
+ if hasattr(ssl, 'create_default_context'):
+ # Python 2.7.9+, 3.4+
+ # Note that the naming of ssl.Purpose is confusing; the purpose
+ # of a context is to authentiate the opposite side of the connection.
+ _client_ssl_defaults = ssl.create_default_context(
+ ssl.Purpose.SERVER_AUTH)
+ _server_ssl_defaults = ssl.create_default_context(
+ ssl.Purpose.CLIENT_AUTH)
+ else:
+ # Python 3.2-3.3
+ _client_ssl_defaults = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ _client_ssl_defaults.verify_mode = ssl.CERT_REQUIRED
+ _client_ssl_defaults.load_verify_locations(certifi.where())
+ _server_ssl_defaults = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ if hasattr(ssl, 'OP_NO_COMPRESSION'):
+ # Disable TLS compression to avoid CRIME and related attacks.
+ # This constant wasn't added until python 3.3.
+ _client_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
+ _server_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
+
+elif ssl:
+ # Python 2.6-2.7.8
+ _client_ssl_defaults = dict(cert_reqs=ssl.CERT_REQUIRED,
+ ca_certs=certifi.where())
+ _server_ssl_defaults = {}
+else:
+ # Google App Engine
+ _client_ssl_defaults = dict(cert_reqs=None,
+ ca_certs=None)
+ _server_ssl_defaults = {}
+
# ThreadedResolver runs getaddrinfo on a thread. If the hostname is unicode,
# getaddrinfo attempts to import encodings.idna. If this is done at
# module-import time, the import lock is already held by the main thread,
@@ -68,6 +109,7 @@ if hasattr(errno, "WSAEWOULDBLOCK"):
# Default backlog used when calling sock.listen()
_DEFAULT_BACKLOG = 128
+
def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
backlog=_DEFAULT_BACKLOG, flags=None):
"""Creates listening sockets bound to the given port and address.
@@ -105,7 +147,7 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
for res in set(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM,
0, flags)):
af, socktype, proto, canonname, sockaddr = res
- if (platform.system() == 'Darwin' and address == 'localhost' and
+ if (sys.platform == 'darwin' and address == 'localhost' and
af == socket.AF_INET6 and sockaddr[3] != 0):
# Mac OS X includes a link-local address fe80::1%lo0 in the
# getaddrinfo results for 'localhost'. However, the firewall
@@ -187,6 +229,9 @@ def add_accept_handler(sock, callback, io_loop=None):
address of the other end of the connection). Note that this signature
is different from the ``callback(fd, events)`` signature used for
`.IOLoop` handlers.
+
+ .. versionchanged:: 4.1
+ The ``io_loop`` argument is deprecated.
"""
if io_loop is None:
io_loop = IOLoop.current()
@@ -301,6 +346,9 @@ class ExecutorResolver(Resolver):
The executor will be shut down when the resolver is closed unless
``close_resolver=False``; use this if you want to reuse the same
executor elsewhere.
+
+ .. versionchanged:: 4.1
+ The ``io_loop`` argument is deprecated.
"""
def initialize(self, io_loop=None, executor=None, close_executor=True):
self.io_loop = io_loop or IOLoop.current()
@@ -413,7 +461,7 @@ def ssl_options_to_context(ssl_options):
`~ssl.SSLContext` object.
The ``ssl_options`` dictionary contains keywords to be passed to
- `ssl.wrap_socket`. In Python 3.2+, `ssl.SSLContext` objects can
+ `ssl.wrap_socket`. In Python 2.7.9+, `ssl.SSLContext` objects can
be used instead. This function converts the dict form to its
`~ssl.SSLContext` equivalent, and may be used when a component which
accepts both forms needs to upgrade to the `~ssl.SSLContext` version
@@ -444,11 +492,11 @@ def ssl_options_to_context(ssl_options):
def ssl_wrap_socket(socket, ssl_options, server_hostname=None, **kwargs):
"""Returns an ``ssl.SSLSocket`` wrapping the given socket.
- ``ssl_options`` may be either a dictionary (as accepted by
- `ssl_options_to_context`) or an `ssl.SSLContext` object.
- Additional keyword arguments are passed to ``wrap_socket``
- (either the `~ssl.SSLContext` method or the `ssl` module function
- as appropriate).
+ ``ssl_options`` may be either an `ssl.SSLContext` object or a
+ dictionary (as accepted by `ssl_options_to_context`). Additional
+ keyword arguments are passed to ``wrap_socket`` (either the
+ `~ssl.SSLContext` method or the `ssl` module function as
+ appropriate).
"""
context = ssl_options_to_context(ssl_options)
if hasattr(ssl, 'SSLContext') and isinstance(context, ssl.SSLContext):
diff --git a/tornado/options.py b/tornado/options.py
index 5e23e291..89a9e432 100644
--- a/tornado/options.py
+++ b/tornado/options.py
@@ -204,6 +204,13 @@ class OptionParser(object):
(name, self._options[name].file_name))
frame = sys._getframe(0)
options_file = frame.f_code.co_filename
+
+ # Can be called directly, or through top level define() fn, in which
+ # case, step up above that frame to look for real caller.
+ if (frame.f_back.f_code.co_filename == options_file and
+ frame.f_back.f_code.co_name == 'define'):
+ frame = frame.f_back
+
file_name = frame.f_back.f_code.co_filename
if file_name == options_file:
file_name = ""
@@ -249,7 +256,7 @@ class OptionParser(object):
arg = args[i].lstrip("-")
name, equals, value = arg.partition("=")
name = name.replace('-', '_')
- if not name in self._options:
+ if name not in self._options:
self.print_help()
raise Error('Unrecognized command line option: %r' % name)
option = self._options[name]
diff --git a/tornado/platform/asyncio.py b/tornado/platform/asyncio.py
index dd6722a4..8f3dbff6 100644
--- a/tornado/platform/asyncio.py
+++ b/tornado/platform/asyncio.py
@@ -12,6 +12,8 @@ unfinished callbacks on the event loop that fail when it resumes)
from __future__ import absolute_import, division, print_function, with_statement
import functools
+import tornado.concurrent
+from tornado.gen import convert_yielded
from tornado.ioloop import IOLoop
from tornado import stack_context
@@ -27,8 +29,10 @@ except ImportError as e:
# Re-raise the original asyncio error, not the trollius one.
raise e
+
class BaseAsyncIOLoop(IOLoop):
- def initialize(self, asyncio_loop, close_loop=False):
+ def initialize(self, asyncio_loop, close_loop=False, **kwargs):
+ super(BaseAsyncIOLoop, self).initialize(**kwargs)
self.asyncio_loop = asyncio_loop
self.close_loop = close_loop
self.asyncio_loop.call_soon(self.make_current)
@@ -129,12 +133,29 @@ class BaseAsyncIOLoop(IOLoop):
class AsyncIOMainLoop(BaseAsyncIOLoop):
- def initialize(self):
+ def initialize(self, **kwargs):
super(AsyncIOMainLoop, self).initialize(asyncio.get_event_loop(),
- close_loop=False)
+ close_loop=False, **kwargs)
class AsyncIOLoop(BaseAsyncIOLoop):
- def initialize(self):
+ def initialize(self, **kwargs):
super(AsyncIOLoop, self).initialize(asyncio.new_event_loop(),
- close_loop=True)
+ close_loop=True, **kwargs)
+
+
+def to_tornado_future(asyncio_future):
+ """Convert an ``asyncio.Future`` to a `tornado.concurrent.Future`."""
+ tf = tornado.concurrent.Future()
+ tornado.concurrent.chain_future(asyncio_future, tf)
+ return tf
+
+
+def to_asyncio_future(tornado_future):
+ """Convert a `tornado.concurrent.Future` to an ``asyncio.Future``."""
+ af = asyncio.Future()
+ tornado.concurrent.chain_future(tornado_future, af)
+ return af
+
+if hasattr(convert_yielded, 'register'):
+ convert_yielded.register(asyncio.Future, to_tornado_future)
diff --git a/tornado/platform/auto.py b/tornado/platform/auto.py
index ddfe06b4..fc40c9d9 100644
--- a/tornado/platform/auto.py
+++ b/tornado/platform/auto.py
@@ -27,13 +27,14 @@ from __future__ import absolute_import, division, print_function, with_statement
import os
-if os.name == 'nt':
- from tornado.platform.common import Waker
- from tornado.platform.windows import set_close_exec
-elif 'APPENGINE_RUNTIME' in os.environ:
+if 'APPENGINE_RUNTIME' in os.environ:
from tornado.platform.common import Waker
+
def set_close_exec(fd):
pass
+elif os.name == 'nt':
+ from tornado.platform.common import Waker
+ from tornado.platform.windows import set_close_exec
else:
from tornado.platform.posix import set_close_exec, Waker
@@ -41,9 +42,13 @@ try:
# monotime monkey-patches the time module to have a monotonic function
# in versions of python before 3.3.
import monotime
+ # Silence pyflakes warning about this unused import
+ monotime
except ImportError:
pass
try:
from time import monotonic as monotonic_time
except ImportError:
monotonic_time = None
+
+__all__ = ['Waker', 'set_close_exec', 'monotonic_time']
diff --git a/tornado/platform/caresresolver.py b/tornado/platform/caresresolver.py
index c4648c22..5559614f 100644
--- a/tornado/platform/caresresolver.py
+++ b/tornado/platform/caresresolver.py
@@ -18,6 +18,9 @@ class CaresResolver(Resolver):
so it is only recommended for use in ``AF_INET`` (i.e. IPv4). This is
the default for ``tornado.simple_httpclient``, but other libraries
may default to ``AF_UNSPEC``.
+
+ .. versionchanged:: 4.1
+ The ``io_loop`` argument is deprecated.
"""
def initialize(self, io_loop=None):
self.io_loop = io_loop or IOLoop.current()
diff --git a/tornado/platform/kqueue.py b/tornado/platform/kqueue.py
index de8c046d..f8f3e4a6 100644
--- a/tornado/platform/kqueue.py
+++ b/tornado/platform/kqueue.py
@@ -54,8 +54,7 @@ class _KQueue(object):
if events & IOLoop.WRITE:
kevents.append(select.kevent(
fd, filter=select.KQ_FILTER_WRITE, flags=flags))
- if events & IOLoop.READ or not kevents:
- # Always read when there is not a write
+ if events & IOLoop.READ:
kevents.append(select.kevent(
fd, filter=select.KQ_FILTER_READ, flags=flags))
# Even though control() takes a list, it seems to return EINVAL
diff --git a/tornado/platform/select.py b/tornado/platform/select.py
index 9a879562..db52ef91 100644
--- a/tornado/platform/select.py
+++ b/tornado/platform/select.py
@@ -47,7 +47,7 @@ class _Select(object):
# Closed connections are reported as errors by epoll and kqueue,
# but as zero-byte reads by select, so when errors are requested
# we need to listen for both read and error.
- self.read_fds.add(fd)
+ # self.read_fds.add(fd)
def modify(self, fd, events):
self.unregister(fd)
diff --git a/tornado/platform/twisted.py b/tornado/platform/twisted.py
index 27d991cd..7b3c8ca5 100644
--- a/tornado/platform/twisted.py
+++ b/tornado/platform/twisted.py
@@ -35,7 +35,7 @@ of the application::
tornado.platform.twisted.install()
from twisted.internet import reactor
-When the app is ready to start, call `IOLoop.instance().start()`
+When the app is ready to start, call `IOLoop.current().start()`
instead of `reactor.run()`.
It is also possible to create a non-global reactor by calling
@@ -70,8 +70,10 @@ import datetime
import functools
import numbers
import socket
+import sys
import twisted.internet.abstract
+from twisted.internet.defer import Deferred
from twisted.internet.posixbase import PosixReactorBase
from twisted.internet.interfaces import \
IReactorFDSet, IDelayedCall, IReactorTime, IReadDescriptor, IWriteDescriptor
@@ -84,6 +86,7 @@ import twisted.names.resolve
from zope.interface import implementer
+from tornado.concurrent import Future
from tornado.escape import utf8
from tornado import gen
import tornado.ioloop
@@ -147,6 +150,9 @@ class TornadoReactor(PosixReactorBase):
We override `mainLoop` instead of `doIteration` and must implement
timed call functionality on top of `IOLoop.add_timeout` rather than
using the implementation in `PosixReactorBase`.
+
+ .. versionchanged:: 4.1
+ The ``io_loop`` argument is deprecated.
"""
def __init__(self, io_loop=None):
if not io_loop:
@@ -356,7 +362,11 @@ class _TestReactor(TornadoReactor):
def install(io_loop=None):
- """Install this package as the default Twisted reactor."""
+ """Install this package as the default Twisted reactor.
+
+ .. versionchanged:: 4.1
+ The ``io_loop`` argument is deprecated.
+ """
if not io_loop:
io_loop = tornado.ioloop.IOLoop.current()
reactor = TornadoReactor(io_loop)
@@ -406,7 +416,8 @@ class TwistedIOLoop(tornado.ioloop.IOLoop):
because the ``SIGCHLD`` handlers used by Tornado and Twisted conflict
with each other.
"""
- def initialize(self, reactor=None):
+ def initialize(self, reactor=None, **kwargs):
+ super(TwistedIOLoop, self).initialize(**kwargs)
if reactor is None:
import twisted.internet.reactor
reactor = twisted.internet.reactor
@@ -512,6 +523,9 @@ class TwistedResolver(Resolver):
``socket.AF_UNSPEC``.
Requires Twisted 12.1 or newer.
+
+ .. versionchanged:: 4.1
+ The ``io_loop`` argument is deprecated.
"""
def initialize(self, io_loop=None):
self.io_loop = io_loop or IOLoop.current()
@@ -554,3 +568,18 @@ class TwistedResolver(Resolver):
(resolved_family, (resolved, port)),
]
raise gen.Return(result)
+
+if hasattr(gen.convert_yielded, 'register'):
+ @gen.convert_yielded.register(Deferred)
+ def _(d):
+ f = Future()
+
+ def errback(failure):
+ try:
+ failure.raiseException()
+ # Should never happen, but just in case
+ raise Exception("errback called without error")
+ except:
+ f.set_exc_info(sys.exc_info())
+ d.addCallbacks(f.set_result, errback)
+ return f
diff --git a/tornado/process.py b/tornado/process.py
index cea3dbd0..f580e192 100644
--- a/tornado/process.py
+++ b/tornado/process.py
@@ -29,6 +29,7 @@ import time
from binascii import hexlify
+from tornado.concurrent import Future
from tornado import ioloop
from tornado.iostream import PipeIOStream
from tornado.log import gen_log
@@ -48,6 +49,10 @@ except NameError:
long = int # py3
+# Re-export this exception for convenience.
+CalledProcessError = subprocess.CalledProcessError
+
+
def cpu_count():
"""Returns the number of processors on this machine."""
if multiprocessing is None:
@@ -191,6 +196,9 @@ class Subprocess(object):
``tornado.process.Subprocess.STREAM``, which will make the corresponding
attribute of the resulting Subprocess a `.PipeIOStream`.
* A new keyword argument ``io_loop`` may be used to pass in an IOLoop.
+
+ .. versionchanged:: 4.1
+ The ``io_loop`` argument is deprecated.
"""
STREAM = object()
@@ -255,6 +263,33 @@ class Subprocess(object):
Subprocess._waiting[self.pid] = self
Subprocess._try_cleanup_process(self.pid)
+ def wait_for_exit(self, raise_error=True):
+ """Returns a `.Future` which resolves when the process exits.
+
+ Usage::
+
+ ret = yield proc.wait_for_exit()
+
+ This is a coroutine-friendly alternative to `set_exit_callback`
+ (and a replacement for the blocking `subprocess.Popen.wait`).
+
+ By default, raises `subprocess.CalledProcessError` if the process
+ has a non-zero exit status. Use ``wait_for_exit(raise_error=False)``
+ to suppress this behavior and return the exit status without raising.
+
+ .. versionadded:: 4.2
+ """
+ future = Future()
+
+ def callback(ret):
+ if ret != 0 and raise_error:
+ # Unfortunately we don't have the original args any more.
+ future.set_exception(CalledProcessError(ret, None))
+ else:
+ future.set_result(ret)
+ self.set_exit_callback(callback)
+ return future
+
@classmethod
def initialize(cls, io_loop=None):
"""Initializes the ``SIGCHLD`` handler.
@@ -263,6 +298,9 @@ class Subprocess(object):
Note that the `.IOLoop` used for signal handling need not be the
same one used by individual Subprocess objects (as long as the
``IOLoops`` are each running in separate threads).
+
+ .. versionchanged:: 4.1
+ The ``io_loop`` argument is deprecated.
"""
if cls._initialized:
return
diff --git a/tornado/queues.py b/tornado/queues.py
new file mode 100644
index 00000000..55ab4834
--- /dev/null
+++ b/tornado/queues.py
@@ -0,0 +1,321 @@
+# Copyright 2015 The Tornado Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from __future__ import absolute_import, division, print_function, with_statement
+
+__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty']
+
+import collections
+import heapq
+
+from tornado import gen, ioloop
+from tornado.concurrent import Future
+from tornado.locks import Event
+
+
+class QueueEmpty(Exception):
+ """Raised by `.Queue.get_nowait` when the queue has no items."""
+ pass
+
+
+class QueueFull(Exception):
+ """Raised by `.Queue.put_nowait` when a queue is at its maximum size."""
+ pass
+
+
+def _set_timeout(future, timeout):
+ if timeout:
+ def on_timeout():
+ future.set_exception(gen.TimeoutError())
+ io_loop = ioloop.IOLoop.current()
+ timeout_handle = io_loop.add_timeout(timeout, on_timeout)
+ future.add_done_callback(
+ lambda _: io_loop.remove_timeout(timeout_handle))
+
+
+class Queue(object):
+ """Coordinate producer and consumer coroutines.
+
+ If maxsize is 0 (the default) the queue size is unbounded.
+
+ .. testcode::
+
+ q = queues.Queue(maxsize=2)
+
+ @gen.coroutine
+ def consumer():
+ while True:
+ item = yield q.get()
+ try:
+ print('Doing work on %s' % item)
+ yield gen.sleep(0.01)
+ finally:
+ q.task_done()
+
+ @gen.coroutine
+ def producer():
+ for item in range(5):
+ yield q.put(item)
+ print('Put %s' % item)
+
+ @gen.coroutine
+ def main():
+ consumer() # Start consumer.
+ yield producer() # Wait for producer to put all tasks.
+ yield q.join() # Wait for consumer to finish all tasks.
+ print('Done')
+
+ io_loop.run_sync(main)
+
+ .. testoutput::
+
+ Put 0
+ Put 1
+ Put 2
+ Doing work on 0
+ Doing work on 1
+ Put 3
+ Doing work on 2
+ Put 4
+ Doing work on 3
+ Doing work on 4
+ Done
+ """
+ def __init__(self, maxsize=0):
+ if maxsize is None:
+ raise TypeError("maxsize can't be None")
+
+ if maxsize < 0:
+ raise ValueError("maxsize can't be negative")
+
+ self._maxsize = maxsize
+ self._init()
+ self._getters = collections.deque([]) # Futures.
+ self._putters = collections.deque([]) # Pairs of (item, Future).
+ self._unfinished_tasks = 0
+ self._finished = Event()
+ self._finished.set()
+
+ @property
+ def maxsize(self):
+ """Number of items allowed in the queue."""
+ return self._maxsize
+
+ def qsize(self):
+ """Number of items in the queue."""
+ return len(self._queue)
+
+ def empty(self):
+ return not self._queue
+
+ def full(self):
+ if self.maxsize == 0:
+ return False
+ else:
+ return self.qsize() >= self.maxsize
+
+ def put(self, item, timeout=None):
+ """Put an item into the queue, perhaps waiting until there is room.
+
+ Returns a Future, which raises `tornado.gen.TimeoutError` after a
+ timeout.
+ """
+ try:
+ self.put_nowait(item)
+ except QueueFull:
+ future = Future()
+ self._putters.append((item, future))
+ _set_timeout(future, timeout)
+ return future
+ else:
+ return gen._null_future
+
+ def put_nowait(self, item):
+ """Put an item into the queue without blocking.
+
+ If no free slot is immediately available, raise `QueueFull`.
+ """
+ self._consume_expired()
+ if self._getters:
+ assert self.empty(), "queue non-empty, why are getters waiting?"
+ getter = self._getters.popleft()
+ self.__put_internal(item)
+ getter.set_result(self._get())
+ elif self.full():
+ raise QueueFull
+ else:
+ self.__put_internal(item)
+
+ def get(self, timeout=None):
+ """Remove and return an item from the queue.
+
+ Returns a Future which resolves once an item is available, or raises
+ `tornado.gen.TimeoutError` after a timeout.
+ """
+ future = Future()
+ try:
+ future.set_result(self.get_nowait())
+ except QueueEmpty:
+ self._getters.append(future)
+ _set_timeout(future, timeout)
+ return future
+
+ def get_nowait(self):
+ """Remove and return an item from the queue without blocking.
+
+ Return an item if one is immediately available, else raise
+ `QueueEmpty`.
+ """
+ self._consume_expired()
+ if self._putters:
+ assert self.full(), "queue not full, why are putters waiting?"
+ item, putter = self._putters.popleft()
+ self.__put_internal(item)
+ putter.set_result(None)
+ return self._get()
+ elif self.qsize():
+ return self._get()
+ else:
+ raise QueueEmpty
+
+ def task_done(self):
+ """Indicate that a formerly enqueued task is complete.
+
+ Used by queue consumers. For each `.get` used to fetch a task, a
+ subsequent call to `.task_done` tells the queue that the processing
+ on the task is complete.
+
+ If a `.join` is blocking, it resumes when all items have been
+ processed; that is, when every `.put` is matched by a `.task_done`.
+
+ Raises `ValueError` if called more times than `.put`.
+ """
+ if self._unfinished_tasks <= 0:
+ raise ValueError('task_done() called too many times')
+ self._unfinished_tasks -= 1
+ if self._unfinished_tasks == 0:
+ self._finished.set()
+
+ def join(self, timeout=None):
+ """Block until all items in the queue are processed.
+
+ Returns a Future, which raises `tornado.gen.TimeoutError` after a
+ timeout.
+ """
+ return self._finished.wait(timeout)
+
+ # These three are overridable in subclasses.
+ def _init(self):
+ self._queue = collections.deque()
+
+ def _get(self):
+ return self._queue.popleft()
+
+ def _put(self, item):
+ self._queue.append(item)
+ # End of the overridable methods.
+
+ def __put_internal(self, item):
+ self._unfinished_tasks += 1
+ self._finished.clear()
+ self._put(item)
+
+ def _consume_expired(self):
+ # Remove timed-out waiters.
+ while self._putters and self._putters[0][1].done():
+ self._putters.popleft()
+
+ while self._getters and self._getters[0].done():
+ self._getters.popleft()
+
+ def __repr__(self):
+ return '<%s at %s %s>' % (
+ type(self).__name__, hex(id(self)), self._format())
+
+ def __str__(self):
+ return '<%s %s>' % (type(self).__name__, self._format())
+
+ def _format(self):
+ result = 'maxsize=%r' % (self.maxsize, )
+ if getattr(self, '_queue', None):
+ result += ' queue=%r' % self._queue
+ if self._getters:
+ result += ' getters[%s]' % len(self._getters)
+ if self._putters:
+ result += ' putters[%s]' % len(self._putters)
+ if self._unfinished_tasks:
+ result += ' tasks=%s' % self._unfinished_tasks
+ return result
+
+
+class PriorityQueue(Queue):
+ """A `.Queue` that retrieves entries in priority order, lowest first.
+
+ Entries are typically tuples like ``(priority number, data)``.
+
+ .. testcode::
+
+ q = queues.PriorityQueue()
+ q.put((1, 'medium-priority item'))
+ q.put((0, 'high-priority item'))
+ q.put((10, 'low-priority item'))
+
+ print(q.get_nowait())
+ print(q.get_nowait())
+ print(q.get_nowait())
+
+ .. testoutput::
+
+ (0, 'high-priority item')
+ (1, 'medium-priority item')
+ (10, 'low-priority item')
+ """
+ def _init(self):
+ self._queue = []
+
+ def _put(self, item):
+ heapq.heappush(self._queue, item)
+
+ def _get(self):
+ return heapq.heappop(self._queue)
+
+
+class LifoQueue(Queue):
+ """A `.Queue` that retrieves the most recently put items first.
+
+ .. testcode::
+
+ q = queues.LifoQueue()
+ q.put(3)
+ q.put(2)
+ q.put(1)
+
+ print(q.get_nowait())
+ print(q.get_nowait())
+ print(q.get_nowait())
+
+ .. testoutput::
+
+ 1
+ 2
+ 3
+ """
+ def _init(self):
+ self._queue = []
+
+ def _put(self, item):
+ self._queue.append(item)
+
+ def _get(self):
+ return self._queue.pop()
diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py
index cf30f072..cf58e162 100644
--- a/tornado/simple_httpclient.py
+++ b/tornado/simple_httpclient.py
@@ -7,7 +7,7 @@ from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _
from tornado import httputil
from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters
from tornado.iostream import StreamClosedError
-from tornado.netutil import Resolver, OverrideResolver
+from tornado.netutil import Resolver, OverrideResolver, _client_ssl_defaults
from tornado.log import gen_log
from tornado import stack_context
from tornado.tcpclient import TCPClient
@@ -34,7 +34,7 @@ except ImportError:
ssl = None
try:
- import lib.certifi
+ import certifi
except ImportError:
certifi = None
@@ -50,9 +50,6 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
"""Non-blocking HTTP client with no external dependencies.
This class implements an HTTP 1.1 client on top of Tornado's IOStreams.
- It does not currently implement all applicable parts of the HTTP
- specification, but it does enough to work with major web service APIs.
-
Some features found in the curl-based AsyncHTTPClient are not yet
supported. In particular, proxies are not supported, connections
are not reused, and callers cannot select the network interface to be
@@ -60,25 +57,39 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
"""
def initialize(self, io_loop, max_clients=10,
hostname_mapping=None, max_buffer_size=104857600,
- resolver=None, defaults=None, max_header_size=None):
+ resolver=None, defaults=None, max_header_size=None,
+ max_body_size=None):
"""Creates a AsyncHTTPClient.
Only a single AsyncHTTPClient instance exists per IOLoop
in order to provide limitations on the number of pending connections.
- force_instance=True may be used to suppress this behavior.
+ ``force_instance=True`` may be used to suppress this behavior.
- max_clients is the number of concurrent requests that can be
- in progress. Note that this arguments are only used when the
- client is first created, and will be ignored when an existing
- client is reused.
+ Note that because of this implicit reuse, unless ``force_instance``
+ is used, only the first call to the constructor actually uses
+ its arguments. It is recommended to use the ``configure`` method
+ instead of the constructor to ensure that arguments take effect.
- hostname_mapping is a dictionary mapping hostnames to IP addresses.
+ ``max_clients`` is the number of concurrent requests that can be
+ in progress; when this limit is reached additional requests will be
+ queued. Note that time spent waiting in this queue still counts
+ against the ``request_timeout``.
+
+ ``hostname_mapping`` is a dictionary mapping hostnames to IP addresses.
It can be used to make local DNS changes when modifying system-wide
- settings like /etc/hosts is not possible or desirable (e.g. in
+ settings like ``/etc/hosts`` is not possible or desirable (e.g. in
unittests).
- max_buffer_size is the number of bytes that can be read by IOStream. It
- defaults to 100mb.
+ ``max_buffer_size`` (default 100MB) is the number of bytes
+ that can be read into memory at once. ``max_body_size``
+ (defaults to ``max_buffer_size``) is the largest response body
+ that the client will accept. Without a
+ ``streaming_callback``, the smaller of these two limits
+ applies; with a ``streaming_callback`` only ``max_body_size``
+ does.
+
+ .. versionchanged:: 4.2
+ Added the ``max_body_size`` argument.
"""
super(SimpleAsyncHTTPClient, self).initialize(io_loop,
defaults=defaults)
@@ -88,6 +99,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
self.waiting = {}
self.max_buffer_size = max_buffer_size
self.max_header_size = max_header_size
+ self.max_body_size = max_body_size
# TCPClient could create a Resolver for us, but we have to do it
# ourselves to support hostname_mapping.
if resolver:
@@ -135,10 +147,14 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
release_callback = functools.partial(self._release_fetch, key)
self._handle_request(request, release_callback, callback)
+ def _connection_class(self):
+ return _HTTPConnection
+
def _handle_request(self, request, release_callback, final_callback):
- _HTTPConnection(self.io_loop, self, request, release_callback,
- final_callback, self.max_buffer_size, self.tcp_client,
- self.max_header_size)
+ self._connection_class()(
+ self.io_loop, self, request, release_callback,
+ final_callback, self.max_buffer_size, self.tcp_client,
+ self.max_header_size, self.max_body_size)
def _release_fetch(self, key):
del self.active[key]
@@ -166,7 +182,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
def __init__(self, io_loop, client, request, release_callback,
final_callback, max_buffer_size, tcp_client,
- max_header_size):
+ max_header_size, max_body_size):
self.start_time = io_loop.time()
self.io_loop = io_loop
self.client = client
@@ -176,6 +192,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
self.max_buffer_size = max_buffer_size
self.tcp_client = tcp_client
self.max_header_size = max_header_size
+ self.max_body_size = max_body_size
self.code = None
self.headers = None
self.chunks = []
@@ -193,12 +210,8 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
netloc = self.parsed.netloc
if "@" in netloc:
userpass, _, netloc = netloc.rpartition("@")
- match = re.match(r'^(.+):(\d+)$', netloc)
- if match:
- host = match.group(1)
- port = int(match.group(2))
- else:
- host = netloc
+ host, port = httputil.split_host_and_port(netloc)
+ if port is None:
port = 443 if self.parsed.scheme == "https" else 80
if re.match(r'^\[.*\]$', host):
# raw ipv6 addresses in urls are enclosed in brackets
@@ -224,12 +237,24 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
def _get_ssl_options(self, scheme):
if scheme == "https":
+ if self.request.ssl_options is not None:
+ return self.request.ssl_options
+ # If we are using the defaults, don't construct a
+ # new SSLContext.
+ if (self.request.validate_cert and
+ self.request.ca_certs is None and
+ self.request.client_cert is None and
+ self.request.client_key is None):
+ return _client_ssl_defaults
ssl_options = {}
if self.request.validate_cert:
ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
if self.request.ca_certs is not None:
ssl_options["ca_certs"] = self.request.ca_certs
- else:
+ elif not hasattr(ssl, 'create_default_context'):
+ # When create_default_context is present,
+ # we can omit the "ca_certs" parameter entirely,
+ # which avoids the dependency on "certifi" for py34.
ssl_options["ca_certs"] = _default_ca_certs()
if self.request.client_key is not None:
ssl_options["keyfile"] = self.request.client_key
@@ -321,9 +346,9 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
body_present = (self.request.body is not None or
self.request.body_producer is not None)
if ((body_expected and not body_present) or
- (body_present and not body_expected)):
+ (body_present and not body_expected)):
raise ValueError(
- 'Body must %sbe None for method %s (unelss '
+ 'Body must %sbe None for method %s (unless '
'allow_nonstandard_methods is true)' %
('not ' if body_expected else '', self.request.method))
if self.request.expect_100_continue:
@@ -340,26 +365,30 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
self.request.headers["Accept-Encoding"] = "gzip"
req_path = ((self.parsed.path or '/') +
(('?' + self.parsed.query) if self.parsed.query else ''))
- self.stream.set_nodelay(True)
- self.connection = HTTP1Connection(
- self.stream, True,
- HTTP1ConnectionParameters(
- no_keep_alive=True,
- max_header_size=self.max_header_size,
- decompress=self.request.decompress_response),
- self._sockaddr)
+ self.connection = self._create_connection(stream)
start_line = httputil.RequestStartLine(self.request.method,
- req_path, 'HTTP/1.1')
+ req_path, '')
self.connection.write_headers(start_line, self.request.headers)
if self.request.expect_100_continue:
self._read_response()
else:
self._write_body(True)
+ def _create_connection(self, stream):
+ stream.set_nodelay(True)
+ connection = HTTP1Connection(
+ stream, True,
+ HTTP1ConnectionParameters(
+ no_keep_alive=True,
+ max_header_size=self.max_header_size,
+ max_body_size=self.max_body_size,
+ decompress=self.request.decompress_response),
+ self._sockaddr)
+ return connection
+
def _write_body(self, start_read):
if self.request.body is not None:
self.connection.write(self.request.body)
- self.connection.finish()
elif self.request.body_producer is not None:
fut = self.request.body_producer(self.connection.write)
if is_future(fut):
@@ -370,7 +399,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
self._read_response()
self.io_loop.add_future(fut, on_body_written)
return
- self.connection.finish()
+ self.connection.finish()
if start_read:
self._read_response()
diff --git a/tornado/tcpclient.py b/tornado/tcpclient.py
index 0abbea20..f594d91b 100644
--- a/tornado/tcpclient.py
+++ b/tornado/tcpclient.py
@@ -111,6 +111,7 @@ class _Connector(object):
if self.timeout is not None:
# If the first attempt failed, don't wait for the
# timeout to try an address from the secondary queue.
+ self.io_loop.remove_timeout(self.timeout)
self.on_timeout()
return
self.clear_timeout()
@@ -135,6 +136,9 @@ class _Connector(object):
class TCPClient(object):
"""A non-blocking TCP connection factory.
+
+ .. versionchanged:: 4.1
+ The ``io_loop`` argument is deprecated.
"""
def __init__(self, resolver=None, io_loop=None):
self.io_loop = io_loop or IOLoop.current()
diff --git a/tornado/tcpserver.py b/tornado/tcpserver.py
index 427acec5..c9d148a8 100644
--- a/tornado/tcpserver.py
+++ b/tornado/tcpserver.py
@@ -41,14 +41,15 @@ class TCPServer(object):
To use `TCPServer`, define a subclass which overrides the `handle_stream`
method.
- To make this server serve SSL traffic, send the ssl_options dictionary
- argument with the arguments required for the `ssl.wrap_socket` method,
- including "certfile" and "keyfile"::
+ To make this server serve SSL traffic, send the ``ssl_options`` keyword
+ argument with an `ssl.SSLContext` object. For compatibility with older
+ versions of Python ``ssl_options`` may also be a dictionary of keyword
+ arguments for the `ssl.wrap_socket` method.::
- TCPServer(ssl_options={
- "certfile": os.path.join(data_dir, "mydomain.crt"),
- "keyfile": os.path.join(data_dir, "mydomain.key"),
- })
+ ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
+ ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"),
+ os.path.join(data_dir, "mydomain.key"))
+ TCPServer(ssl_options=ssl_ctx)
`TCPServer` initialization follows one of three patterns:
@@ -56,14 +57,14 @@ class TCPServer(object):
server = TCPServer()
server.listen(8888)
- IOLoop.instance().start()
+ IOLoop.current().start()
2. `bind`/`start`: simple multi-process::
server = TCPServer()
server.bind(8888)
server.start(0) # Forks multiple sub-processes
- IOLoop.instance().start()
+ IOLoop.current().start()
When using this interface, an `.IOLoop` must *not* be passed
to the `TCPServer` constructor. `start` will always start
@@ -75,7 +76,7 @@ class TCPServer(object):
tornado.process.fork_processes(0)
server = TCPServer()
server.add_sockets(sockets)
- IOLoop.instance().start()
+ IOLoop.current().start()
The `add_sockets` interface is more complicated, but it can be
used with `tornado.process.fork_processes` to give you more
@@ -95,7 +96,7 @@ class TCPServer(object):
self._pending_sockets = []
self._started = False
self.max_buffer_size = max_buffer_size
- self.read_chunk_size = None
+ self.read_chunk_size = read_chunk_size
# Verify the SSL options. Otherwise we don't get errors until clients
# connect. This doesn't verify that the keys are legitimate, but
@@ -212,7 +213,20 @@ class TCPServer(object):
sock.close()
def handle_stream(self, stream, address):
- """Override to handle a new `.IOStream` from an incoming connection."""
+ """Override to handle a new `.IOStream` from an incoming connection.
+
+ This method may be a coroutine; if so any exceptions it raises
+ asynchronously will be logged. Accepting of incoming connections
+ will not be blocked by this coroutine.
+
+ If this `TCPServer` is configured for SSL, ``handle_stream``
+ may be called before the SSL handshake has completed. Use
+ `.SSLIOStream.wait_for_handshake` if you need to verify the client's
+ certificate or use NPN/ALPN.
+
+ .. versionchanged:: 4.2
+ Added the option for this method to be a coroutine.
+ """
raise NotImplementedError()
def _handle_connection(self, connection, address):
@@ -252,6 +266,8 @@ class TCPServer(object):
stream = IOStream(connection, io_loop=self.io_loop,
max_buffer_size=self.max_buffer_size,
read_chunk_size=self.read_chunk_size)
- self.handle_stream(stream, address)
+ future = self.handle_stream(stream, address)
+ if future is not None:
+ self.io_loop.add_future(future, lambda f: f.result())
except Exception:
app_log.error("Error in connection callback", exc_info=True)
diff --git a/tornado/test/README b/tornado/test/README
deleted file mode 100644
index 33edba98..00000000
--- a/tornado/test/README
+++ /dev/null
@@ -1,4 +0,0 @@
-Test coverage is almost non-existent, but it's a start. Be sure to
-set PYTHONPATH appropriately (generally to the root directory of your
-tornado checkout) when running tests to make sure you're getting the
-version of the tornado package that you expect.
\ No newline at end of file
diff --git a/tornado/test/asyncio_test.py b/tornado/test/asyncio_test.py
new file mode 100644
index 00000000..1be0e54f
--- /dev/null
+++ b/tornado/test/asyncio_test.py
@@ -0,0 +1,69 @@
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from __future__ import absolute_import, division, print_function, with_statement
+
+import sys
+import textwrap
+
+from tornado import gen
+from tornado.testing import AsyncTestCase, gen_test
+from tornado.test.util import unittest
+
+try:
+ from tornado.platform.asyncio import asyncio, AsyncIOLoop
+except ImportError:
+ asyncio = None
+
+skipIfNoSingleDispatch = unittest.skipIf(
+ gen.singledispatch is None, "singledispatch module not present")
+
+
+@unittest.skipIf(asyncio is None, "asyncio module not present")
+class AsyncIOLoopTest(AsyncTestCase):
+ def get_new_ioloop(self):
+ io_loop = AsyncIOLoop()
+ asyncio.set_event_loop(io_loop.asyncio_loop)
+ return io_loop
+
+ def test_asyncio_callback(self):
+ # Basic test that the asyncio loop is set up correctly.
+ asyncio.get_event_loop().call_soon(self.stop)
+ self.wait()
+
+ @skipIfNoSingleDispatch
+ @gen_test
+ def test_asyncio_future(self):
+ # Test that we can yield an asyncio future from a tornado coroutine.
+ # Without 'yield from', we must wrap coroutines in asyncio.async.
+ x = yield asyncio.async(
+ asyncio.get_event_loop().run_in_executor(None, lambda: 42))
+ self.assertEqual(x, 42)
+
+ @unittest.skipIf(sys.version_info < (3, 3),
+ 'PEP 380 not available')
+ @skipIfNoSingleDispatch
+ @gen_test
+ def test_asyncio_yield_from(self):
+ # Test that we can use asyncio coroutines with 'yield from'
+ # instead of asyncio.async(). This requires python 3.3 syntax.
+ global_namespace = dict(globals(), **locals())
+ local_namespace = {}
+ exec(textwrap.dedent("""
+ @gen.coroutine
+ def f():
+ event_loop = asyncio.get_event_loop()
+ x = yield from event_loop.run_in_executor(None, lambda: 42)
+ return x
+ """), global_namespace, local_namespace)
+ result = yield local_namespace['f']()
+ self.assertEqual(result, 42)
diff --git a/tornado/test/auth_test.py b/tornado/test/auth_test.py
index 254e1ae1..541ecf16 100644
--- a/tornado/test/auth_test.py
+++ b/tornado/test/auth_test.py
@@ -5,7 +5,7 @@
from __future__ import absolute_import, division, print_function, with_statement
-from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin, TwitterMixin, GoogleMixin, AuthError
+from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin, TwitterMixin, AuthError
from tornado.concurrent import Future
from tornado.escape import json_decode
from tornado import gen
@@ -238,28 +238,6 @@ class TwitterServerVerifyCredentialsHandler(RequestHandler):
self.write(dict(screen_name='foo', name='Foo'))
-class GoogleOpenIdClientLoginHandler(RequestHandler, GoogleMixin):
- def initialize(self, test):
- self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
-
- @asynchronous
- def get(self):
- if self.get_argument("openid.mode", None):
- self.get_authenticated_user(self.on_user)
- return
- res = self.authenticate_redirect()
- assert isinstance(res, Future)
- assert res.done()
-
- def on_user(self, user):
- if user is None:
- raise Exception("user is None")
- self.finish(user)
-
- def get_auth_http_client(self):
- return self.settings['http_client']
-
-
class AuthTest(AsyncHTTPTestCase):
def get_app(self):
return Application(
@@ -286,7 +264,6 @@ class AuthTest(AsyncHTTPTestCase):
('/twitter/client/login_gen_coroutine', TwitterClientLoginGenCoroutineHandler, dict(test=self)),
('/twitter/client/show_user', TwitterClientShowUserHandler, dict(test=self)),
('/twitter/client/show_user_future', TwitterClientShowUserFutureHandler, dict(test=self)),
- ('/google/client/openid_login', GoogleOpenIdClientLoginHandler, dict(test=self)),
# simulated servers
('/openid/server/authenticate', OpenIdServerAuthenticateHandler),
@@ -436,16 +413,3 @@ class AuthTest(AsyncHTTPTestCase):
response = self.fetch('/twitter/client/show_user_future?name=error')
self.assertEqual(response.code, 500)
self.assertIn(b'Error response HTTP 500', response.body)
-
- def test_google_redirect(self):
- # same as test_openid_redirect
- response = self.fetch('/google/client/openid_login', follow_redirects=False)
- self.assertEqual(response.code, 302)
- self.assertTrue(
- '/openid/server/authenticate?' in response.headers['Location'])
-
- def test_google_get_user(self):
- response = self.fetch('/google/client/openid_login?openid.mode=blah&openid.ns.ax=http://openid.net/srv/ax/1.0&openid.ax.type.email=http://axschema.org/contact/email&openid.ax.value.email=foo@example.com', follow_redirects=False)
- response.rethrow()
- parsed = json_decode(response.body)
- self.assertEqual(parsed["email"], "foo@example.com")
diff --git a/tornado/test/concurrent_test.py b/tornado/test/concurrent_test.py
index 5e93ad6a..bf90ad0e 100644
--- a/tornado/test/concurrent_test.py
+++ b/tornado/test/concurrent_test.py
@@ -21,13 +21,14 @@ import socket
import sys
import traceback
-from tornado.concurrent import Future, return_future, ReturnValueIgnoredError
+from tornado.concurrent import Future, return_future, ReturnValueIgnoredError, run_on_executor
from tornado.escape import utf8, to_unicode
from tornado import gen
from tornado.iostream import IOStream
from tornado import stack_context
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, LogTrapTestCase, bind_unused_port, gen_test
+from tornado.test.util import unittest
try:
@@ -334,3 +335,81 @@ class DecoratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
class GeneratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
client_class = GeneratorCapClient
+
+
+@unittest.skipIf(futures is None, "concurrent.futures module not present")
+class RunOnExecutorTest(AsyncTestCase):
+ @gen_test
+ def test_no_calling(self):
+ class Object(object):
+ def __init__(self, io_loop):
+ self.io_loop = io_loop
+ self.executor = futures.thread.ThreadPoolExecutor(1)
+
+ @run_on_executor
+ def f(self):
+ return 42
+
+ o = Object(io_loop=self.io_loop)
+ answer = yield o.f()
+ self.assertEqual(answer, 42)
+
+ @gen_test
+ def test_call_with_no_args(self):
+ class Object(object):
+ def __init__(self, io_loop):
+ self.io_loop = io_loop
+ self.executor = futures.thread.ThreadPoolExecutor(1)
+
+ @run_on_executor()
+ def f(self):
+ return 42
+
+ o = Object(io_loop=self.io_loop)
+ answer = yield o.f()
+ self.assertEqual(answer, 42)
+
+ @gen_test
+ def test_call_with_io_loop(self):
+ class Object(object):
+ def __init__(self, io_loop):
+ self._io_loop = io_loop
+ self.executor = futures.thread.ThreadPoolExecutor(1)
+
+ @run_on_executor(io_loop='_io_loop')
+ def f(self):
+ return 42
+
+ o = Object(io_loop=self.io_loop)
+ answer = yield o.f()
+ self.assertEqual(answer, 42)
+
+ @gen_test
+ def test_call_with_executor(self):
+ class Object(object):
+ def __init__(self, io_loop):
+ self.io_loop = io_loop
+ self.__executor = futures.thread.ThreadPoolExecutor(1)
+
+ @run_on_executor(executor='_Object__executor')
+ def f(self):
+ return 42
+
+ o = Object(io_loop=self.io_loop)
+ answer = yield o.f()
+ self.assertEqual(answer, 42)
+
+ @gen_test
+ def test_call_with_both(self):
+ class Object(object):
+ def __init__(self, io_loop):
+ self._io_loop = io_loop
+ self.__executor = futures.thread.ThreadPoolExecutor(1)
+
+ @run_on_executor(io_loop='_io_loop', executor='_Object__executor')
+ def f(self):
+ return 42
+
+ o = Object(io_loop=self.io_loop)
+ answer = yield o.f()
+ self.assertEqual(answer, 42)
diff --git a/tornado/test/curl_httpclient_test.py b/tornado/test/curl_httpclient_test.py
index 8d7065df..3ac21f4d 100644
--- a/tornado/test/curl_httpclient_test.py
+++ b/tornado/test/curl_httpclient_test.py
@@ -8,7 +8,7 @@ from tornado.stack_context import ExceptionStackContext
from tornado.testing import AsyncHTTPTestCase
from tornado.test import httpclient_test
from tornado.test.util import unittest
-from tornado.web import Application, RequestHandler, URLSpec
+from tornado.web import Application, RequestHandler
try:
diff --git a/tornado/test/escape_test.py b/tornado/test/escape_test.py
index f6404288..98a23463 100644
--- a/tornado/test/escape_test.py
+++ b/tornado/test/escape_test.py
@@ -217,9 +217,8 @@ class EscapeTestCase(unittest.TestCase):
self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9")
def test_squeeze(self):
- self.assertEqual(squeeze(u('sequences of whitespace chars'))
- , u('sequences of whitespace chars'))
-
+ self.assertEqual(squeeze(u('sequences of whitespace chars')), u('sequences of whitespace chars'))
+
def test_recursive_unicode(self):
tests = {
'dict': {b"foo": b"bar"},
diff --git a/tornado/test/gen_test.py b/tornado/test/gen_test.py
index a15cdf73..fdaa0ec8 100644
--- a/tornado/test/gen_test.py
+++ b/tornado/test/gen_test.py
@@ -62,6 +62,11 @@ class GenEngineTest(AsyncTestCase):
def async_future(self, result, callback):
self.io_loop.add_callback(callback, result)
+ @gen.coroutine
+ def async_exception(self, e):
+ yield gen.moment
+ raise e
+
def test_no_yield(self):
@gen.engine
def f():
@@ -385,11 +390,56 @@ class GenEngineTest(AsyncTestCase):
results = yield [self.async_future(1), self.async_future(2)]
self.assertEqual(results, [1, 2])
+ @gen_test
+ def test_multi_future_duplicate(self):
+ f = self.async_future(2)
+ results = yield [self.async_future(1), f, self.async_future(3), f]
+ self.assertEqual(results, [1, 2, 3, 2])
+
@gen_test
def test_multi_dict_future(self):
results = yield dict(foo=self.async_future(1), bar=self.async_future(2))
self.assertEqual(results, dict(foo=1, bar=2))
+ @gen_test
+ def test_multi_exceptions(self):
+ with ExpectLog(app_log, "Multiple exceptions in yield list"):
+ with self.assertRaises(RuntimeError) as cm:
+ yield gen.Multi([self.async_exception(RuntimeError("error 1")),
+ self.async_exception(RuntimeError("error 2"))])
+ self.assertEqual(str(cm.exception), "error 1")
+
+ # With only one exception, no error is logged.
+ with self.assertRaises(RuntimeError):
+ yield gen.Multi([self.async_exception(RuntimeError("error 1")),
+ self.async_future(2)])
+
+ # Exception logging may be explicitly quieted.
+ with self.assertRaises(RuntimeError):
+ yield gen.Multi([self.async_exception(RuntimeError("error 1")),
+ self.async_exception(RuntimeError("error 2"))],
+ quiet_exceptions=RuntimeError)
+
+ @gen_test
+ def test_multi_future_exceptions(self):
+ with ExpectLog(app_log, "Multiple exceptions in yield list"):
+ with self.assertRaises(RuntimeError) as cm:
+ yield [self.async_exception(RuntimeError("error 1")),
+ self.async_exception(RuntimeError("error 2"))]
+ self.assertEqual(str(cm.exception), "error 1")
+
+ # With only one exception, no error is logged.
+ with self.assertRaises(RuntimeError):
+ yield [self.async_exception(RuntimeError("error 1")),
+ self.async_future(2)]
+
+ # Exception logging may be explicitly quieted.
+ with self.assertRaises(RuntimeError):
+ yield gen.multi_future(
+ [self.async_exception(RuntimeError("error 1")),
+ self.async_exception(RuntimeError("error 2"))],
+ quiet_exceptions=RuntimeError)
+
def test_arguments(self):
@gen.engine
def f():
@@ -816,6 +866,7 @@ class GenCoroutineTest(AsyncTestCase):
@gen_test
def test_moment(self):
calls = []
+
@gen.coroutine
def f(name, yieldable):
for i in range(5):
@@ -838,6 +889,34 @@ class GenCoroutineTest(AsyncTestCase):
yield [f('a', gen.moment), f('b', immediate)]
self.assertEqual(''.join(calls), 'abbbbbaaaa')
+ @gen_test
+ def test_sleep(self):
+ yield gen.sleep(0.01)
+ self.finished = True
+
+ @skipBefore33
+ @gen_test
+ def test_py3_leak_exception_context(self):
+ class LeakedException(Exception):
+ pass
+
+ @gen.coroutine
+ def inner(iteration):
+ raise LeakedException(iteration)
+
+ try:
+ yield inner(1)
+ except LeakedException as e:
+ self.assertEqual(str(e), "1")
+ self.assertIsNone(e.__context__)
+
+ try:
+ yield inner(2)
+ except LeakedException as e:
+ self.assertEqual(str(e), "2")
+ self.assertIsNone(e.__context__)
+
+ self.finished = True
class GenSequenceHandler(RequestHandler):
@asynchronous
@@ -1031,7 +1110,7 @@ class WithTimeoutTest(AsyncTestCase):
self.io_loop.add_timeout(datetime.timedelta(seconds=0.1),
lambda: future.set_result('asdf'))
result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
- future)
+ future, io_loop=self.io_loop)
self.assertEqual(result, 'asdf')
@gen_test
@@ -1039,16 +1118,17 @@ class WithTimeoutTest(AsyncTestCase):
future = Future()
self.io_loop.add_timeout(
datetime.timedelta(seconds=0.1),
- lambda: future.set_exception(ZeroDivisionError))
+ lambda: future.set_exception(ZeroDivisionError()))
with self.assertRaises(ZeroDivisionError):
- yield gen.with_timeout(datetime.timedelta(seconds=3600), future)
+ yield gen.with_timeout(datetime.timedelta(seconds=3600),
+ future, io_loop=self.io_loop)
@gen_test
def test_already_resolved(self):
future = Future()
future.set_result('asdf')
result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
- future)
+ future, io_loop=self.io_loop)
self.assertEqual(result, 'asdf')
@unittest.skipIf(futures is None, 'futures module not present')
@@ -1067,5 +1147,117 @@ class WithTimeoutTest(AsyncTestCase):
executor.submit(lambda: None))
+class WaitIteratorTest(AsyncTestCase):
+ @gen_test
+ def test_empty_iterator(self):
+ g = gen.WaitIterator()
+ self.assertTrue(g.done(), 'empty generator iterated')
+
+ with self.assertRaises(ValueError):
+ g = gen.WaitIterator(False, bar=False)
+
+ self.assertEqual(g.current_index, None, "bad nil current index")
+ self.assertEqual(g.current_future, None, "bad nil current future")
+
+ @gen_test
+ def test_already_done(self):
+ f1 = Future()
+ f2 = Future()
+ f3 = Future()
+ f1.set_result(24)
+ f2.set_result(42)
+ f3.set_result(84)
+
+ g = gen.WaitIterator(f1, f2, f3)
+ i = 0
+ while not g.done():
+ r = yield g.next()
+ # Order is not guaranteed, but the current implementation
+ # preserves ordering of already-done Futures.
+ if i == 0:
+ self.assertEqual(g.current_index, 0)
+ self.assertIs(g.current_future, f1)
+ self.assertEqual(r, 24)
+ elif i == 1:
+ self.assertEqual(g.current_index, 1)
+ self.assertIs(g.current_future, f2)
+ self.assertEqual(r, 42)
+ elif i == 2:
+ self.assertEqual(g.current_index, 2)
+ self.assertIs(g.current_future, f3)
+ self.assertEqual(r, 84)
+ i += 1
+
+ self.assertEqual(g.current_index, None, "bad nil current index")
+ self.assertEqual(g.current_future, None, "bad nil current future")
+
+ dg = gen.WaitIterator(f1=f1, f2=f2)
+
+ while not dg.done():
+ dr = yield dg.next()
+ if dg.current_index == "f1":
+ self.assertTrue(dg.current_future == f1 and dr == 24,
+ "WaitIterator dict status incorrect")
+ elif dg.current_index == "f2":
+ self.assertTrue(dg.current_future == f2 and dr == 42,
+ "WaitIterator dict status incorrect")
+ else:
+ self.fail("got bad WaitIterator index {}".format(
+ dg.current_index))
+
+ i += 1
+
+ self.assertEqual(dg.current_index, None, "bad nil current index")
+ self.assertEqual(dg.current_future, None, "bad nil current future")
+
+ def finish_coroutines(self, iteration, futures):
+ if iteration == 3:
+ futures[2].set_result(24)
+ elif iteration == 5:
+ futures[0].set_exception(ZeroDivisionError())
+ elif iteration == 8:
+ futures[1].set_result(42)
+ futures[3].set_result(84)
+
+ if iteration < 8:
+ self.io_loop.add_callback(self.finish_coroutines, iteration + 1, futures)
+
+ @gen_test
+ def test_iterator(self):
+ futures = [Future(), Future(), Future(), Future()]
+
+ self.finish_coroutines(0, futures)
+
+ g = gen.WaitIterator(*futures)
+
+ i = 0
+ while not g.done():
+ try:
+ r = yield g.next()
+ except ZeroDivisionError:
+ self.assertIs(g.current_future, futures[0],
+ 'exception future invalid')
+ else:
+ if i == 0:
+ self.assertEqual(r, 24, 'iterator value incorrect')
+ self.assertEqual(g.current_index, 2, 'wrong index')
+ elif i == 2:
+ self.assertEqual(r, 42, 'iterator value incorrect')
+ self.assertEqual(g.current_index, 1, 'wrong index')
+ elif i == 3:
+ self.assertEqual(r, 84, 'iterator value incorrect')
+ self.assertEqual(g.current_index, 3, 'wrong index')
+ i += 1
+
+ @gen_test
+ def test_no_ref(self):
+ # In this usage, there is no direct hard reference to the
+ # WaitIterator itself, only the Future it returns. Since
+ # WaitIterator uses weak references internally to improve GC
+ # performance, this used to cause problems.
+ yield gen.with_timeout(datetime.timedelta(seconds=0.1),
+ gen.WaitIterator(gen.sleep(0)).next())
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/tornado/test/gettext_translations/extract_me.py b/tornado/test/gettext_translations/extract_me.py
index 75406ecc..45321cce 100644
--- a/tornado/test/gettext_translations/extract_me.py
+++ b/tornado/test/gettext_translations/extract_me.py
@@ -1,11 +1,16 @@
+# flake8: noqa
# Dummy source file to allow creation of the initial .po file in the
# same way as a real project. I'm not entirely sure about the real
# workflow here, but this seems to work.
#
-# 1) xgettext --language=Python --keyword=_:1,2 -d tornado_test extract_me.py -o tornado_test.po
-# 2) Edit tornado_test.po, setting CHARSET and setting msgstr
+# 1) xgettext --language=Python --keyword=_:1,2 --keyword=pgettext:1c,2 --keyword=pgettext:1c,2,3 extract_me.py -o tornado_test.po
+# 2) Edit tornado_test.po, setting CHARSET, Plural-Forms and setting msgstr
# 3) msgfmt tornado_test.po -o tornado_test.mo
# 4) Put the file in the proper location: $LANG/LC_MESSAGES
from __future__ import absolute_import, division, print_function, with_statement
_("school")
+pgettext("law", "right")
+pgettext("good", "right")
+pgettext("organization", "club", "clubs", 1)
+pgettext("stick", "club", "clubs", 1)
diff --git a/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.mo b/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.mo
index 089f6c7a..a97bf9c5 100644
Binary files a/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.mo and b/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.mo differ
diff --git a/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.po b/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.po
index 732ee6da..88d72c86 100644
--- a/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.po
+++ b/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.po
@@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
-"POT-Creation-Date: 2012-06-14 01:10-0700\n"
+"POT-Creation-Date: 2015-01-27 11:05+0300\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME \n"
"Language-Team: LANGUAGE \n"
@@ -16,7 +16,32 @@ msgstr ""
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
+"Plural-Forms: nplurals=2; plural=(n > 1);\n"
-#: extract_me.py:1
+#: extract_me.py:11
msgid "school"
msgstr "école"
+
+#: extract_me.py:12
+msgctxt "law"
+msgid "right"
+msgstr "le droit"
+
+#: extract_me.py:13
+msgctxt "good"
+msgid "right"
+msgstr "le bien"
+
+#: extract_me.py:14
+msgctxt "organization"
+msgid "club"
+msgid_plural "clubs"
+msgstr[0] "le club"
+msgstr[1] "les clubs"
+
+#: extract_me.py:15
+msgctxt "stick"
+msgid "club"
+msgid_plural "clubs"
+msgstr[0] "le bâton"
+msgstr[1] "les bâtons"
diff --git a/tornado/test/httpclient_test.py b/tornado/test/httpclient_test.py
index bfb50d87..ecc63e4a 100644
--- a/tornado/test/httpclient_test.py
+++ b/tornado/test/httpclient_test.py
@@ -12,6 +12,7 @@ import datetime
from io import BytesIO
from tornado.escape import utf8
+from tornado import gen
from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
@@ -52,9 +53,12 @@ class RedirectHandler(RequestHandler):
class ChunkHandler(RequestHandler):
+ @gen.coroutine
def get(self):
self.write("asdf")
self.flush()
+ # Wait a bit to ensure the chunks are sent and received separately.
+ yield gen.sleep(0.01)
self.write("qwer")
@@ -178,6 +182,8 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase):
sock, port = bind_unused_port()
with closing(sock):
def write_response(stream, request_data):
+ if b"HTTP/1." not in request_data:
+ self.skipTest("requires HTTP/1.x")
stream.write(b"""\
HTTP/1.1 200 OK
Transfer-Encoding: chunked
@@ -300,23 +306,26 @@ Transfer-Encoding: chunked
chunks = []
def header_callback(header_line):
- if header_line.startswith('HTTP/'):
+ if header_line.startswith('HTTP/1.1 101'):
+ # Upgrading to HTTP/2
+ pass
+ elif header_line.startswith('HTTP/'):
first_line.append(header_line)
elif header_line != '\r\n':
k, v = header_line.split(':', 1)
- headers[k] = v.strip()
+ headers[k.lower()] = v.strip()
def streaming_callback(chunk):
# All header callbacks are run before any streaming callbacks,
# so the header data is available to process the data as it
# comes in.
- self.assertEqual(headers['Content-Type'], 'text/html; charset=UTF-8')
+ self.assertEqual(headers['content-type'], 'text/html; charset=UTF-8')
chunks.append(chunk)
self.fetch('/chunk', header_callback=header_callback,
streaming_callback=streaming_callback)
- self.assertEqual(len(first_line), 1)
- self.assertRegexpMatches(first_line[0], 'HTTP/1.[01] 200 OK\r\n')
+ self.assertEqual(len(first_line), 1, first_line)
+ self.assertRegexpMatches(first_line[0], 'HTTP/[0-9]\\.[0-9] 200.*\r\n')
self.assertEqual(chunks, [b'asdf', b'qwer'])
def test_header_callback_stack_context(self):
@@ -327,7 +336,7 @@ Transfer-Encoding: chunked
return True
def header_callback(header_line):
- if header_line.startswith('Content-Type:'):
+ if header_line.lower().startswith('content-type:'):
1 / 0
with ExceptionStackContext(error_handler):
@@ -404,6 +413,11 @@ Transfer-Encoding: chunked
self.assertEqual(context.exception.code, 404)
self.assertEqual(context.exception.response.code, 404)
+ @gen_test
+ def test_future_http_error_no_raise(self):
+ response = yield self.http_client.fetch(self.get_url('/notfound'), raise_error=False)
+ self.assertEqual(response.code, 404)
+
@gen_test
def test_reuse_request_from_response(self):
# The response.request attribute should be an HTTPRequest, not
@@ -454,7 +468,7 @@ Transfer-Encoding: chunked
# Twisted's reactor does not. The removeReader call fails and so
# do all future removeAll calls (which our tests do at cleanup).
#
- #def test_post_307(self):
+ # def test_post_307(self):
# response = self.fetch("/redirect?status=307&url=/post",
# method="POST", body=b"arg1=foo&arg2=bar")
# self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
@@ -536,14 +550,19 @@ class SyncHTTPClientTest(unittest.TestCase):
def tearDown(self):
def stop_server():
self.server.stop()
- self.server_ioloop.stop()
+ # Delay the shutdown of the IOLoop by one iteration because
+ # the server may still have some cleanup work left when
+ # the client finishes with the response (this is noticable
+ # with http/2, which leaves a Future with an unexamined
+ # StreamClosedError on the loop).
+ self.server_ioloop.add_callback(self.server_ioloop.stop)
self.server_ioloop.add_callback(stop_server)
self.server_thread.join()
self.http_client.close()
self.server_ioloop.close(all_fds=True)
def get_url(self, path):
- return 'http://localhost:%d%s' % (self.port, path)
+ return 'http://127.0.0.1:%d%s' % (self.port, path)
def test_sync_client(self):
response = self.http_client.fetch(self.get_url('/'))
@@ -584,5 +603,5 @@ class HTTPRequestTestCase(unittest.TestCase):
def test_if_modified_since(self):
http_date = datetime.datetime.utcnow()
request = HTTPRequest('http://example.com', if_modified_since=http_date)
- self.assertEqual(request.headers,
- {'If-Modified-Since': format_timestamp(http_date)})
+ self.assertEqual(request.headers,
+ {'If-Modified-Since': format_timestamp(http_date)})
diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py
index 156a027b..f05599dd 100644
--- a/tornado/test/httpserver_test.py
+++ b/tornado/test/httpserver_test.py
@@ -32,6 +32,7 @@ def read_stream_body(stream, callback):
"""Reads an HTTP response from `stream` and runs callback with its
headers and body."""
chunks = []
+
class Delegate(HTTPMessageDelegate):
def headers_received(self, start_line, headers):
self.headers = headers
@@ -161,19 +162,22 @@ class BadSSLOptionsTest(unittest.TestCase):
application = Application()
module_dir = os.path.dirname(__file__)
existing_certificate = os.path.join(module_dir, 'test.crt')
+ existing_key = os.path.join(module_dir, 'test.key')
- self.assertRaises(ValueError, HTTPServer, application, ssl_options={
- "certfile": "/__mising__.crt",
+ self.assertRaises((ValueError, IOError),
+ HTTPServer, application, ssl_options={
+ "certfile": "/__mising__.crt",
})
- self.assertRaises(ValueError, HTTPServer, application, ssl_options={
- "certfile": existing_certificate,
- "keyfile": "/__missing__.key"
+ self.assertRaises((ValueError, IOError),
+ HTTPServer, application, ssl_options={
+ "certfile": existing_certificate,
+ "keyfile": "/__missing__.key"
})
# This actually works because both files exist
HTTPServer(application, ssl_options={
"certfile": existing_certificate,
- "keyfile": existing_certificate
+ "keyfile": existing_key,
})
@@ -195,14 +199,14 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
def get_app(self):
return Application(self.get_handlers())
- def raw_fetch(self, headers, body):
+ def raw_fetch(self, headers, body, newline=b"\r\n"):
with closing(IOStream(socket.socket())) as stream:
stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
self.wait()
stream.write(
- b"\r\n".join(headers +
- [utf8("Content-Length: %d\r\n" % len(body))]) +
- b"\r\n" + body)
+ newline.join(headers +
+ [utf8("Content-Length: %d" % len(body))]) +
+ newline + newline + body)
read_stream_body(stream, self.stop)
headers, body = self.wait()
return body
@@ -232,12 +236,19 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
self.assertEqual(u("\u00f3"), data["filename"])
self.assertEqual(u("\u00fa"), data["filebody"])
+ def test_newlines(self):
+ # We support both CRLF and bare LF as line separators.
+ for newline in (b"\r\n", b"\n"):
+ response = self.raw_fetch([b"GET /hello HTTP/1.0"], b"",
+ newline=newline)
+ self.assertEqual(response, b'Hello world')
+
def test_100_continue(self):
# Run through a 100-continue interaction by hand:
# When given Expect: 100-continue, we get a 100 response after the
# headers, and then the real response after the body.
stream = IOStream(socket.socket(), io_loop=self.io_loop)
- stream.connect(("localhost", self.get_http_port()), callback=self.stop)
+ stream.connect(("127.0.0.1", self.get_http_port()), callback=self.stop)
self.wait()
stream.write(b"\r\n".join([b"POST /hello HTTP/1.1",
b"Content-Length: 1024",
@@ -374,7 +385,7 @@ class HTTPServerRawTest(AsyncHTTPTestCase):
def setUp(self):
super(HTTPServerRawTest, self).setUp()
self.stream = IOStream(socket.socket())
- self.stream.connect(('localhost', self.get_http_port()), self.stop)
+ self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
self.wait()
def tearDown(self):
@@ -555,7 +566,7 @@ class UnixSocketTest(AsyncTestCase):
self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
self.stream.read_until(b"\r\n", self.stop)
response = self.wait()
- self.assertEqual(response, b"HTTP/1.0 200 OK\r\n")
+ self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
self.stream.read_until(b"\r\n\r\n", self.stop)
headers = HTTPHeaders.parse(self.wait().decode('latin1'))
self.stream.read_bytes(int(headers["Content-Length"]), self.stop)
@@ -582,6 +593,7 @@ class KeepAliveTest(AsyncHTTPTestCase):
class HelloHandler(RequestHandler):
def get(self):
self.finish('Hello world')
+
def post(self):
self.finish('Hello world')
@@ -623,13 +635,13 @@ class KeepAliveTest(AsyncHTTPTestCase):
# The next few methods are a crude manual http client
def connect(self):
self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
- self.stream.connect(('localhost', self.get_http_port()), self.stop)
+ self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
self.wait()
def read_headers(self):
self.stream.read_until(b'\r\n', self.stop)
first_line = self.wait()
- self.assertTrue(first_line.startswith(self.http_version + b' 200'), first_line)
+ self.assertTrue(first_line.startswith(b'HTTP/1.1 200'), first_line)
self.stream.read_until(b'\r\n\r\n', self.stop)
header_bytes = self.wait()
headers = HTTPHeaders.parse(header_bytes.decode('latin1'))
@@ -808,8 +820,8 @@ class StreamingChunkSizeTest(AsyncHTTPTestCase):
def get_app(self):
class App(HTTPServerConnectionDelegate):
- def start_request(self, connection):
- return StreamingChunkSizeTest.MessageDelegate(connection)
+ def start_request(self, server_conn, request_conn):
+ return StreamingChunkSizeTest.MessageDelegate(request_conn)
return App()
def fetch_chunk_sizes(self, **kwargs):
@@ -856,6 +868,7 @@ class StreamingChunkSizeTest(AsyncHTTPTestCase):
def test_chunked_compressed(self):
compressed = self.compress(self.BODY)
self.assertGreater(len(compressed), 20)
+
def body_producer(write):
write(compressed[:20])
write(compressed[20:])
@@ -900,7 +913,7 @@ class IdleTimeoutTest(AsyncHTTPTestCase):
def connect(self):
stream = IOStream(socket.socket())
- stream.connect(('localhost', self.get_http_port()), self.stop)
+ stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
self.wait()
self.streams.append(stream)
return stream
@@ -1045,6 +1058,15 @@ class LegacyInterfaceTest(AsyncHTTPTestCase):
# delegate interface, and writes its response via request.write
# instead of request.connection.write_headers.
def handle_request(request):
+ self.http1 = request.version.startswith("HTTP/1.")
+ if not self.http1:
+ # This test will be skipped if we're using HTTP/2,
+ # so just close it out cleanly using the modern interface.
+ request.connection.write_headers(
+ ResponseStartLine('', 200, 'OK'),
+ HTTPHeaders())
+ request.connection.finish()
+ return
message = b"Hello world"
request.write(utf8("HTTP/1.1 200 OK\r\n"
"Content-Length: %d\r\n\r\n" % len(message)))
@@ -1054,4 +1076,6 @@ class LegacyInterfaceTest(AsyncHTTPTestCase):
def test_legacy_interface(self):
response = self.fetch('/')
+ if not self.http1:
+ self.skipTest("requires HTTP/1.x")
self.assertEqual(response.body, b"Hello world")
diff --git a/tornado/test/httputil_test.py b/tornado/test/httputil_test.py
index 5ca5cf9f..6e953601 100644
--- a/tornado/test/httputil_test.py
+++ b/tornado/test/httputil_test.py
@@ -3,11 +3,13 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado.httputil import url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp, HTTPServerRequest, parse_request_start_line
-from tornado.escape import utf8
+from tornado.escape import utf8, native_str
from tornado.log import gen_log
from tornado.testing import ExpectLog
from tornado.test.util import unittest
+from tornado.util import u
+import copy
import datetime
import logging
import time
@@ -228,6 +230,75 @@ Foo: even
("Foo", "bar baz"),
("Foo", "even more lines")])
+ def test_unicode_newlines(self):
+ # Ensure that only \r\n is recognized as a header separator, and not
+ # the other newline-like unicode characters.
+ # Characters that are likely to be problematic can be found in
+ # http://unicode.org/standard/reports/tr13/tr13-5.html
+ # and cpython's unicodeobject.c (which defines the implementation
+ # of unicode_type.splitlines(), and uses a different list than TR13).
+ newlines = [
+ u('\u001b'), # VERTICAL TAB
+ u('\u001c'), # FILE SEPARATOR
+ u('\u001d'), # GROUP SEPARATOR
+ u('\u001e'), # RECORD SEPARATOR
+ u('\u0085'), # NEXT LINE
+ u('\u2028'), # LINE SEPARATOR
+ u('\u2029'), # PARAGRAPH SEPARATOR
+ ]
+ for newline in newlines:
+ # Try the utf8 and latin1 representations of each newline
+ for encoding in ['utf8', 'latin1']:
+ try:
+ try:
+ encoded = newline.encode(encoding)
+ except UnicodeEncodeError:
+ # Some chars cannot be represented in latin1
+ continue
+ data = b'Cookie: foo=' + encoded + b'bar'
+ # parse() wants a native_str, so decode through latin1
+ # in the same way the real parser does.
+ headers = HTTPHeaders.parse(
+ native_str(data.decode('latin1')))
+ expected = [('Cookie', 'foo=' +
+ native_str(encoded.decode('latin1')) + 'bar')]
+ self.assertEqual(
+ expected, list(headers.get_all()))
+ except Exception:
+ gen_log.warning("failed while trying %r in %s",
+ newline, encoding)
+ raise
+
+ def test_optional_cr(self):
+ # Both CRLF and LF should be accepted as separators. CR should not be
+ # part of the data when followed by LF, but it is a normal char
+ # otherwise (or should bare CR be an error?)
+ headers = HTTPHeaders.parse(
+ 'CRLF: crlf\r\nLF: lf\nCR: cr\rMore: more\r\n')
+ self.assertEqual(sorted(headers.get_all()),
+ [('Cr', 'cr\rMore: more'),
+ ('Crlf', 'crlf'),
+ ('Lf', 'lf'),
+ ])
+
+ def test_copy(self):
+ all_pairs = [('A', '1'), ('A', '2'), ('B', 'c')]
+ h1 = HTTPHeaders()
+ for k, v in all_pairs:
+ h1.add(k, v)
+ h2 = h1.copy()
+ h3 = copy.copy(h1)
+ h4 = copy.deepcopy(h1)
+ for headers in [h1, h2, h3, h4]:
+ # All the copies are identical, no matter how they were
+ # constructed.
+ self.assertEqual(list(sorted(headers.get_all())), all_pairs)
+ for headers in [h2, h3, h4]:
+ # Neither the dict or its member lists are reused.
+ self.assertIsNot(headers, h1)
+ self.assertIsNot(headers.get_list('A'), h1.get_list('A'))
+
+
class FormatTimestampTest(unittest.TestCase):
# Make sure that all the input types are supported.
@@ -264,6 +335,10 @@ class HTTPServerRequestTest(unittest.TestCase):
# more required parameters slip in.
HTTPServerRequest(uri='/')
+ def test_body_is_a_byte_string(self):
+ requets = HTTPServerRequest(uri='/')
+ self.assertIsInstance(requets.body, bytes)
+
class ParseRequestStartLineTest(unittest.TestCase):
METHOD = "GET"
diff --git a/tornado/test/import_test.py b/tornado/test/import_test.py
index de7cc0b9..1be6427f 100644
--- a/tornado/test/import_test.py
+++ b/tornado/test/import_test.py
@@ -1,3 +1,4 @@
+# flake8: noqa
from __future__ import absolute_import, division, print_function, with_statement
from tornado.test.util import unittest
diff --git a/tornado/test/ioloop_test.py b/tornado/test/ioloop_test.py
index 7eb7594f..f3a0cbdc 100644
--- a/tornado/test/ioloop_test.py
+++ b/tornado/test/ioloop_test.py
@@ -11,8 +11,9 @@ import threading
import time
from tornado import gen
-from tornado.ioloop import IOLoop, TimeoutError
+from tornado.ioloop import IOLoop, TimeoutError, PollIOLoop, PeriodicCallback
from tornado.log import app_log
+from tornado.platform.select import _Select
from tornado.stack_context import ExceptionStackContext, StackContext, wrap, NullContext
from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog
from tornado.test.util import unittest, skipIfNonUnix, skipOnTravis
@@ -23,6 +24,42 @@ except ImportError:
futures = None
+class FakeTimeSelect(_Select):
+ def __init__(self):
+ self._time = 1000
+ super(FakeTimeSelect, self).__init__()
+
+ def time(self):
+ return self._time
+
+ def sleep(self, t):
+ self._time += t
+
+ def poll(self, timeout):
+ events = super(FakeTimeSelect, self).poll(0)
+ if events:
+ return events
+ self._time += timeout
+ return []
+
+
+class FakeTimeIOLoop(PollIOLoop):
+ """IOLoop implementation with a fake and deterministic clock.
+
+ The clock advances as needed to trigger timeouts immediately.
+ For use when testing code that involves the passage of time
+ and no external dependencies.
+ """
+ def initialize(self):
+ self.fts = FakeTimeSelect()
+ super(FakeTimeIOLoop, self).initialize(impl=self.fts,
+ time_func=self.fts.time)
+
+ def sleep(self, t):
+ """Simulate a blocking sleep by advancing the clock."""
+ self.fts.sleep(t)
+
+
class TestIOLoop(AsyncTestCase):
@skipOnTravis
def test_add_callback_wakeup(self):
@@ -180,10 +217,12 @@ class TestIOLoop(AsyncTestCase):
# t2 should be cancelled by t1, even though it is already scheduled to
# be run before the ioloop even looks at it.
now = self.io_loop.time()
+
def t1():
calls[0] = True
self.io_loop.remove_timeout(t2_handle)
self.io_loop.add_timeout(now + 0.01, t1)
+
def t2():
calls[1] = True
t2_handle = self.io_loop.add_timeout(now + 0.02, t2)
@@ -252,6 +291,7 @@ class TestIOLoop(AsyncTestCase):
"""The handler callback receives the same fd object it passed in."""
server_sock, port = bind_unused_port()
fds = []
+
def handle_connection(fd, events):
fds.append(fd)
conn, addr = server_sock.accept()
@@ -274,6 +314,7 @@ class TestIOLoop(AsyncTestCase):
def test_mixed_fd_fileobj(self):
server_sock, port = bind_unused_port()
+
def f(fd, events):
pass
self.io_loop.add_handler(server_sock, f, IOLoop.READ)
@@ -288,6 +329,7 @@ class TestIOLoop(AsyncTestCase):
"""Calling start() twice should raise an error, not deadlock."""
returned_from_start = [False]
got_exception = [False]
+
def callback():
try:
self.io_loop.start()
@@ -305,7 +347,7 @@ class TestIOLoop(AsyncTestCase):
# Use a NullContext to keep the exception from being caught by
# AsyncTestCase.
with NullContext():
- self.io_loop.add_callback(lambda: 1/0)
+ self.io_loop.add_callback(lambda: 1 / 0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
@@ -316,7 +358,7 @@ class TestIOLoop(AsyncTestCase):
@gen.coroutine
def callback():
self.io_loop.add_callback(self.stop)
- 1/0
+ 1 / 0
self.io_loop.add_callback(callback)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
@@ -324,12 +366,12 @@ class TestIOLoop(AsyncTestCase):
def test_spawn_callback(self):
# An added callback runs in the test's stack_context, so will be
# re-arised in wait().
- self.io_loop.add_callback(lambda: 1/0)
+ self.io_loop.add_callback(lambda: 1 / 0)
with self.assertRaises(ZeroDivisionError):
self.wait()
# A spawned callback is run directly on the IOLoop, so it will be
# logged without stopping the test.
- self.io_loop.spawn_callback(lambda: 1/0)
+ self.io_loop.spawn_callback(lambda: 1 / 0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
@@ -344,6 +386,7 @@ class TestIOLoop(AsyncTestCase):
# After reading from one fd, remove the other from the IOLoop.
chunks = []
+
def handle_read(fd, events):
chunks.append(fd.recv(1024))
if fd is client:
@@ -352,7 +395,7 @@ class TestIOLoop(AsyncTestCase):
self.io_loop.remove_handler(client)
self.io_loop.add_handler(client, handle_read, self.io_loop.READ)
self.io_loop.add_handler(server, handle_read, self.io_loop.READ)
- self.io_loop.call_later(0.01, self.stop)
+ self.io_loop.call_later(0.03, self.stop)
self.wait()
# Only one fd was read; the other was cleanly removed.
@@ -520,5 +563,47 @@ class TestIOLoopRunSync(unittest.TestCase):
self.assertRaises(TimeoutError, self.io_loop.run_sync, f, timeout=0.01)
+class TestPeriodicCallback(unittest.TestCase):
+ def setUp(self):
+ self.io_loop = FakeTimeIOLoop()
+ self.io_loop.make_current()
+
+ def tearDown(self):
+ self.io_loop.close()
+
+ def test_basic(self):
+ calls = []
+
+ def cb():
+ calls.append(self.io_loop.time())
+ pc = PeriodicCallback(cb, 10000)
+ pc.start()
+ self.io_loop.call_later(50, self.io_loop.stop)
+ self.io_loop.start()
+ self.assertEqual(calls, [1010, 1020, 1030, 1040, 1050])
+
+ def test_overrun(self):
+ sleep_durations = [9, 9, 10, 11, 20, 20, 35, 35, 0, 0]
+ expected = [
+ 1010, 1020, 1030, # first 3 calls on schedule
+ 1050, 1070, # next 2 delayed one cycle
+ 1100, 1130, # next 2 delayed 2 cycles
+ 1170, 1210, # next 2 delayed 3 cycles
+ 1220, 1230, # then back on schedule.
+ ]
+ calls = []
+
+ def cb():
+ calls.append(self.io_loop.time())
+ if not sleep_durations:
+ self.io_loop.stop()
+ return
+ self.io_loop.sleep(sleep_durations.pop(0))
+ pc = PeriodicCallback(cb, 10000)
+ pc.start()
+ self.io_loop.start()
+ self.assertEqual(calls, expected)
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/tornado/test/iostream_test.py b/tornado/test/iostream_test.py
index f51caeaf..45df6b50 100644
--- a/tornado/test/iostream_test.py
+++ b/tornado/test/iostream_test.py
@@ -7,10 +7,10 @@ from tornado.httputil import HTTPHeaders
from tornado.log import gen_log, app_log
from tornado.netutil import ssl_wrap_socket
from tornado.stack_context import NullContext
+from tornado.tcpserver import TCPServer
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test
-from tornado.test.util import unittest, skipIfNonUnix
+from tornado.test.util import unittest, skipIfNonUnix, refusing_port
from tornado.web import RequestHandler, Application
-import lib.certifi
import errno
import logging
import os
@@ -51,18 +51,18 @@ class TestIOStreamWebMixin(object):
def test_read_until_close(self):
stream = self._make_client_iostream()
- stream.connect(('localhost', self.get_http_port()), callback=self.stop)
+ stream.connect(('127.0.0.1', self.get_http_port()), callback=self.stop)
self.wait()
stream.write(b"GET / HTTP/1.0\r\n\r\n")
stream.read_until_close(self.stop)
data = self.wait()
- self.assertTrue(data.startswith(b"HTTP/1.0 200"))
+ self.assertTrue(data.startswith(b"HTTP/1.1 200"))
self.assertTrue(data.endswith(b"Hello"))
def test_read_zero_bytes(self):
self.stream = self._make_client_iostream()
- self.stream.connect(("localhost", self.get_http_port()),
+ self.stream.connect(("127.0.0.1", self.get_http_port()),
callback=self.stop)
self.wait()
self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
@@ -70,7 +70,7 @@ class TestIOStreamWebMixin(object):
# normal read
self.stream.read_bytes(9, self.stop)
data = self.wait()
- self.assertEqual(data, b"HTTP/1.0 ")
+ self.assertEqual(data, b"HTTP/1.1 ")
# zero bytes
self.stream.read_bytes(0, self.stop)
@@ -91,7 +91,7 @@ class TestIOStreamWebMixin(object):
def connected_callback():
connected[0] = True
self.stop()
- stream.connect(("localhost", self.get_http_port()),
+ stream.connect(("127.0.0.1", self.get_http_port()),
callback=connected_callback)
# unlike the previous tests, try to write before the connection
# is complete.
@@ -121,11 +121,11 @@ class TestIOStreamWebMixin(object):
"""Basic test of IOStream's ability to return Futures."""
stream = self._make_client_iostream()
connect_result = yield stream.connect(
- ("localhost", self.get_http_port()))
+ ("127.0.0.1", self.get_http_port()))
self.assertIs(connect_result, stream)
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
first_line = yield stream.read_until(b"\r\n")
- self.assertEqual(first_line, b"HTTP/1.0 200 OK\r\n")
+ self.assertEqual(first_line, b"HTTP/1.1 200 OK\r\n")
# callback=None is equivalent to no callback.
header_data = yield stream.read_until(b"\r\n\r\n", callback=None)
headers = HTTPHeaders.parse(header_data.decode('latin1'))
@@ -137,7 +137,7 @@ class TestIOStreamWebMixin(object):
@gen_test
def test_future_close_while_reading(self):
stream = self._make_client_iostream()
- yield stream.connect(("localhost", self.get_http_port()))
+ yield stream.connect(("127.0.0.1", self.get_http_port()))
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
with self.assertRaises(StreamClosedError):
yield stream.read_bytes(1024 * 1024)
@@ -147,7 +147,7 @@ class TestIOStreamWebMixin(object):
def test_future_read_until_close(self):
# Ensure that the data comes through before the StreamClosedError.
stream = self._make_client_iostream()
- yield stream.connect(("localhost", self.get_http_port()))
+ yield stream.connect(("127.0.0.1", self.get_http_port()))
yield stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n")
yield stream.read_until(b"\r\n\r\n")
body = yield stream.read_until_close()
@@ -217,17 +217,18 @@ class TestIOStreamMixin(object):
# When a connection is refused, the connect callback should not
# be run. (The kqueue IOLoop used to behave differently from the
# epoll IOLoop in this respect)
- server_socket, port = bind_unused_port()
- server_socket.close()
+ cleanup_func, port = refusing_port()
+ self.addCleanup(cleanup_func)
stream = IOStream(socket.socket(), self.io_loop)
self.connect_called = False
def connect_callback():
self.connect_called = True
+ self.stop()
stream.set_close_callback(self.stop)
# log messages vary by platform and ioloop implementation
with ExpectLog(gen_log, ".*", required=False):
- stream.connect(("localhost", port), connect_callback)
+ stream.connect(("127.0.0.1", port), connect_callback)
self.wait()
self.assertFalse(self.connect_called)
self.assertTrue(isinstance(stream.error, socket.error), stream.error)
@@ -248,7 +249,8 @@ class TestIOStreamMixin(object):
# opendns and some ISPs return bogus addresses for nonexistent
# domains instead of the proper error codes).
with ExpectLog(gen_log, "Connect error"):
- stream.connect(('an invalid domain', 54321))
+ stream.connect(('an invalid domain', 54321), callback=self.stop)
+ self.wait()
self.assertTrue(isinstance(stream.error, socket.gaierror), stream.error)
def test_read_callback_error(self):
@@ -308,6 +310,7 @@ class TestIOStreamMixin(object):
def streaming_callback(data):
chunks.append(data)
self.stop()
+
def close_callback(data):
assert not data, data
closed[0] = True
@@ -325,6 +328,31 @@ class TestIOStreamMixin(object):
server.close()
client.close()
+ def test_streaming_until_close_future(self):
+ server, client = self.make_iostream_pair()
+ try:
+ chunks = []
+
+ @gen.coroutine
+ def client_task():
+ yield client.read_until_close(streaming_callback=chunks.append)
+
+ @gen.coroutine
+ def server_task():
+ yield server.write(b"1234")
+ yield gen.sleep(0.01)
+ yield server.write(b"5678")
+ server.close()
+
+ @gen.coroutine
+ def f():
+ yield [client_task(), server_task()]
+ self.io_loop.run_sync(f)
+ self.assertEqual(chunks, [b"1234", b"5678"])
+ finally:
+ server.close()
+ client.close()
+
def test_delayed_close_callback(self):
# The scenario: Server closes the connection while there is a pending
# read that can be served out of buffered data. The client does not
@@ -353,6 +381,7 @@ class TestIOStreamMixin(object):
def test_future_delayed_close_callback(self):
# Same as test_delayed_close_callback, but with the future interface.
server, client = self.make_iostream_pair()
+
# We can't call make_iostream_pair inside a gen_test function
# because the ioloop is not reentrant.
@gen_test
@@ -532,6 +561,7 @@ class TestIOStreamMixin(object):
# and IOStream._maybe_add_error_listener.
server, client = self.make_iostream_pair()
closed = [False]
+
def close_callback():
closed[0] = True
self.stop()
@@ -724,6 +754,26 @@ class TestIOStreamMixin(object):
server.close()
client.close()
+ def test_flow_control(self):
+ MB = 1024 * 1024
+ server, client = self.make_iostream_pair(max_buffer_size=5 * MB)
+ try:
+ # Client writes more than the server will accept.
+ client.write(b"a" * 10 * MB)
+ # The server pauses while reading.
+ server.read_bytes(MB, self.stop)
+ self.wait()
+ self.io_loop.call_later(0.1, self.stop)
+ self.wait()
+ # The client's writes have been blocked; the server can
+ # continue to read gradually.
+ for i in range(9):
+ server.read_bytes(MB, self.stop)
+ self.wait()
+ finally:
+ server.close()
+ client.close()
+
class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase):
def _make_client_iostream(self):
@@ -732,7 +782,8 @@ class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase):
class TestIOStreamWebHTTPS(TestIOStreamWebMixin, AsyncHTTPSTestCase):
def _make_client_iostream(self):
- return SSLIOStream(socket.socket(), io_loop=self.io_loop)
+ return SSLIOStream(socket.socket(), io_loop=self.io_loop,
+ ssl_options=dict(cert_reqs=ssl.CERT_NONE))
class TestIOStream(TestIOStreamMixin, AsyncTestCase):
@@ -752,7 +803,9 @@ class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase):
return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
def _make_client_iostream(self, connection, **kwargs):
- return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
+ return SSLIOStream(connection, io_loop=self.io_loop,
+ ssl_options=dict(cert_reqs=ssl.CERT_NONE),
+ **kwargs)
# This will run some tests that are basically redundant but it's the
@@ -820,10 +873,10 @@ class TestIOStreamStartTLS(AsyncTestCase):
recv_line = yield self.client_stream.read_until(b"\r\n")
self.assertEqual(line, recv_line)
- def client_start_tls(self, ssl_options=None):
+ def client_start_tls(self, ssl_options=None, server_hostname=None):
client_stream = self.client_stream
self.client_stream = None
- return client_stream.start_tls(False, ssl_options)
+ return client_stream.start_tls(False, ssl_options, server_hostname)
def server_start_tls(self, ssl_options=None):
server_stream = self.server_stream
@@ -842,7 +895,7 @@ class TestIOStreamStartTLS(AsyncTestCase):
yield self.server_send_line(b"250 STARTTLS\r\n")
yield self.client_send_line(b"STARTTLS\r\n")
yield self.server_send_line(b"220 Go ahead\r\n")
- client_future = self.client_start_tls()
+ client_future = self.client_start_tls(dict(cert_reqs=ssl.CERT_NONE))
server_future = self.server_start_tls(_server_ssl_options())
self.client_stream = yield client_future
self.server_stream = yield server_future
@@ -853,12 +906,123 @@ class TestIOStreamStartTLS(AsyncTestCase):
@gen_test
def test_handshake_fail(self):
- self.server_start_tls(_server_ssl_options())
- client_future = self.client_start_tls(
- dict(cert_reqs=ssl.CERT_REQUIRED, ca_certs=certifi.where()))
+ server_future = self.server_start_tls(_server_ssl_options())
+ # Certificates are verified with the default configuration.
+ client_future = self.client_start_tls(server_hostname="localhost")
with ExpectLog(gen_log, "SSL Error"):
with self.assertRaises(ssl.SSLError):
yield client_future
+ with self.assertRaises((ssl.SSLError, socket.error)):
+ yield server_future
+
+ @unittest.skipIf(not hasattr(ssl, 'create_default_context'),
+ 'ssl.create_default_context not present')
+ @gen_test
+ def test_check_hostname(self):
+ # Test that server_hostname parameter to start_tls is being used.
+ # The check_hostname functionality is only available in python 2.7 and
+ # up and in python 3.4 and up.
+ server_future = self.server_start_tls(_server_ssl_options())
+ client_future = self.client_start_tls(
+ ssl.create_default_context(),
+ server_hostname=b'127.0.0.1')
+ with ExpectLog(gen_log, "SSL Error"):
+ with self.assertRaises(ssl.SSLError):
+ yield client_future
+ with self.assertRaises((ssl.SSLError, socket.error)):
+ yield server_future
+
+
+class WaitForHandshakeTest(AsyncTestCase):
+ @gen.coroutine
+ def connect_to_server(self, server_cls):
+ server = client = None
+ try:
+ sock, port = bind_unused_port()
+ server = server_cls(ssl_options=_server_ssl_options())
+ server.add_socket(sock)
+
+ client = SSLIOStream(socket.socket(),
+ ssl_options=dict(cert_reqs=ssl.CERT_NONE))
+ yield client.connect(('127.0.0.1', port))
+ self.assertIsNotNone(client.socket.cipher())
+ finally:
+ if server is not None:
+ server.stop()
+ if client is not None:
+ client.close()
+
+ @gen_test
+ def test_wait_for_handshake_callback(self):
+ test = self
+ handshake_future = Future()
+
+ class TestServer(TCPServer):
+ def handle_stream(self, stream, address):
+ # The handshake has not yet completed.
+ test.assertIsNone(stream.socket.cipher())
+ self.stream = stream
+ stream.wait_for_handshake(self.handshake_done)
+
+ def handshake_done(self):
+ # Now the handshake is done and ssl information is available.
+ test.assertIsNotNone(self.stream.socket.cipher())
+ handshake_future.set_result(None)
+
+ yield self.connect_to_server(TestServer)
+ yield handshake_future
+
+ @gen_test
+ def test_wait_for_handshake_future(self):
+ test = self
+ handshake_future = Future()
+
+ class TestServer(TCPServer):
+ def handle_stream(self, stream, address):
+ test.assertIsNone(stream.socket.cipher())
+ test.io_loop.spawn_callback(self.handle_connection, stream)
+
+ @gen.coroutine
+ def handle_connection(self, stream):
+ yield stream.wait_for_handshake()
+ handshake_future.set_result(None)
+
+ yield self.connect_to_server(TestServer)
+ yield handshake_future
+
+ @gen_test
+ def test_wait_for_handshake_already_waiting_error(self):
+ test = self
+ handshake_future = Future()
+
+ class TestServer(TCPServer):
+ def handle_stream(self, stream, address):
+ stream.wait_for_handshake(self.handshake_done)
+ test.assertRaises(RuntimeError, stream.wait_for_handshake)
+
+ def handshake_done(self):
+ handshake_future.set_result(None)
+
+ yield self.connect_to_server(TestServer)
+ yield handshake_future
+
+ @gen_test
+ def test_wait_for_handshake_already_connected(self):
+ handshake_future = Future()
+
+ class TestServer(TCPServer):
+ def handle_stream(self, stream, address):
+ self.stream = stream
+ stream.wait_for_handshake(self.handshake_done)
+
+ def handshake_done(self):
+ self.stream.wait_for_handshake(self.handshake2_done)
+
+ def handshake2_done(self):
+ handshake_future.set_result(None)
+
+ yield self.connect_to_server(TestServer)
+ yield handshake_future
@skipIfNonUnix
diff --git a/tornado/test/locale_test.py b/tornado/test/locale_test.py
index d12ad52f..31c57a61 100644
--- a/tornado/test/locale_test.py
+++ b/tornado/test/locale_test.py
@@ -41,6 +41,12 @@ class TranslationLoaderTest(unittest.TestCase):
locale = tornado.locale.get("fr_FR")
self.assertTrue(isinstance(locale, tornado.locale.GettextLocale))
self.assertEqual(locale.translate("school"), u("\u00e9cole"))
+ self.assertEqual(locale.pgettext("law", "right"), u("le droit"))
+ self.assertEqual(locale.pgettext("good", "right"), u("le bien"))
+ self.assertEqual(locale.pgettext("organization", "club", "clubs", 1), u("le club"))
+ self.assertEqual(locale.pgettext("organization", "club", "clubs", 2), u("les clubs"))
+ self.assertEqual(locale.pgettext("stick", "club", "clubs", 1), u("le b\xe2ton"))
+ self.assertEqual(locale.pgettext("stick", "club", "clubs", 2), u("les b\xe2tons"))
class LocaleDataTest(unittest.TestCase):
@@ -58,7 +64,7 @@ class EnglishTest(unittest.TestCase):
self.assertEqual(locale.format_date(date, full_format=True),
'April 28, 2013 at 6:35 pm')
- self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(seconds=2), full_format=False),
+ self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(seconds=2), full_format=False),
'2 seconds ago')
self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(minutes=2), full_format=False),
'2 minutes ago')
diff --git a/tornado/test/locks_test.py b/tornado/test/locks_test.py
new file mode 100644
index 00000000..90bdafaa
--- /dev/null
+++ b/tornado/test/locks_test.py
@@ -0,0 +1,480 @@
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from datetime import timedelta
+
+from tornado import gen, locks
+from tornado.gen import TimeoutError
+from tornado.testing import gen_test, AsyncTestCase
+from tornado.test.util import unittest
+
+
+class ConditionTest(AsyncTestCase):
+ def setUp(self):
+ super(ConditionTest, self).setUp()
+ self.history = []
+
+ def record_done(self, future, key):
+ """Record the resolution of a Future returned by Condition.wait."""
+ def callback(_):
+ if not future.result():
+ # wait() resolved to False, meaning it timed out.
+ self.history.append('timeout')
+ else:
+ self.history.append(key)
+ future.add_done_callback(callback)
+
+ def test_repr(self):
+ c = locks.Condition()
+ self.assertIn('Condition', repr(c))
+ self.assertNotIn('waiters', repr(c))
+ c.wait()
+ self.assertIn('waiters', repr(c))
+
+ @gen_test
+ def test_notify(self):
+ c = locks.Condition()
+ self.io_loop.call_later(0.01, c.notify)
+ yield c.wait()
+
+ def test_notify_1(self):
+ c = locks.Condition()
+ self.record_done(c.wait(), 'wait1')
+ self.record_done(c.wait(), 'wait2')
+ c.notify(1)
+ self.history.append('notify1')
+ c.notify(1)
+ self.history.append('notify2')
+ self.assertEqual(['wait1', 'notify1', 'wait2', 'notify2'],
+ self.history)
+
+ def test_notify_n(self):
+ c = locks.Condition()
+ for i in range(6):
+ self.record_done(c.wait(), i)
+
+ c.notify(3)
+
+ # Callbacks execute in the order they were registered.
+ self.assertEqual(list(range(3)), self.history)
+ c.notify(1)
+ self.assertEqual(list(range(4)), self.history)
+ c.notify(2)
+ self.assertEqual(list(range(6)), self.history)
+
+ def test_notify_all(self):
+ c = locks.Condition()
+ for i in range(4):
+ self.record_done(c.wait(), i)
+
+ c.notify_all()
+ self.history.append('notify_all')
+
+ # Callbacks execute in the order they were registered.
+ self.assertEqual(
+ list(range(4)) + ['notify_all'],
+ self.history)
+
+ @gen_test
+ def test_wait_timeout(self):
+ c = locks.Condition()
+ wait = c.wait(timedelta(seconds=0.01))
+ self.io_loop.call_later(0.02, c.notify) # Too late.
+ yield gen.sleep(0.03)
+ self.assertFalse((yield wait))
+
+ @gen_test
+ def test_wait_timeout_preempted(self):
+ c = locks.Condition()
+
+ # This fires before the wait times out.
+ self.io_loop.call_later(0.01, c.notify)
+ wait = c.wait(timedelta(seconds=0.02))
+ yield gen.sleep(0.03)
+ yield wait # No TimeoutError.
+
+ @gen_test
+ def test_notify_n_with_timeout(self):
+ # Register callbacks 0, 1, 2, and 3. Callback 1 has a timeout.
+ # Wait for that timeout to expire, then do notify(2) and make
+ # sure everyone runs. Verifies that a timed-out callback does
+ # not count against the 'n' argument to notify().
+ c = locks.Condition()
+ self.record_done(c.wait(), 0)
+ self.record_done(c.wait(timedelta(seconds=0.01)), 1)
+ self.record_done(c.wait(), 2)
+ self.record_done(c.wait(), 3)
+
+ # Wait for callback 1 to time out.
+ yield gen.sleep(0.02)
+ self.assertEqual(['timeout'], self.history)
+
+ c.notify(2)
+ yield gen.sleep(0.01)
+ self.assertEqual(['timeout', 0, 2], self.history)
+ self.assertEqual(['timeout', 0, 2], self.history)
+ c.notify()
+ self.assertEqual(['timeout', 0, 2, 3], self.history)
+
+ @gen_test
+ def test_notify_all_with_timeout(self):
+ c = locks.Condition()
+ self.record_done(c.wait(), 0)
+ self.record_done(c.wait(timedelta(seconds=0.01)), 1)
+ self.record_done(c.wait(), 2)
+
+ # Wait for callback 1 to time out.
+ yield gen.sleep(0.02)
+ self.assertEqual(['timeout'], self.history)
+
+ c.notify_all()
+ self.assertEqual(['timeout', 0, 2], self.history)
+
+ @gen_test
+ def test_nested_notify(self):
+ # Ensure no notifications lost, even if notify() is reentered by a
+ # waiter calling notify().
+ c = locks.Condition()
+
+ # Three waiters.
+ futures = [c.wait() for _ in range(3)]
+
+ # First and second futures resolved. Second future reenters notify(),
+ # resolving third future.
+ futures[1].add_done_callback(lambda _: c.notify())
+ c.notify(2)
+ self.assertTrue(all(f.done() for f in futures))
+
+ @gen_test
+ def test_garbage_collection(self):
+ # Test that timed-out waiters are occasionally cleaned from the queue.
+ c = locks.Condition()
+ for _ in range(101):
+ c.wait(timedelta(seconds=0.01))
+
+ future = c.wait()
+ self.assertEqual(102, len(c._waiters))
+
+ # Let first 101 waiters time out, triggering a collection.
+ yield gen.sleep(0.02)
+ self.assertEqual(1, len(c._waiters))
+
+ # Final waiter is still active.
+ self.assertFalse(future.done())
+ c.notify()
+ self.assertTrue(future.done())
+
+
+class EventTest(AsyncTestCase):
+ def test_repr(self):
+ event = locks.Event()
+ self.assertTrue('clear' in str(event))
+ self.assertFalse('set' in str(event))
+ event.set()
+ self.assertFalse('clear' in str(event))
+ self.assertTrue('set' in str(event))
+
+ def test_event(self):
+ e = locks.Event()
+ future_0 = e.wait()
+ e.set()
+ future_1 = e.wait()
+ e.clear()
+ future_2 = e.wait()
+
+ self.assertTrue(future_0.done())
+ self.assertTrue(future_1.done())
+ self.assertFalse(future_2.done())
+
+ @gen_test
+ def test_event_timeout(self):
+ e = locks.Event()
+ with self.assertRaises(TimeoutError):
+ yield e.wait(timedelta(seconds=0.01))
+
+ # After a timed-out waiter, normal operation works.
+ self.io_loop.add_timeout(timedelta(seconds=0.01), e.set)
+ yield e.wait(timedelta(seconds=1))
+
+ def test_event_set_multiple(self):
+ e = locks.Event()
+ e.set()
+ e.set()
+ self.assertTrue(e.is_set())
+
+ def test_event_wait_clear(self):
+ e = locks.Event()
+ f0 = e.wait()
+ e.clear()
+ f1 = e.wait()
+ e.set()
+ self.assertTrue(f0.done())
+ self.assertTrue(f1.done())
+
+
+class SemaphoreTest(AsyncTestCase):
+ def test_negative_value(self):
+ self.assertRaises(ValueError, locks.Semaphore, value=-1)
+
+ def test_repr(self):
+ sem = locks.Semaphore()
+ self.assertIn('Semaphore', repr(sem))
+ self.assertIn('unlocked,value:1', repr(sem))
+ sem.acquire()
+ self.assertIn('locked', repr(sem))
+ self.assertNotIn('waiters', repr(sem))
+ sem.acquire()
+ self.assertIn('waiters', repr(sem))
+
+ def test_acquire(self):
+ sem = locks.Semaphore()
+ f0 = sem.acquire()
+ self.assertTrue(f0.done())
+
+ # Wait for release().
+ f1 = sem.acquire()
+ self.assertFalse(f1.done())
+ f2 = sem.acquire()
+ sem.release()
+ self.assertTrue(f1.done())
+ self.assertFalse(f2.done())
+ sem.release()
+ self.assertTrue(f2.done())
+
+ sem.release()
+ # Now acquire() is instant.
+ self.assertTrue(sem.acquire().done())
+ self.assertEqual(0, len(sem._waiters))
+
+ @gen_test
+ def test_acquire_timeout(self):
+ sem = locks.Semaphore(2)
+ yield sem.acquire()
+ yield sem.acquire()
+ acquire = sem.acquire(timedelta(seconds=0.01))
+ self.io_loop.call_later(0.02, sem.release) # Too late.
+ yield gen.sleep(0.3)
+ with self.assertRaises(gen.TimeoutError):
+ yield acquire
+
+ sem.acquire()
+ f = sem.acquire()
+ self.assertFalse(f.done())
+ sem.release()
+ self.assertTrue(f.done())
+
+ @gen_test
+ def test_acquire_timeout_preempted(self):
+ sem = locks.Semaphore(1)
+ yield sem.acquire()
+
+ # This fires before the wait times out.
+ self.io_loop.call_later(0.01, sem.release)
+ acquire = sem.acquire(timedelta(seconds=0.02))
+ yield gen.sleep(0.03)
+ yield acquire # No TimeoutError.
+
+ def test_release_unacquired(self):
+ # Unbounded releases are allowed, and increment the semaphore's value.
+ sem = locks.Semaphore()
+ sem.release()
+ sem.release()
+
+ # Now the counter is 3. We can acquire three times before blocking.
+ self.assertTrue(sem.acquire().done())
+ self.assertTrue(sem.acquire().done())
+ self.assertTrue(sem.acquire().done())
+ self.assertFalse(sem.acquire().done())
+
+ @gen_test
+ def test_garbage_collection(self):
+ # Test that timed-out waiters are occasionally cleaned from the queue.
+ sem = locks.Semaphore(value=0)
+ futures = [sem.acquire(timedelta(seconds=0.01)) for _ in range(101)]
+
+ future = sem.acquire()
+ self.assertEqual(102, len(sem._waiters))
+
+ # Let first 101 waiters time out, triggering a collection.
+ yield gen.sleep(0.02)
+ self.assertEqual(1, len(sem._waiters))
+
+ # Final waiter is still active.
+ self.assertFalse(future.done())
+ sem.release()
+ self.assertTrue(future.done())
+
+ # Prevent "Future exception was never retrieved" messages.
+ for future in futures:
+ self.assertRaises(TimeoutError, future.result)
+
+
+class SemaphoreContextManagerTest(AsyncTestCase):
+ @gen_test
+ def test_context_manager(self):
+ sem = locks.Semaphore()
+ with (yield sem.acquire()) as yielded:
+ self.assertTrue(yielded is None)
+
+ # Semaphore was released and can be acquired again.
+ self.assertTrue(sem.acquire().done())
+
+ @gen_test
+ def test_context_manager_exception(self):
+ sem = locks.Semaphore()
+ with self.assertRaises(ZeroDivisionError):
+ with (yield sem.acquire()):
+ 1 / 0
+
+ # Semaphore was released and can be acquired again.
+ self.assertTrue(sem.acquire().done())
+
+ @gen_test
+ def test_context_manager_timeout(self):
+ sem = locks.Semaphore()
+ with (yield sem.acquire(timedelta(seconds=0.01))):
+ pass
+
+ # Semaphore was released and can be acquired again.
+ self.assertTrue(sem.acquire().done())
+
+ @gen_test
+ def test_context_manager_timeout_error(self):
+ sem = locks.Semaphore(value=0)
+ with self.assertRaises(gen.TimeoutError):
+ with (yield sem.acquire(timedelta(seconds=0.01))):
+ pass
+
+ # Counter is still 0.
+ self.assertFalse(sem.acquire().done())
+
+ @gen_test
+ def test_context_manager_contended(self):
+ sem = locks.Semaphore()
+ history = []
+
+ @gen.coroutine
+ def f(index):
+ with (yield sem.acquire()):
+ history.append('acquired %d' % index)
+ yield gen.sleep(0.01)
+ history.append('release %d' % index)
+
+ yield [f(i) for i in range(2)]
+
+ expected_history = []
+ for i in range(2):
+ expected_history.extend(['acquired %d' % i, 'release %d' % i])
+
+ self.assertEqual(expected_history, history)
+
+ @gen_test
+ def test_yield_sem(self):
+ # Ensure we catch a "with (yield sem)", which should be
+ # "with (yield sem.acquire())".
+ with self.assertRaises(gen.BadYieldError):
+ with (yield locks.Semaphore()):
+ pass
+
+ def test_context_manager_misuse(self):
+ # Ensure we catch a "with sem", which should be
+ # "with (yield sem.acquire())".
+ with self.assertRaises(RuntimeError):
+ with locks.Semaphore():
+ pass
+
+
+class BoundedSemaphoreTest(AsyncTestCase):
+ def test_release_unacquired(self):
+ sem = locks.BoundedSemaphore()
+ self.assertRaises(ValueError, sem.release)
+ # Value is 0.
+ sem.acquire()
+ # Block on acquire().
+ future = sem.acquire()
+ self.assertFalse(future.done())
+ sem.release()
+ self.assertTrue(future.done())
+ # Value is 1.
+ sem.release()
+ self.assertRaises(ValueError, sem.release)
+
+
+class LockTests(AsyncTestCase):
+ def test_repr(self):
+ lock = locks.Lock()
+ # No errors.
+ repr(lock)
+ lock.acquire()
+ repr(lock)
+
+ def test_acquire_release(self):
+ lock = locks.Lock()
+ self.assertTrue(lock.acquire().done())
+ future = lock.acquire()
+ self.assertFalse(future.done())
+ lock.release()
+ self.assertTrue(future.done())
+
+ @gen_test
+ def test_acquire_fifo(self):
+ lock = locks.Lock()
+ self.assertTrue(lock.acquire().done())
+ N = 5
+ history = []
+
+ @gen.coroutine
+ def f(idx):
+ with (yield lock.acquire()):
+ history.append(idx)
+
+ futures = [f(i) for i in range(N)]
+ self.assertFalse(any(future.done() for future in futures))
+ lock.release()
+ yield futures
+ self.assertEqual(list(range(N)), history)
+
+ @gen_test
+ def test_acquire_timeout(self):
+ lock = locks.Lock()
+ lock.acquire()
+ with self.assertRaises(gen.TimeoutError):
+ yield lock.acquire(timeout=timedelta(seconds=0.01))
+
+ # Still locked.
+ self.assertFalse(lock.acquire().done())
+
+ def test_multi_release(self):
+ lock = locks.Lock()
+ self.assertRaises(RuntimeError, lock.release)
+ lock.acquire()
+ lock.release()
+ self.assertRaises(RuntimeError, lock.release)
+
+ @gen_test
+ def test_yield_lock(self):
+ # Ensure we catch a "with (yield lock)", which should be
+ # "with (yield lock.acquire())".
+ with self.assertRaises(gen.BadYieldError):
+ with (yield locks.Lock()):
+ pass
+
+ def test_context_manager_misuse(self):
+ # Ensure we catch a "with lock", which should be
+ # "with (yield lock.acquire())".
+ with self.assertRaises(RuntimeError):
+ with locks.Lock():
+ pass
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tornado/test/netutil_test.py b/tornado/test/netutil_test.py
index 1df1e320..7d9cad34 100644
--- a/tornado/test/netutil_test.py
+++ b/tornado/test/netutil_test.py
@@ -67,10 +67,12 @@ class _ResolverErrorTestMixin(object):
yield self.resolver.resolve('an invalid domain', 80,
socket.AF_UNSPEC)
+
def _failing_getaddrinfo(*args):
"""Dummy implementation of getaddrinfo for use in mocks"""
raise socket.gaierror("mock: lookup failed")
+
@skipIfNoNetwork
class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
diff --git a/tornado/test/process_test.py b/tornado/test/process_test.py
index de727607..58cc410b 100644
--- a/tornado/test/process_test.py
+++ b/tornado/test/process_test.py
@@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop
from tornado.log import gen_log
from tornado.process import fork_processes, task_id, Subprocess
from tornado.simple_httpclient import SimpleAsyncHTTPClient
-from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase
+from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase, gen_test
from tornado.test.util import unittest, skipIfNonUnix
from tornado.web import RequestHandler, Application
@@ -85,7 +85,7 @@ class ProcessTest(unittest.TestCase):
self.assertEqual(id, task_id())
server = HTTPServer(self.get_app())
server.add_sockets([sock])
- IOLoop.instance().start()
+ IOLoop.current().start()
elif id == 2:
self.assertEqual(id, task_id())
sock.close()
@@ -200,6 +200,16 @@ class SubprocessTest(AsyncTestCase):
self.assertEqual(ret, 0)
self.assertEqual(subproc.returncode, ret)
+ @gen_test
+ def test_sigchild_future(self):
+ skip_if_twisted()
+ Subprocess.initialize()
+ self.addCleanup(Subprocess.uninitialize)
+ subproc = Subprocess([sys.executable, '-c', 'pass'])
+ ret = yield subproc.wait_for_exit()
+ self.assertEqual(ret, 0)
+ self.assertEqual(subproc.returncode, ret)
+
def test_sigchild_signal(self):
skip_if_twisted()
Subprocess.initialize(io_loop=self.io_loop)
@@ -212,3 +222,22 @@ class SubprocessTest(AsyncTestCase):
ret = self.wait()
self.assertEqual(subproc.returncode, ret)
self.assertEqual(ret, -signal.SIGTERM)
+
+ @gen_test
+ def test_wait_for_exit_raise(self):
+ skip_if_twisted()
+ Subprocess.initialize()
+ self.addCleanup(Subprocess.uninitialize)
+ subproc = Subprocess([sys.executable, '-c', 'import sys; sys.exit(1)'])
+ with self.assertRaises(subprocess.CalledProcessError) as cm:
+ yield subproc.wait_for_exit()
+ self.assertEqual(cm.exception.returncode, 1)
+
+ @gen_test
+ def test_wait_for_exit_raise_disabled(self):
+ skip_if_twisted()
+ Subprocess.initialize()
+ self.addCleanup(Subprocess.uninitialize)
+ subproc = Subprocess([sys.executable, '-c', 'import sys; sys.exit(1)'])
+ ret = yield subproc.wait_for_exit(raise_error=False)
+ self.assertEqual(ret, 1)
diff --git a/tornado/test/queues_test.py b/tornado/test/queues_test.py
new file mode 100644
index 00000000..f2ffb646
--- /dev/null
+++ b/tornado/test/queues_test.py
@@ -0,0 +1,403 @@
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from datetime import timedelta
+from random import random
+
+from tornado import gen, queues
+from tornado.gen import TimeoutError
+from tornado.testing import gen_test, AsyncTestCase
+from tornado.test.util import unittest
+
+
+class QueueBasicTest(AsyncTestCase):
+ def test_repr_and_str(self):
+ q = queues.Queue(maxsize=1)
+ self.assertIn(hex(id(q)), repr(q))
+ self.assertNotIn(hex(id(q)), str(q))
+ q.get()
+
+ for q_str in repr(q), str(q):
+ self.assertTrue(q_str.startswith('= logging.ERROR:
+ self.error_count += 1
+ elif record.levelno >= logging.WARNING:
+ self.warning_count += 1
+ return True
+
+
def main():
# The -W command-line option does not work in a virtualenv with
# python 3 (as of virtualenv 1.7), so configure warnings
@@ -92,12 +112,21 @@ def main():
# 2.7 and 3.2
warnings.filterwarnings("ignore", category=DeprecationWarning,
message="Please use assert.* instead")
+ # unittest2 0.6 on py26 reports these as PendingDeprecationWarnings
+ # instead of DeprecationWarnings.
+ warnings.filterwarnings("ignore", category=PendingDeprecationWarning,
+ message="Please use assert.* instead")
+ # Twisted 15.0.0 triggers some warnings on py3 with -bb.
+ warnings.filterwarnings("ignore", category=BytesWarning,
+ module=r"twisted\..*")
logging.getLogger("tornado.access").setLevel(logging.CRITICAL)
define('httpclient', type=str, default=None,
callback=lambda s: AsyncHTTPClient.configure(
s, defaults=dict(allow_ipv6=False)))
+ define('httpserver', type=str, default=None,
+ callback=HTTPServer.configure)
define('ioloop', type=str, default=None)
define('ioloop_time_monotonic', default=False)
define('resolver', type=str, default=None,
@@ -121,6 +150,10 @@ def main():
IOLoop.configure(options.ioloop, **kwargs)
add_parse_callback(configure_ioloop)
+ log_counter = LogCounter()
+ add_parse_callback(
+ lambda: logging.getLogger().handlers[0].addFilter(log_counter))
+
import tornado.testing
kwargs = {}
if sys.version_info >= (3, 2):
@@ -131,7 +164,16 @@ def main():
# detail. http://bugs.python.org/issue15626
kwargs['warnings'] = False
kwargs['testRunner'] = TornadoTextTestRunner
- tornado.testing.main(**kwargs)
+ try:
+ tornado.testing.main(**kwargs)
+ finally:
+ # The tests should run clean; consider it a failure if they logged
+ # any warnings or errors. We'd like to ban info logs too, but
+ # we can't count them cleanly due to interactions with LogTrapTestCase.
+ if log_counter.warning_count > 0 or log_counter.error_count > 0:
+ logging.error("logged %d warnings and %d errors",
+ log_counter.warning_count, log_counter.error_count)
+ sys.exit(1)
if __name__ == '__main__':
main()
diff --git a/tornado/test/simple_httpclient_test.py b/tornado/test/simple_httpclient_test.py
index e3fab57a..c0de22b7 100644
--- a/tornado/test/simple_httpclient_test.py
+++ b/tornado/test/simple_httpclient_test.py
@@ -8,19 +8,20 @@ import logging
import os
import re
import socket
+import ssl
import sys
from tornado import gen
from tornado.httpclient import AsyncHTTPClient
-from tornado.httputil import HTTPHeaders
+from tornado.httputil import HTTPHeaders, ResponseStartLine
from tornado.ioloop import IOLoop
from tornado.log import gen_log
from tornado.netutil import Resolver, bind_sockets
from tornado.simple_httpclient import SimpleAsyncHTTPClient, _default_ca_certs
from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler
from tornado.test import httpclient_test
-from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
-from tornado.test.util import skipOnTravis, skipIfNoIPv6
+from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog
+from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port, unittest
from tornado.web import RequestHandler, Application, asynchronous, url, stream_request_body
@@ -97,15 +98,18 @@ class HostEchoHandler(RequestHandler):
class NoContentLengthHandler(RequestHandler):
- @gen.coroutine
+ @asynchronous
def get(self):
- # Emulate the old HTTP/1.0 behavior of returning a body with no
- # content-length. Tornado handles content-length at the framework
- # level so we have to go around it.
- stream = self.request.connection.stream
- yield stream.write(b"HTTP/1.0 200 OK\r\n\r\n"
- b"hello")
- stream.close()
+ if self.request.version.startswith('HTTP/1'):
+ # Emulate the old HTTP/1.0 behavior of returning a body with no
+ # content-length. Tornado handles content-length at the framework
+ # level so we have to go around it.
+ stream = self.request.connection.detach()
+ stream.write(b"HTTP/1.0 200 OK\r\n\r\n"
+ b"hello")
+ stream.close()
+ else:
+ self.finish('HTTP/1 required')
class EchoPostHandler(RequestHandler):
@@ -191,9 +195,6 @@ class SimpleHTTPClientTestMixin(object):
response = self.wait()
response.rethrow()
- def test_default_certificates_exist(self):
- open(_default_ca_certs()).close()
-
def test_gzip(self):
# All the tests in this file should be using gzip, but this test
# ensures that it is in fact getting compressed.
@@ -235,9 +236,16 @@ class SimpleHTTPClientTestMixin(object):
@skipOnTravis
def test_request_timeout(self):
- response = self.fetch('/trigger?wake=false', request_timeout=0.1)
+ timeout = 0.1
+ timeout_min, timeout_max = 0.099, 0.15
+ if os.name == 'nt':
+ timeout = 0.5
+ timeout_min, timeout_max = 0.4, 0.6
+
+ response = self.fetch('/trigger?wake=false', request_timeout=timeout)
self.assertEqual(response.code, 599)
- self.assertTrue(0.099 < response.request_time < 0.15, response.request_time)
+ self.assertTrue(timeout_min < response.request_time < timeout_max,
+ response.request_time)
self.assertEqual(str(response.error), "HTTP 599: Timeout")
# trigger the hanging request to let it clean up after itself
self.triggers.popleft()()
@@ -315,10 +323,10 @@ class SimpleHTTPClientTestMixin(object):
self.assertTrue(host_re.match(response.body), response.body)
def test_connection_refused(self):
- server_socket, port = bind_unused_port()
- server_socket.close()
+ cleanup_func, port = refusing_port()
+ self.addCleanup(cleanup_func)
with ExpectLog(gen_log, ".*", required=False):
- self.http_client.fetch("http://localhost:%d/" % port, self.stop)
+ self.http_client.fetch("http://127.0.0.1:%d/" % port, self.stop)
response = self.wait()
self.assertEqual(599, response.code)
@@ -352,7 +360,10 @@ class SimpleHTTPClientTestMixin(object):
def test_no_content_length(self):
response = self.fetch("/no_content_length")
- self.assertEquals(b"hello", response.body)
+ if response.body == b"HTTP/1 required":
+ self.skipTest("requires HTTP/1.x")
+ else:
+ self.assertEquals(b"hello", response.body)
def sync_body_producer(self, write):
write(b'1234')
@@ -425,6 +436,33 @@ class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
defaults=dict(validate_cert=False),
**kwargs)
+ def test_ssl_options(self):
+ resp = self.fetch("/hello", ssl_options={})
+ self.assertEqual(resp.body, b"Hello world!")
+
+ @unittest.skipIf(not hasattr(ssl, 'SSLContext'),
+ 'ssl.SSLContext not present')
+ def test_ssl_context(self):
+ resp = self.fetch("/hello",
+ ssl_options=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
+ self.assertEqual(resp.body, b"Hello world!")
+
+ def test_ssl_options_handshake_fail(self):
+ with ExpectLog(gen_log, "SSL Error|Uncaught exception",
+ required=False):
+ resp = self.fetch(
+ "/hello", ssl_options=dict(cert_reqs=ssl.CERT_REQUIRED))
+ self.assertRaises(ssl.SSLError, resp.rethrow)
+
+ @unittest.skipIf(not hasattr(ssl, 'SSLContext'),
+ 'ssl.SSLContext not present')
+ def test_ssl_context_handshake_fail(self):
+ with ExpectLog(gen_log, "SSL Error|Uncaught exception"):
+ ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ ctx.verify_mode = ssl.CERT_REQUIRED
+ resp = self.fetch("/hello", ssl_options=ctx)
+ self.assertRaises(ssl.SSLError, resp.rethrow)
+
class CreateAsyncHTTPClientTestCase(AsyncTestCase):
def setUp(self):
@@ -460,6 +498,12 @@ class CreateAsyncHTTPClientTestCase(AsyncTestCase):
class HTTP100ContinueTestCase(AsyncHTTPTestCase):
def respond_100(self, request):
+ self.http1 = request.version.startswith('HTTP/1.')
+ if not self.http1:
+ request.connection.write_headers(ResponseStartLine('', 200, 'OK'),
+ HTTPHeaders())
+ request.connection.finish()
+ return
self.request = request
self.request.connection.stream.write(
b"HTTP/1.1 100 CONTINUE\r\n\r\n",
@@ -476,11 +520,20 @@ class HTTP100ContinueTestCase(AsyncHTTPTestCase):
def test_100_continue(self):
res = self.fetch('/')
+ if not self.http1:
+ self.skipTest("requires HTTP/1.x")
self.assertEqual(res.body, b'A')
class HTTP204NoContentTestCase(AsyncHTTPTestCase):
def respond_204(self, request):
+ self.http1 = request.version.startswith('HTTP/1.')
+ if not self.http1:
+ # Close the request cleanly in HTTP/2; it will be skipped anyway.
+ request.connection.write_headers(ResponseStartLine('', 200, 'OK'),
+ HTTPHeaders())
+ request.connection.finish()
+ return
# A 204 response never has a body, even if doesn't have a content-length
# (which would otherwise mean read-until-close). Tornado always
# sends a content-length, so we simulate here a server that sends
@@ -488,14 +541,18 @@ class HTTP204NoContentTestCase(AsyncHTTPTestCase):
#
# Tests of a 204 response with a Content-Length header are included
# in SimpleHTTPClientTestMixin.
- request.connection.stream.write(
+ stream = request.connection.detach()
+ stream.write(
b"HTTP/1.1 204 No content\r\n\r\n")
+ stream.close()
def get_app(self):
return self.respond_204
def test_204_no_content(self):
resp = self.fetch('/')
+ if not self.http1:
+ self.skipTest("requires HTTP/1.x")
self.assertEqual(resp.code, 204)
self.assertEqual(resp.body, b'')
@@ -574,3 +631,49 @@ class MaxHeaderSizeTest(AsyncHTTPTestCase):
with ExpectLog(gen_log, "Unsatisfiable read"):
response = self.fetch('/large')
self.assertEqual(response.code, 599)
+
+
+class MaxBodySizeTest(AsyncHTTPTestCase):
+ def get_app(self):
+ class SmallBody(RequestHandler):
+ def get(self):
+ self.write("a"*1024*64)
+
+ class LargeBody(RequestHandler):
+ def get(self):
+ self.write("a"*1024*100)
+
+ return Application([('/small', SmallBody),
+ ('/large', LargeBody)])
+
+ def get_http_client(self):
+ return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_body_size=1024*64)
+
+ def test_small_body(self):
+ response = self.fetch('/small')
+ response.rethrow()
+ self.assertEqual(response.body, b'a'*1024*64)
+
+ def test_large_body(self):
+ with ExpectLog(gen_log, "Malformed HTTP message from None: Content-Length too long"):
+ response = self.fetch('/large')
+ self.assertEqual(response.code, 599)
+
+
+class MaxBufferSizeTest(AsyncHTTPTestCase):
+ def get_app(self):
+
+ class LargeBody(RequestHandler):
+ def get(self):
+ self.write("a"*1024*100)
+
+ return Application([('/large', LargeBody)])
+
+ def get_http_client(self):
+ # 100KB body with 64KB buffer
+ return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_body_size=1024*100, max_buffer_size=1024*64)
+
+ def test_large_body(self):
+ response = self.fetch('/large')
+ response.rethrow()
+ self.assertEqual(response.body, b'a'*1024*100)
diff --git a/tornado/test/tcpclient_test.py b/tornado/test/tcpclient_test.py
index 5df4a7ab..1a4201e6 100644
--- a/tornado/test/tcpclient_test.py
+++ b/tornado/test/tcpclient_test.py
@@ -24,8 +24,8 @@ from tornado.concurrent import Future
from tornado.netutil import bind_sockets, Resolver
from tornado.tcpclient import TCPClient, _Connector
from tornado.tcpserver import TCPServer
-from tornado.testing import AsyncTestCase, bind_unused_port, gen_test
-from tornado.test.util import skipIfNoIPv6, unittest
+from tornado.testing import AsyncTestCase, gen_test
+from tornado.test.util import skipIfNoIPv6, unittest, refusing_port
# Fake address families for testing. Used in place of AF_INET
# and AF_INET6 because some installations do not have AF_INET6.
@@ -120,8 +120,8 @@ class TCPClientTest(AsyncTestCase):
@gen_test
def test_refused_ipv4(self):
- sock, port = bind_unused_port()
- sock.close()
+ cleanup_func, port = refusing_port()
+ self.addCleanup(cleanup_func)
with self.assertRaises(IOError):
yield self.client.connect('127.0.0.1', port)
diff --git a/tornado/test/tcpserver_test.py b/tornado/test/tcpserver_test.py
new file mode 100644
index 00000000..84c95076
--- /dev/null
+++ b/tornado/test/tcpserver_test.py
@@ -0,0 +1,38 @@
+import socket
+
+from tornado import gen
+from tornado.iostream import IOStream
+from tornado.log import app_log
+from tornado.stack_context import NullContext
+from tornado.tcpserver import TCPServer
+from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test
+
+
+class TCPServerTest(AsyncTestCase):
+ @gen_test
+ def test_handle_stream_coroutine_logging(self):
+ # handle_stream may be a coroutine and any exception in its
+ # Future will be logged.
+ class TestServer(TCPServer):
+ @gen.coroutine
+ def handle_stream(self, stream, address):
+ yield gen.moment
+ stream.close()
+ 1/0
+
+ server = client = None
+ try:
+ sock, port = bind_unused_port()
+ with NullContext():
+ server = TestServer()
+ server.add_socket(sock)
+ client = IOStream(socket.socket())
+ with ExpectLog(app_log, "Exception in callback"):
+ yield client.connect(('localhost', port))
+ yield client.read_until_close()
+ yield gen.moment
+ finally:
+ if server is not None:
+ server.stop()
+ if client is not None:
+ client.close()
diff --git a/tornado/test/twisted_test.py b/tornado/test/twisted_test.py
index 2922a61e..8ace993d 100644
--- a/tornado/test/twisted_test.py
+++ b/tornado/test/twisted_test.py
@@ -19,15 +19,18 @@ Unittest for the twisted-style reactor.
from __future__ import absolute_import, division, print_function, with_statement
+import logging
import os
import shutil
import signal
+import sys
import tempfile
import threading
+import warnings
try:
import fcntl
- from twisted.internet.defer import Deferred
+ from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
from twisted.internet.interfaces import IReadDescriptor, IWriteDescriptor
from twisted.internet.protocol import Protocol
from twisted.python import log
@@ -40,10 +43,12 @@ except ImportError:
# The core of Twisted 12.3.0 is available on python 3, but twisted.web is not
# so test for it separately.
try:
- from twisted.web.client import Agent
+ from twisted.web.client import Agent, readBody
from twisted.web.resource import Resource
from twisted.web.server import Site
- have_twisted_web = True
+ # As of Twisted 15.0.0, twisted.web is present but fails our
+ # tests due to internal str/bytes errors.
+ have_twisted_web = sys.version_info < (3,)
except ImportError:
have_twisted_web = False
@@ -52,6 +57,8 @@ try:
except ImportError:
import _thread as thread # py3
+from tornado.escape import utf8
+from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
@@ -65,6 +72,9 @@ from tornado.web import RequestHandler, Application
skipIfNoTwisted = unittest.skipUnless(have_twisted,
"twisted module not present")
+skipIfNoSingleDispatch = unittest.skipIf(
+ gen.singledispatch is None, "singledispatch module not present")
+
def save_signal_handlers():
saved = {}
@@ -407,7 +417,7 @@ class CompatibilityTests(unittest.TestCase):
# http://twistedmatrix.com/documents/current/web/howto/client.html
chunks = []
client = Agent(self.reactor)
- d = client.request('GET', url)
+ d = client.request(b'GET', utf8(url))
class Accumulator(Protocol):
def __init__(self, finished):
@@ -425,37 +435,98 @@ class CompatibilityTests(unittest.TestCase):
return finished
d.addCallback(callback)
- def shutdown(ignored):
- self.stop_loop()
+ def shutdown(failure):
+ if hasattr(self, 'stop_loop'):
+ self.stop_loop()
+ elif failure is not None:
+ # loop hasn't been initialized yet; try our best to
+ # get an error message out. (the runner() interaction
+ # should probably be refactored).
+ try:
+ failure.raiseException()
+ except:
+ logging.error('exception before starting loop', exc_info=True)
d.addBoth(shutdown)
runner()
self.assertTrue(chunks)
return ''.join(chunks)
+ def twisted_coroutine_fetch(self, url, runner):
+ body = [None]
+
+ @gen.coroutine
+ def f():
+ # This is simpler than the non-coroutine version, but it cheats
+ # by reading the body in one blob instead of streaming it with
+ # a Protocol.
+ client = Agent(self.reactor)
+ response = yield client.request(b'GET', utf8(url))
+ with warnings.catch_warnings():
+ # readBody has a buggy DeprecationWarning in Twisted 15.0:
+ # https://twistedmatrix.com/trac/changeset/43379
+ warnings.simplefilter('ignore', category=DeprecationWarning)
+ body[0] = yield readBody(response)
+ self.stop_loop()
+ self.io_loop.add_callback(f)
+ runner()
+ return body[0]
+
def testTwistedServerTornadoClientIOLoop(self):
self.start_twisted_server()
response = self.tornado_fetch(
- 'http://localhost:%d' % self.twisted_port, self.run_ioloop)
+ 'http://127.0.0.1:%d' % self.twisted_port, self.run_ioloop)
self.assertEqual(response.body, 'Hello from twisted!')
def testTwistedServerTornadoClientReactor(self):
self.start_twisted_server()
response = self.tornado_fetch(
- 'http://localhost:%d' % self.twisted_port, self.run_reactor)
+ 'http://127.0.0.1:%d' % self.twisted_port, self.run_reactor)
self.assertEqual(response.body, 'Hello from twisted!')
def testTornadoServerTwistedClientIOLoop(self):
self.start_tornado_server()
response = self.twisted_fetch(
- 'http://localhost:%d' % self.tornado_port, self.run_ioloop)
+ 'http://127.0.0.1:%d' % self.tornado_port, self.run_ioloop)
self.assertEqual(response, 'Hello from tornado!')
def testTornadoServerTwistedClientReactor(self):
self.start_tornado_server()
response = self.twisted_fetch(
- 'http://localhost:%d' % self.tornado_port, self.run_reactor)
+ 'http://127.0.0.1:%d' % self.tornado_port, self.run_reactor)
self.assertEqual(response, 'Hello from tornado!')
+ @skipIfNoSingleDispatch
+ def testTornadoServerTwistedCoroutineClientIOLoop(self):
+ self.start_tornado_server()
+ response = self.twisted_coroutine_fetch(
+ 'http://127.0.0.1:%d' % self.tornado_port, self.run_ioloop)
+ self.assertEqual(response, 'Hello from tornado!')
+
+
+@skipIfNoTwisted
+@skipIfNoSingleDispatch
+class ConvertDeferredTest(unittest.TestCase):
+ def test_success(self):
+ @inlineCallbacks
+ def fn():
+ if False:
+ # inlineCallbacks doesn't work with regular functions;
+ # must have a yield even if it's unreachable.
+ yield
+ returnValue(42)
+ f = gen.convert_yielded(fn())
+ self.assertEqual(f.result(), 42)
+
+ def test_failure(self):
+ @inlineCallbacks
+ def fn():
+ if False:
+ yield
+ 1 / 0
+ f = gen.convert_yielded(fn())
+ with self.assertRaises(ZeroDivisionError):
+ f.result()
+
if have_twisted:
# Import and run as much of twisted's test suite as possible.
@@ -483,7 +554,7 @@ if have_twisted:
'test_changeUID',
],
# Process tests appear to work on OSX 10.7, but not 10.6
- #'twisted.internet.test.test_process.PTYProcessTestsBuilder': [
+ # 'twisted.internet.test.test_process.PTYProcessTestsBuilder': [
# 'test_systemCallUninterruptedByChildExit',
# ],
'twisted.internet.test.test_tcp.TCPClientTestsBuilder': [
@@ -502,7 +573,7 @@ if have_twisted:
'twisted.internet.test.test_threads.ThreadTestsBuilder': [],
'twisted.internet.test.test_time.TimeTestsBuilder': [],
# Extra third-party dependencies (pyOpenSSL)
- #'twisted.internet.test.test_tls.SSLClientTestsMixin': [],
+ # 'twisted.internet.test.test_tls.SSLClientTestsMixin': [],
'twisted.internet.test.test_udp.UDPServerTestsBuilder': [],
'twisted.internet.test.test_unix.UNIXTestsBuilder': [
# Platform-specific. These tests would be skipped automatically
@@ -588,13 +659,13 @@ if have_twisted:
correctly. In some tests another TornadoReactor is layered on top
of the whole stack.
"""
- def initialize(self):
+ def initialize(self, **kwargs):
# When configured to use LayeredTwistedIOLoop we can't easily
# get the next-best IOLoop implementation, so use the lowest common
# denominator.
self.real_io_loop = SelectIOLoop()
reactor = TornadoReactor(io_loop=self.real_io_loop)
- super(LayeredTwistedIOLoop, self).initialize(reactor=reactor)
+ super(LayeredTwistedIOLoop, self).initialize(reactor=reactor, **kwargs)
self.add_callback(self.make_current)
def close(self, all_fds=False):
diff --git a/tornado/test/util.py b/tornado/test/util.py
index d31bbba3..9dd9c0ce 100644
--- a/tornado/test/util.py
+++ b/tornado/test/util.py
@@ -4,6 +4,8 @@ import os
import socket
import sys
+from tornado.testing import bind_unused_port
+
# Encapsulate the choice of unittest or unittest2 here.
# To be used as 'from tornado.test.util import unittest'.
if sys.version_info < (2, 7):
@@ -28,3 +30,23 @@ skipIfNoNetwork = unittest.skipIf('NO_NETWORK' in os.environ,
'network access disabled')
skipIfNoIPv6 = unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present')
+
+
+def refusing_port():
+ """Returns a local port number that will refuse all connections.
+
+ Return value is (cleanup_func, port); the cleanup function
+ must be called to free the port to be reused.
+ """
+ # On travis-ci, port numbers are reassigned frequently. To avoid
+ # collisions with other tests, we use an open client-side socket's
+ # ephemeral port number to ensure that nothing can listen on that
+ # port.
+ server_socket, port = bind_unused_port()
+ server_socket.setblocking(1)
+ client_socket = socket.socket()
+ client_socket.connect(("127.0.0.1", port))
+ conn, client_addr = server_socket.accept()
+ conn.close()
+ server_socket.close()
+ return (client_socket.close, client_addr[1])
diff --git a/tornado/test/util_test.py b/tornado/test/util_test.py
index 1cd78fe4..0936c89a 100644
--- a/tornado/test/util_test.py
+++ b/tornado/test/util_test.py
@@ -3,8 +3,9 @@ from __future__ import absolute_import, division, print_function, with_statement
import sys
import datetime
+import tornado.escape
from tornado.escape import utf8
-from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer, timedelta_to_seconds
+from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer, timedelta_to_seconds, import_object
from tornado.test.util import unittest
try:
@@ -45,13 +46,15 @@ class TestConfigurable(Configurable):
class TestConfig1(TestConfigurable):
- def initialize(self, a=None):
+ def initialize(self, pos_arg=None, a=None):
self.a = a
+ self.pos_arg = pos_arg
class TestConfig2(TestConfigurable):
- def initialize(self, b=None):
+ def initialize(self, pos_arg=None, b=None):
self.b = b
+ self.pos_arg = pos_arg
class ConfigurableTest(unittest.TestCase):
@@ -101,9 +104,10 @@ class ConfigurableTest(unittest.TestCase):
self.assertIsInstance(obj, TestConfig1)
self.assertEqual(obj.a, 3)
- obj = TestConfigurable(a=4)
+ obj = TestConfigurable(42, a=4)
self.assertIsInstance(obj, TestConfig1)
self.assertEqual(obj.a, 4)
+ self.assertEqual(obj.pos_arg, 42)
self.checkSubclasses()
# args bound in configure don't apply when using the subclass directly
@@ -116,9 +120,10 @@ class ConfigurableTest(unittest.TestCase):
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 5)
- obj = TestConfigurable(b=6)
+ obj = TestConfigurable(42, b=6)
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 6)
+ self.assertEqual(obj.pos_arg, 42)
self.checkSubclasses()
# args bound in configure don't apply when using the subclass directly
@@ -177,3 +182,20 @@ class TimedeltaToSecondsTest(unittest.TestCase):
def test_timedelta_to_seconds(self):
time_delta = datetime.timedelta(hours=1)
self.assertEqual(timedelta_to_seconds(time_delta), 3600.0)
+
+
+class ImportObjectTest(unittest.TestCase):
+ def test_import_member(self):
+ self.assertIs(import_object('tornado.escape.utf8'), utf8)
+
+ def test_import_member_unicode(self):
+ self.assertIs(import_object(u('tornado.escape.utf8')), utf8)
+
+ def test_import_module(self):
+ self.assertIs(import_object('tornado.escape'), tornado.escape)
+
+ def test_import_module_unicode(self):
+ # The internal implementation of __import__ differs depending on
+ # whether the thing being imported is a module or not.
+ # This variant requires a byte string in python 2.
+ self.assertIs(import_object(u('tornado.escape')), tornado.escape)
diff --git a/tornado/test/web_test.py b/tornado/test/web_test.py
index 55c9c9e8..96edd6c2 100644
--- a/tornado/test/web_test.py
+++ b/tornado/test/web_test.py
@@ -11,7 +11,7 @@ from tornado.template import DictLoader
from tornado.testing import AsyncHTTPTestCase, ExpectLog, gen_test
from tornado.test.util import unittest
from tornado.util import u, ObjectDict, unicode_type, timedelta_to_seconds
-from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler
+from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler, get_signature_key_version
import binascii
import contextlib
@@ -71,10 +71,14 @@ class HelloHandler(RequestHandler):
class CookieTestRequestHandler(RequestHandler):
# stub out enough methods to make the secure_cookie functions work
- def __init__(self):
+ def __init__(self, cookie_secret='0123456789', key_version=None):
# don't call super.__init__
self._cookies = {}
- self.application = ObjectDict(settings=dict(cookie_secret='0123456789'))
+ if key_version is None:
+ self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret))
+ else:
+ self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret,
+ key_version=key_version))
def get_cookie(self, name):
return self._cookies.get(name)
@@ -128,6 +132,51 @@ class SecureCookieV1Test(unittest.TestCase):
self.assertEqual(handler.get_secure_cookie('foo', min_version=1), b'\xe9')
+# See SignedValueTest below for more.
+class SecureCookieV2Test(unittest.TestCase):
+ KEY_VERSIONS = {
+ 0: 'ajklasdf0ojaisdf',
+ 1: 'aslkjasaolwkjsdf'
+ }
+
+ def test_round_trip(self):
+ handler = CookieTestRequestHandler()
+ handler.set_secure_cookie('foo', b'bar', version=2)
+ self.assertEqual(handler.get_secure_cookie('foo', min_version=2), b'bar')
+
+ def test_key_version_roundtrip(self):
+ handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
+ key_version=0)
+ handler.set_secure_cookie('foo', b'bar')
+ self.assertEqual(handler.get_secure_cookie('foo'), b'bar')
+
+ def test_key_version_roundtrip_differing_version(self):
+ handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
+ key_version=1)
+ handler.set_secure_cookie('foo', b'bar')
+ self.assertEqual(handler.get_secure_cookie('foo'), b'bar')
+
+ def test_key_version_increment_version(self):
+ handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
+ key_version=0)
+ handler.set_secure_cookie('foo', b'bar')
+ new_handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
+ key_version=1)
+ new_handler._cookies = handler._cookies
+ self.assertEqual(new_handler.get_secure_cookie('foo'), b'bar')
+
+ def test_key_version_invalidate_version(self):
+ handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
+ key_version=0)
+ handler.set_secure_cookie('foo', b'bar')
+ new_key_versions = self.KEY_VERSIONS.copy()
+ new_key_versions.pop(0)
+ new_handler = CookieTestRequestHandler(cookie_secret=new_key_versions,
+ key_version=1)
+ new_handler._cookies = handler._cookies
+ self.assertEqual(new_handler.get_secure_cookie('foo'), None)
+
+
class CookieTest(WebTestCase):
def get_handlers(self):
class SetCookieHandler(RequestHandler):
@@ -171,6 +220,13 @@ class CookieTest(WebTestCase):
def get(self):
self.set_cookie("foo", "bar", expires_days=10)
+ class SetCookieFalsyFlags(RequestHandler):
+ def get(self):
+ self.set_cookie("a", "1", secure=True)
+ self.set_cookie("b", "1", secure=False)
+ self.set_cookie("c", "1", httponly=True)
+ self.set_cookie("d", "1", httponly=False)
+
return [("/set", SetCookieHandler),
("/get", GetCookieHandler),
("/set_domain", SetCookieDomainHandler),
@@ -178,6 +234,7 @@ class CookieTest(WebTestCase):
("/set_overwrite", SetCookieOverwriteHandler),
("/set_max_age", SetCookieMaxAgeHandler),
("/set_expires_days", SetCookieExpiresDaysHandler),
+ ("/set_falsy_flags", SetCookieFalsyFlags)
]
def test_set_cookie(self):
@@ -237,7 +294,7 @@ class CookieTest(WebTestCase):
headers = response.headers.get_list("Set-Cookie")
self.assertEqual(sorted(headers),
["foo=bar; Max-Age=10; Path=/"])
-
+
def test_set_cookie_expires_days(self):
response = self.fetch("/set_expires_days")
header = response.headers.get("Set-Cookie")
@@ -248,7 +305,17 @@ class CookieTest(WebTestCase):
header_expires = datetime.datetime(
*email.utils.parsedate(match.groupdict()["expires"])[:6])
self.assertTrue(abs(timedelta_to_seconds(expires - header_expires)) < 10)
-
+
+ def test_set_cookie_false_flags(self):
+ response = self.fetch("/set_falsy_flags")
+ headers = sorted(response.headers.get_list("Set-Cookie"))
+ # The secure and httponly headers are capitalized in py35 and
+ # lowercase in older versions.
+ self.assertEqual(headers[0].lower(), 'a=1; path=/; secure')
+ self.assertEqual(headers[1].lower(), 'b=1; path=/')
+ self.assertEqual(headers[2].lower(), 'c=1; httponly; path=/')
+ self.assertEqual(headers[3].lower(), 'd=1; path=/')
+
class AuthRedirectRequestHandler(RequestHandler):
def initialize(self, login_url):
@@ -305,7 +372,7 @@ class ConnectionCloseTest(WebTestCase):
def test_connection_close(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
- s.connect(("localhost", self.get_http_port()))
+ s.connect(("127.0.0.1", self.get_http_port()))
self.stream = IOStream(s, io_loop=self.io_loop)
self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
self.wait()
@@ -379,6 +446,12 @@ class RequestEncodingTest(WebTestCase):
path_args=["a/b", "c/d"],
args={}))
+ def test_error(self):
+ # Percent signs (encoded as %25) should not mess up printf-style
+ # messages in logs
+ with ExpectLog(gen_log, ".*Invalid unicode"):
+ self.fetch("/group/?arg=%25%e9")
+
class TypeCheckHandler(RequestHandler):
def prepare(self):
@@ -579,6 +652,7 @@ class WSGISafeWebTest(WebTestCase):
url("/redirect", RedirectHandler),
url("/web_redirect_permanent", WebRedirectHandler, {"url": "/web_redirect_newpath"}),
url("/web_redirect", WebRedirectHandler, {"url": "/web_redirect_newpath", "permanent": False}),
+ url("//web_redirect_double_slash", WebRedirectHandler, {"url": '/web_redirect_newpath'}),
url("/header_injection", HeaderInjectionHandler),
url("/get_argument", GetArgumentHandler),
url("/get_arguments", GetArgumentsHandler),
@@ -712,6 +786,11 @@ js_embed()
self.assertEqual(response.code, 302)
self.assertEqual(response.headers['Location'], '/web_redirect_newpath')
+ def test_web_redirect_double_slash(self):
+ response = self.fetch("//web_redirect_double_slash", follow_redirects=False)
+ self.assertEqual(response.code, 301)
+ self.assertEqual(response.headers['Location'], '/web_redirect_newpath')
+
def test_header_injection(self):
response = self.fetch("/header_injection")
self.assertEqual(response.body, b"ok")
@@ -1517,6 +1596,22 @@ class ExceptionHandlerTest(SimpleHandlerTestCase):
self.assertEqual(response.code, 403)
+@wsgi_safe
+class BuggyLoggingTest(SimpleHandlerTestCase):
+ class Handler(RequestHandler):
+ def get(self):
+ 1/0
+
+ def log_exception(self, typ, value, tb):
+ 1/0
+
+ def test_buggy_log_exception(self):
+ # Something gets logged even though the application's
+ # logger is broken.
+ with ExpectLog(app_log, '.*'):
+ self.fetch('/')
+
+
@wsgi_safe
class UIMethodUIModuleTest(SimpleHandlerTestCase):
"""Test that UI methods and modules are created correctly and
@@ -1533,6 +1628,7 @@ class UIMethodUIModuleTest(SimpleHandlerTestCase):
def my_ui_method(handler, x):
return "In my_ui_method(%s) with handler value %s." % (
x, handler.value())
+
class MyModule(UIModule):
def render(self, x):
return "In MyModule(%s) with handler value %s." % (
@@ -1907,7 +2003,7 @@ class StreamingRequestBodyTest(WebTestCase):
def connect(self, url, connection_close):
# Use a raw connection so we can control the sending of data.
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
- s.connect(("localhost", self.get_http_port()))
+ s.connect(("127.0.0.1", self.get_http_port()))
stream = IOStream(s, io_loop=self.io_loop)
stream.write(b"GET " + url + b" HTTP/1.1\r\n")
if connection_close:
@@ -1988,8 +2084,10 @@ class StreamingRequestFlowControlTest(WebTestCase):
@gen.coroutine
def prepare(self):
- with self.in_method('prepare'):
- yield gen.Task(IOLoop.current().add_callback)
+ # Note that asynchronous prepare() does not block data_received,
+ # so we don't use in_method here.
+ self.methods.append('prepare')
+ yield gen.Task(IOLoop.current().add_callback)
@gen.coroutine
def data_received(self, data):
@@ -2051,9 +2149,10 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase):
# When the content-length is too high, the connection is simply
# closed without completing the response. An error is logged on
# the server.
- with ExpectLog(app_log, "Uncaught exception"):
+ with ExpectLog(app_log, "(Uncaught exception|Exception in callback)"):
with ExpectLog(gen_log,
- "Cannot send error response after headers written"):
+ "(Cannot send error response after headers written"
+ "|Failed to flush partial response)"):
response = self.fetch("/high")
self.assertEqual(response.code, 599)
self.assertEqual(str(self.server_error),
@@ -2063,9 +2162,10 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase):
# When the content-length is too low, the connection is closed
# without writing the last chunk, so the client never sees the request
# complete (which would be a framing error).
- with ExpectLog(app_log, "Uncaught exception"):
+ with ExpectLog(app_log, "(Uncaught exception|Exception in callback)"):
with ExpectLog(gen_log,
- "Cannot send error response after headers written"):
+ "(Cannot send error response after headers written"
+ "|Failed to flush partial response)"):
response = self.fetch("/low")
self.assertEqual(response.code, 599)
self.assertEqual(str(self.server_error),
@@ -2075,21 +2175,28 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase):
class ClientCloseTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
- # Simulate a connection closed by the client during
- # request processing. The client will see an error, but the
- # server should respond gracefully (without logging errors
- # because we were unable to write out as many bytes as
- # Content-Length said we would)
- self.request.connection.stream.close()
- self.write('hello')
+ if self.request.version.startswith('HTTP/1'):
+ # Simulate a connection closed by the client during
+ # request processing. The client will see an error, but the
+ # server should respond gracefully (without logging errors
+ # because we were unable to write out as many bytes as
+ # Content-Length said we would)
+ self.request.connection.stream.close()
+ self.write('hello')
+ else:
+ # TODO: add a HTTP2-compatible version of this test.
+ self.write('requires HTTP/1.x')
def test_client_close(self):
response = self.fetch('/')
+ if response.body == b'requires HTTP/1.x':
+ self.skipTest('requires HTTP/1.x')
self.assertEqual(response.code, 599)
class SignedValueTest(unittest.TestCase):
SECRET = "It's a secret to everybody"
+ SECRET_DICT = {0: "asdfbasdf", 1: "12312312", 2: "2342342"}
def past(self):
return self.present() - 86400 * 32
@@ -2151,6 +2258,7 @@ class SignedValueTest(unittest.TestCase):
def test_payload_tampering(self):
# These cookies are variants of the one in test_known_values.
sig = "3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152"
+
def validate(prefix):
return (b'value' ==
decode_signed_value(SignedValueTest.SECRET, "key",
@@ -2165,6 +2273,7 @@ class SignedValueTest(unittest.TestCase):
def test_signature_tampering(self):
prefix = "2|1:0|10:1300000000|3:key|8:dmFsdWU=|"
+
def validate(sig):
return (b'value' ==
decode_signed_value(SignedValueTest.SECRET, "key",
@@ -2194,6 +2303,43 @@ class SignedValueTest(unittest.TestCase):
clock=self.present)
self.assertEqual(value, decoded)
+ def test_key_versioning_read_write_default_key(self):
+ value = b"\xe9"
+ signed = create_signed_value(SignedValueTest.SECRET_DICT,
+ "key", value, clock=self.present,
+ key_version=0)
+ decoded = decode_signed_value(SignedValueTest.SECRET_DICT,
+ "key", signed, clock=self.present)
+ self.assertEqual(value, decoded)
+
+ def test_key_versioning_read_write_non_default_key(self):
+ value = b"\xe9"
+ signed = create_signed_value(SignedValueTest.SECRET_DICT,
+ "key", value, clock=self.present,
+ key_version=1)
+ decoded = decode_signed_value(SignedValueTest.SECRET_DICT,
+ "key", signed, clock=self.present)
+ self.assertEqual(value, decoded)
+
+ def test_key_versioning_invalid_key(self):
+ value = b"\xe9"
+ signed = create_signed_value(SignedValueTest.SECRET_DICT,
+ "key", value, clock=self.present,
+ key_version=0)
+ newkeys = SignedValueTest.SECRET_DICT.copy()
+ newkeys.pop(0)
+ decoded = decode_signed_value(newkeys,
+ "key", signed, clock=self.present)
+ self.assertEqual(None, decoded)
+
+ def test_key_version_retrieval(self):
+ value = b"\xe9"
+ signed = create_signed_value(SignedValueTest.SECRET_DICT,
+ "key", value, clock=self.present,
+ key_version=1)
+ key_version = get_signature_key_version(signed)
+ self.assertEqual(1, key_version)
+
@wsgi_safe
class XSRFTest(SimpleHandlerTestCase):
@@ -2296,7 +2442,7 @@ class XSRFTest(SimpleHandlerTestCase):
token2 = self.get_token()
# Each token can be used to authenticate its own request.
for token in (self.xsrf_token, token2):
- response = self.fetch(
+ response = self.fetch(
"/", method="POST",
body=urllib_parse.urlencode(dict(_xsrf=token)),
headers=self.cookie_headers(token))
@@ -2372,6 +2518,7 @@ class FinishExceptionTest(SimpleHandlerTestCase):
self.assertEqual(b'authentication required', response.body)
+@wsgi_safe
class DecoratorTest(WebTestCase):
def get_handlers(self):
class RemoveSlashHandler(RequestHandler):
@@ -2405,3 +2552,85 @@ class DecoratorTest(WebTestCase):
response = self.fetch("/addslash?foo=bar", follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], "/addslash/?foo=bar")
+
+
+@wsgi_safe
+class CacheTest(WebTestCase):
+ def get_handlers(self):
+ class EtagHandler(RequestHandler):
+ def get(self, computed_etag):
+ self.write(computed_etag)
+
+ def compute_etag(self):
+ return self._write_buffer[0]
+
+ return [
+ ('/etag/(.*)', EtagHandler)
+ ]
+
+ def test_wildcard_etag(self):
+ computed_etag = '"xyzzy"'
+ etags = '*'
+ self._test_etag(computed_etag, etags, 304)
+
+ def test_strong_etag_match(self):
+ computed_etag = '"xyzzy"'
+ etags = '"xyzzy"'
+ self._test_etag(computed_etag, etags, 304)
+
+ def test_multiple_strong_etag_match(self):
+ computed_etag = '"xyzzy1"'
+ etags = '"xyzzy1", "xyzzy2"'
+ self._test_etag(computed_etag, etags, 304)
+
+ def test_strong_etag_not_match(self):
+ computed_etag = '"xyzzy"'
+ etags = '"xyzzy1"'
+ self._test_etag(computed_etag, etags, 200)
+
+ def test_multiple_strong_etag_not_match(self):
+ computed_etag = '"xyzzy"'
+ etags = '"xyzzy1", "xyzzy2"'
+ self._test_etag(computed_etag, etags, 200)
+
+ def test_weak_etag_match(self):
+ computed_etag = '"xyzzy1"'
+ etags = 'W/"xyzzy1"'
+ self._test_etag(computed_etag, etags, 304)
+
+ def test_multiple_weak_etag_match(self):
+ computed_etag = '"xyzzy2"'
+ etags = 'W/"xyzzy1", W/"xyzzy2"'
+ self._test_etag(computed_etag, etags, 304)
+
+ def test_weak_etag_not_match(self):
+ computed_etag = '"xyzzy2"'
+ etags = 'W/"xyzzy1"'
+ self._test_etag(computed_etag, etags, 200)
+
+ def test_multiple_weak_etag_not_match(self):
+ computed_etag = '"xyzzy3"'
+ etags = 'W/"xyzzy1", W/"xyzzy2"'
+ self._test_etag(computed_etag, etags, 200)
+
+ def _test_etag(self, computed_etag, etags, status_code):
+ response = self.fetch(
+ '/etag/' + computed_etag,
+ headers={'If-None-Match': etags}
+ )
+ self.assertEqual(response.code, status_code)
+
+
+@wsgi_safe
+class RequestSummaryTest(SimpleHandlerTestCase):
+ class Handler(RequestHandler):
+ def get(self):
+ # remote_ip is optional, although it's set by
+ # both HTTPServer and WSGIAdapter.
+ # Clobber it to make sure it doesn't break logging.
+ self.request.remote_ip = None
+ self.finish(self._request_summary())
+
+ def test_missing_remote_ip(self):
+ resp = self.fetch("/")
+ self.assertEqual(resp.body, b"GET / (None)")
diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py
index e1e3ea70..6b182d07 100644
--- a/tornado/test/websocket_test.py
+++ b/tornado/test/websocket_test.py
@@ -12,7 +12,7 @@ from tornado.web import Application, RequestHandler
from tornado.util import u
try:
- import tornado.websocket
+ import tornado.websocket # noqa
from tornado.util import _websocket_mask_python
except ImportError:
# The unittest module presents misleading errors on ImportError
@@ -53,7 +53,7 @@ class EchoHandler(TestWebSocketHandler):
class ErrorInOnMessageHandler(TestWebSocketHandler):
def on_message(self, message):
- 1/0
+ 1 / 0
class HeaderHandler(TestWebSocketHandler):
@@ -75,6 +75,7 @@ class NonWebSocketHandler(RequestHandler):
class CloseReasonHandler(TestWebSocketHandler):
def open(self):
+ self.on_close_called = False
self.close(1001, "goodbye")
@@ -91,7 +92,7 @@ class WebSocketBaseTestCase(AsyncHTTPTestCase):
@gen.coroutine
def ws_connect(self, path, compression_options=None):
ws = yield websocket_connect(
- 'ws://localhost:%d%s' % (self.get_http_port(), path),
+ 'ws://127.0.0.1:%d%s' % (self.get_http_port(), path),
compression_options=compression_options)
raise gen.Return(ws)
@@ -105,6 +106,7 @@ class WebSocketBaseTestCase(AsyncHTTPTestCase):
ws.close()
yield self.close_future
+
class WebSocketTest(WebSocketBaseTestCase):
def get_app(self):
self.close_future = Future()
@@ -135,7 +137,7 @@ class WebSocketTest(WebSocketBaseTestCase):
def test_websocket_callbacks(self):
websocket_connect(
- 'ws://localhost:%d/echo' % self.get_http_port(),
+ 'ws://127.0.0.1:%d/echo' % self.get_http_port(),
io_loop=self.io_loop, callback=self.stop)
ws = self.wait().result()
ws.write_message('hello')
@@ -189,14 +191,14 @@ class WebSocketTest(WebSocketBaseTestCase):
with self.assertRaises(IOError):
with ExpectLog(gen_log, ".*"):
yield websocket_connect(
- 'ws://localhost:%d/' % port,
+ 'ws://127.0.0.1:%d/' % port,
io_loop=self.io_loop,
connect_timeout=3600)
@gen_test
def test_websocket_close_buffered_data(self):
ws = yield websocket_connect(
- 'ws://localhost:%d/echo' % self.get_http_port())
+ 'ws://127.0.0.1:%d/echo' % self.get_http_port())
ws.write_message('hello')
ws.write_message('world')
# Close the underlying stream.
@@ -207,7 +209,7 @@ class WebSocketTest(WebSocketBaseTestCase):
def test_websocket_headers(self):
# Ensure that arbitrary headers can be passed through websocket_connect.
ws = yield websocket_connect(
- HTTPRequest('ws://localhost:%d/header' % self.get_http_port(),
+ HTTPRequest('ws://127.0.0.1:%d/header' % self.get_http_port(),
headers={'X-Test': 'hello'}))
response = yield ws.read_message()
self.assertEqual(response, 'hello')
@@ -221,6 +223,8 @@ class WebSocketTest(WebSocketBaseTestCase):
self.assertIs(msg, None)
self.assertEqual(ws.close_code, 1001)
self.assertEqual(ws.close_reason, "goodbye")
+ # The on_close callback is called no matter which side closed.
+ yield self.close_future
@gen_test
def test_client_close_reason(self):
@@ -243,11 +247,11 @@ class WebSocketTest(WebSocketBaseTestCase):
def test_check_origin_valid_no_path(self):
port = self.get_http_port()
- url = 'ws://localhost:%d/echo' % port
- headers = {'Origin': 'http://localhost:%d' % port}
+ url = 'ws://127.0.0.1:%d/echo' % port
+ headers = {'Origin': 'http://127.0.0.1:%d' % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
- io_loop=self.io_loop)
+ io_loop=self.io_loop)
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
@@ -257,11 +261,11 @@ class WebSocketTest(WebSocketBaseTestCase):
def test_check_origin_valid_with_path(self):
port = self.get_http_port()
- url = 'ws://localhost:%d/echo' % port
- headers = {'Origin': 'http://localhost:%d/something' % port}
+ url = 'ws://127.0.0.1:%d/echo' % port
+ headers = {'Origin': 'http://127.0.0.1:%d/something' % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
- io_loop=self.io_loop)
+ io_loop=self.io_loop)
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
@@ -271,8 +275,8 @@ class WebSocketTest(WebSocketBaseTestCase):
def test_check_origin_invalid_partial_url(self):
port = self.get_http_port()
- url = 'ws://localhost:%d/echo' % port
- headers = {'Origin': 'localhost:%d' % port}
+ url = 'ws://127.0.0.1:%d/echo' % port
+ headers = {'Origin': '127.0.0.1:%d' % port}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers),
@@ -283,8 +287,8 @@ class WebSocketTest(WebSocketBaseTestCase):
def test_check_origin_invalid(self):
port = self.get_http_port()
- url = 'ws://localhost:%d/echo' % port
- # Host is localhost, which should not be accessible from some other
+ url = 'ws://127.0.0.1:%d/echo' % port
+ # Host is 127.0.0.1, which should not be accessible from some other
# domain
headers = {'Origin': 'http://somewhereelse.com'}
diff --git a/tornado/testing.py b/tornado/testing.py
index 4d85abe9..93f0dbe1 100644
--- a/tornado/testing.py
+++ b/tornado/testing.py
@@ -19,6 +19,7 @@ try:
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.ioloop import IOLoop, TimeoutError
from tornado import netutil
+ from tornado.process import Subprocess
except ImportError:
# These modules are not importable on app engine. Parts of this module
# won't work, but e.g. LogTrapTestCase and main() will.
@@ -28,6 +29,7 @@ except ImportError:
IOLoop = None
netutil = None
SimpleAsyncHTTPClient = None
+ Subprocess = None
from tornado.log import gen_log, app_log
from tornado.stack_context import ExceptionStackContext
from tornado.util import raise_exc_info, basestring_type
@@ -214,6 +216,8 @@ class AsyncTestCase(unittest.TestCase):
self.io_loop.make_current()
def tearDown(self):
+ # Clean up Subprocess, so it can be used again with a new ioloop.
+ Subprocess.uninitialize()
self.io_loop.clear_current()
if (not IOLoop.initialized() or
self.io_loop is not IOLoop.instance()):
@@ -413,10 +417,8 @@ class AsyncHTTPSTestCase(AsyncHTTPTestCase):
Interface is generally the same as `AsyncHTTPTestCase`.
"""
def get_http_client(self):
- # Some versions of libcurl have deadlock bugs with ssl,
- # so always run these tests with SimpleAsyncHTTPClient.
- return SimpleAsyncHTTPClient(io_loop=self.io_loop, force_instance=True,
- defaults=dict(validate_cert=False))
+ return AsyncHTTPClient(io_loop=self.io_loop, force_instance=True,
+ defaults=dict(validate_cert=False))
def get_httpserver_options(self):
return dict(ssl_options=self.get_ssl_options())
@@ -539,6 +541,9 @@ class LogTrapTestCase(unittest.TestCase):
`logging.basicConfig` and the "pretty logging" configured by
`tornado.options`. It is not compatible with other log buffering
mechanisms, such as those provided by some test runners.
+
+ .. deprecated:: 4.1
+ Use the unittest module's ``--buffer`` option instead, or `.ExpectLog`.
"""
def run(self, result=None):
logger = logging.getLogger()
diff --git a/tornado/util.py b/tornado/util.py
index 34c4b072..606ced19 100644
--- a/tornado/util.py
+++ b/tornado/util.py
@@ -78,6 +78,25 @@ class GzipDecompressor(object):
return self.decompressobj.flush()
+# Fake unicode literal support: Python 3.2 doesn't have the u'' marker for
+# literal strings, and alternative solutions like "from __future__ import
+# unicode_literals" have other problems (see PEP 414). u() can be applied
+# to ascii strings that include \u escapes (but they must not contain
+# literal non-ascii characters).
+if not isinstance(b'', type('')):
+ def u(s):
+ return s
+ unicode_type = str
+ basestring_type = str
+else:
+ def u(s):
+ return s.decode('unicode_escape')
+ # These names don't exist in py3, so use noqa comments to disable
+ # warnings in flake8.
+ unicode_type = unicode # noqa
+ basestring_type = basestring # noqa
+
+
def import_object(name):
"""Imports an object by name.
@@ -96,6 +115,9 @@ def import_object(name):
...
ImportError: No module named missing_module
"""
+ if isinstance(name, unicode_type) and str is not unicode_type:
+ # On python 2 a byte string is required.
+ name = name.encode('utf-8')
if name.count('.') == 0:
return __import__(name, None, None)
@@ -107,22 +129,6 @@ def import_object(name):
raise ImportError("No module named %s" % parts[-1])
-# Fake unicode literal support: Python 3.2 doesn't have the u'' marker for
-# literal strings, and alternative solutions like "from __future__ import
-# unicode_literals" have other problems (see PEP 414). u() can be applied
-# to ascii strings that include \u escapes (but they must not contain
-# literal non-ascii characters).
-if type('') is not type(b''):
- def u(s):
- return s
- unicode_type = str
- basestring_type = str
-else:
- def u(s):
- return s.decode('unicode_escape')
- unicode_type = unicode
- basestring_type = basestring
-
# Deprecated alias that was used before we dropped py25 support.
# Left here in case anyone outside Tornado is using it.
bytes_type = bytes
@@ -192,21 +198,21 @@ class Configurable(object):
__impl_class = None
__impl_kwargs = None
- def __new__(cls, **kwargs):
+ def __new__(cls, *args, **kwargs):
base = cls.configurable_base()
- args = {}
+ init_kwargs = {}
if cls is base:
impl = cls.configured_class()
if base.__impl_kwargs:
- args.update(base.__impl_kwargs)
+ init_kwargs.update(base.__impl_kwargs)
else:
impl = cls
- args.update(kwargs)
+ init_kwargs.update(kwargs)
instance = super(Configurable, cls).__new__(impl)
# initialize vs __init__ chosen for compatibility with AsyncHTTPClient
# singleton magic. If we get rid of that we can switch to __init__
# here too.
- instance.initialize(**args)
+ instance.initialize(*args, **init_kwargs)
return instance
@classmethod
@@ -227,6 +233,9 @@ class Configurable(object):
"""Initialize a `Configurable` subclass instance.
Configurable classes should use `initialize` instead of ``__init__``.
+
+ .. versionchanged:: 4.2
+ Now accepts positional arguments in addition to keyword arguments.
"""
@classmethod
@@ -339,7 +348,7 @@ def _websocket_mask_python(mask, data):
return unmasked.tostring()
if (os.environ.get('TORNADO_NO_EXTENSION') or
- os.environ.get('TORNADO_EXTENSION') == '0'):
+ os.environ.get('TORNADO_EXTENSION') == '0'):
# These environment variables exist to make it easier to do performance
# comparisons; they are not guaranteed to remain supported in the future.
_websocket_mask = _websocket_mask_python
diff --git a/tornado/web.py b/tornado/web.py
index 9edd719a..9bc12933 100644
--- a/tornado/web.py
+++ b/tornado/web.py
@@ -19,7 +19,9 @@ features that allow it to scale to large numbers of open connections,
making it ideal for `long polling
`_.
-Here is a simple "Hello, world" example app::
+Here is a simple "Hello, world" example app:
+
+.. testcode::
import tornado.ioloop
import tornado.web
@@ -33,7 +35,11 @@ Here is a simple "Hello, world" example app::
(r"/", MainHandler),
])
application.listen(8888)
- tornado.ioloop.IOLoop.instance().start()
+ tornado.ioloop.IOLoop.current().start()
+
+.. testoutput::
+ :hide:
+
See the :doc:`guide` for additional information.
@@ -50,7 +56,8 @@ request.
"""
-from __future__ import absolute_import, division, print_function, with_statement
+from __future__ import (absolute_import, division,
+ print_function, with_statement)
import base64
@@ -84,7 +91,9 @@ from tornado.log import access_log, app_log, gen_log
from tornado import stack_context
from tornado import template
from tornado.escape import utf8, _unicode
-from tornado.util import import_object, ObjectDict, raise_exc_info, unicode_type, _websocket_mask
+from tornado.util import (import_object, ObjectDict, raise_exc_info,
+ unicode_type, _websocket_mask)
+from tornado.httputil import split_host_and_port
try:
@@ -130,12 +139,11 @@ May be overridden by passing a ``version`` keyword argument.
DEFAULT_SIGNED_VALUE_MIN_VERSION = 1
"""The oldest signed value accepted by `.RequestHandler.get_secure_cookie`.
-May be overrided by passing a ``min_version`` keyword argument.
+May be overridden by passing a ``min_version`` keyword argument.
.. versionadded:: 3.2.1
"""
-
class RequestHandler(object):
"""Subclass this class and define `get()` or `post()` to make a handler.
@@ -267,6 +275,7 @@ class RequestHandler(object):
if _has_stream_request_body(self.__class__):
if not self.request.body.done():
self.request.body.set_exception(iostream.StreamClosedError())
+ self.request.body.exception()
def clear(self):
"""Resets all headers and content for this response."""
@@ -382,6 +391,12 @@ class RequestHandler(object):
The returned values are always unicode.
"""
+
+ # Make sure `get_arguments` isn't accidentally being called with a
+ # positional argument that's assumed to be a default (like in
+ # `get_argument`.)
+ assert isinstance(strip, bool)
+
return self._get_arguments(name, self.request.arguments, strip)
def get_body_argument(self, name, default=_ARG_DEFAULT, strip=True):
@@ -398,7 +413,8 @@ class RequestHandler(object):
.. versionadded:: 3.2
"""
- return self._get_argument(name, default, self.request.body_arguments, strip)
+ return self._get_argument(name, default, self.request.body_arguments,
+ strip)
def get_body_arguments(self, name, strip=True):
"""Returns a list of the body arguments with the given name.
@@ -425,7 +441,8 @@ class RequestHandler(object):
.. versionadded:: 3.2
"""
- return self._get_argument(name, default, self.request.query_arguments, strip)
+ return self._get_argument(name, default,
+ self.request.query_arguments, strip)
def get_query_arguments(self, name, strip=True):
"""Returns a list of the query arguments with the given name.
@@ -480,7 +497,8 @@ class RequestHandler(object):
@property
def cookies(self):
- """An alias for `self.request.cookies <.httputil.HTTPServerRequest.cookies>`."""
+ """An alias for
+ `self.request.cookies <.httputil.HTTPServerRequest.cookies>`."""
return self.request.cookies
def get_cookie(self, name, default=None):
@@ -522,6 +540,12 @@ class RequestHandler(object):
for k, v in kwargs.items():
if k == 'max_age':
k = 'max-age'
+
+ # skip falsy values for httponly and secure flags because
+ # SimpleCookie sets them regardless
+ if k in ['httponly', 'secure'] and not v:
+ continue
+
morsel[k] = v
def clear_cookie(self, name, path="/", domain=None):
@@ -588,8 +612,15 @@ class RequestHandler(object):
and made it the default.
"""
self.require_setting("cookie_secret", "secure cookies")
- return create_signed_value(self.application.settings["cookie_secret"],
- name, value, version=version)
+ secret = self.application.settings["cookie_secret"]
+ key_version = None
+ if isinstance(secret, dict):
+ if self.application.settings.get("key_version") is None:
+ raise Exception("key_version setting must be used for secret_key dicts")
+ key_version = self.application.settings["key_version"]
+
+ return create_signed_value(secret, name, value, version=version,
+ key_version=key_version)
def get_secure_cookie(self, name, value=None, max_age_days=31,
min_version=None):
@@ -610,6 +641,17 @@ class RequestHandler(object):
name, value, max_age_days=max_age_days,
min_version=min_version)
+ def get_secure_cookie_key_version(self, name, value=None):
+ """Returns the signing key version of the secure cookie.
+
+ The version is returned as int.
+ """
+ self.require_setting("cookie_secret", "secure cookies")
+ if value is None:
+ value = self.get_cookie(name)
+ return get_signature_key_version(value)
+
+
def redirect(self, url, permanent=False, status=None):
"""Sends a redirect to the given (optionally relative) URL.
@@ -625,8 +667,7 @@ class RequestHandler(object):
else:
assert isinstance(status, int) and 300 <= status <= 399
self.set_status(status)
- self.set_header("Location", urlparse.urljoin(utf8(self.request.uri),
- utf8(url)))
+ self.set_header("Location", utf8(url))
self.finish()
def write(self, chunk):
@@ -646,16 +687,14 @@ class RequestHandler(object):
https://github.com/facebook/tornado/issues/1009
"""
if self._finished:
- raise RuntimeError("Cannot write() after finish(). May be caused "
- "by using async operations without the "
- "@asynchronous decorator.")
+ raise RuntimeError("Cannot write() after finish()")
if not isinstance(chunk, (bytes, unicode_type, dict)):
- raise TypeError("write() only accepts bytes, unicode, and dict objects")
+ message = "write() only accepts bytes, unicode, and dict objects"
+ if isinstance(chunk, list):
+ message += ". Lists not accepted for security reasons; see http://www.tornadoweb.org/en/stable/web.html#tornado.web.RequestHandler.write"
+ raise TypeError(message)
if isinstance(chunk, dict):
- if 'unwrap_json' in chunk:
- chunk = chunk['unwrap_json']
- else:
- chunk = escape.json_encode(chunk)
+ chunk = escape.json_encode(chunk)
self.set_header("Content-Type", "application/json; charset=UTF-8")
chunk = utf8(chunk)
self._write_buffer.append(chunk)
@@ -786,6 +825,7 @@ class RequestHandler(object):
current_user=self.current_user,
locale=self.locale,
_=self.locale.translate,
+ pgettext=self.locale.pgettext,
static_url=self.static_url,
xsrf_form_html=self.xsrf_form_html,
reverse_url=self.reverse_url
@@ -830,7 +870,8 @@ class RequestHandler(object):
for transform in self._transforms:
self._status_code, self._headers, chunk = \
transform.transform_first_chunk(
- self._status_code, self._headers, chunk, include_footers)
+ self._status_code, self._headers,
+ chunk, include_footers)
# Ignore the chunk and only write the headers for HEAD requests
if self.request.method == "HEAD":
chunk = None
@@ -842,7 +883,7 @@ class RequestHandler(object):
for cookie in self._new_cookie.values():
self.add_header("Set-Cookie", cookie.OutputString(None))
- start_line = httputil.ResponseStartLine(self.request.version,
+ start_line = httputil.ResponseStartLine('',
self._status_code,
self._reason)
return self.request.connection.write_headers(
@@ -861,9 +902,7 @@ class RequestHandler(object):
def finish(self, chunk=None):
"""Finishes this response, ending the HTTP request."""
if self._finished:
- raise RuntimeError("finish() called twice. May be caused "
- "by using async operations without the "
- "@asynchronous decorator.")
+ raise RuntimeError("finish() called twice")
if chunk is not None:
self.write(chunk)
@@ -915,7 +954,15 @@ class RequestHandler(object):
if self._headers_written:
gen_log.error("Cannot send error response after headers written")
if not self._finished:
- self.finish()
+ # If we get an error between writing headers and finishing,
+ # we are unlikely to be able to finish due to a
+ # Content-Length mismatch. Try anyway to release the
+ # socket.
+ try:
+ self.finish()
+ except Exception:
+ gen_log.error("Failed to flush partial response",
+ exc_info=True)
return
self.clear()
@@ -1122,28 +1169,37 @@ class RequestHandler(object):
"""Convert a cookie string into a the tuple form returned by
_get_raw_xsrf_token.
"""
- m = _signed_value_version_re.match(utf8(cookie))
- if m:
- version = int(m.group(1))
- if version == 2:
- _, mask, masked_token, timestamp = cookie.split("|")
- mask = binascii.a2b_hex(utf8(mask))
- token = _websocket_mask(
- mask, binascii.a2b_hex(utf8(masked_token)))
- timestamp = int(timestamp)
- return version, token, timestamp
+
+ try:
+ m = _signed_value_version_re.match(utf8(cookie))
+
+ if m:
+ version = int(m.group(1))
+ if version == 2:
+ _, mask, masked_token, timestamp = cookie.split("|")
+
+ mask = binascii.a2b_hex(utf8(mask))
+ token = _websocket_mask(
+ mask, binascii.a2b_hex(utf8(masked_token)))
+ timestamp = int(timestamp)
+ return version, token, timestamp
+ else:
+ # Treat unknown versions as not present instead of failing.
+ raise Exception("Unknown xsrf cookie version")
else:
- # Treat unknown versions as not present instead of failing.
- return None, None, None
- else:
- version = 1
- try:
- token = binascii.a2b_hex(utf8(cookie))
- except (binascii.Error, TypeError):
- token = utf8(cookie)
- # We don't have a usable timestamp in older versions.
- timestamp = int(time.time())
- return (version, token, timestamp)
+ version = 1
+ try:
+ token = binascii.a2b_hex(utf8(cookie))
+ except (binascii.Error, TypeError):
+ token = utf8(cookie)
+ # We don't have a usable timestamp in older versions.
+ timestamp = int(time.time())
+ return (version, token, timestamp)
+ except Exception:
+ # Catch exceptions and return nothing instead of failing.
+ gen_log.debug("Uncaught exception in _decode_xsrf_token",
+ exc_info=True)
+ return None, None, None
def check_xsrf_cookie(self):
"""Verifies that the ``_xsrf`` cookie matches the ``_xsrf`` argument.
@@ -1282,9 +1338,27 @@ class RequestHandler(object):
before completing the request. The ``Etag`` header should be set
(perhaps with `set_etag_header`) before calling this method.
"""
- etag = self._headers.get("Etag")
- inm = utf8(self.request.headers.get("If-None-Match", ""))
- return bool(etag and inm and inm.find(etag) >= 0)
+ computed_etag = utf8(self._headers.get("Etag", ""))
+ # Find all weak and strong etag values from If-None-Match header
+ # because RFC 7232 allows multiple etag values in a single header.
+ etags = re.findall(
+ br'\*|(?:W/)?"[^"]*"',
+ utf8(self.request.headers.get("If-None-Match", ""))
+ )
+ if not computed_etag or not etags:
+ return False
+
+ match = False
+ if etags[0] == b'*':
+ match = True
+ else:
+ # Use a weak comparison when comparing entity-tags.
+ val = lambda x: x[2:] if x.startswith(b'W/') else x
+ for etag in etags:
+ if val(etag) == val(computed_etag):
+ match = True
+ break
+ return match
def _stack_context_handle_exception(self, type, value, traceback):
try:
@@ -1344,7 +1418,10 @@ class RequestHandler(object):
if self._auto_finish and not self._finished:
self.finish()
except Exception as e:
- self._handle_request_exception(e)
+ try:
+ self._handle_request_exception(e)
+ except Exception:
+ app_log.error("Exception in exception handler", exc_info=True)
if (self._prepared_future is not None and
not self._prepared_future.done()):
# In case we failed before setting _prepared_future, do it
@@ -1369,8 +1446,8 @@ class RequestHandler(object):
self.application.log_request(self)
def _request_summary(self):
- return self.request.method + " " + self.request.uri + \
- " (" + self.request.remote_ip + ")"
+ return "%s %s (%s)" % (self.request.method, self.request.uri,
+ self.request.remote_ip)
def _handle_request_exception(self, e):
if isinstance(e, Finish):
@@ -1378,7 +1455,12 @@ class RequestHandler(object):
if not self._finished:
self.finish()
return
- self.log_exception(*sys.exc_info())
+ try:
+ self.log_exception(*sys.exc_info())
+ except Exception:
+ # An error here should still get a best-effort send_error()
+ # to avoid leaking the connection.
+ app_log.error("Error in exception logger", exc_info=True)
if self._finished:
# Extra errors after the request has been finished should
# be logged, but there is no reason to continue to try and
@@ -1441,10 +1523,11 @@ class RequestHandler(object):
def asynchronous(method):
"""Wrap request handler methods with this if they are asynchronous.
- This decorator is unnecessary if the method is also decorated with
- ``@gen.coroutine`` (it is legal but unnecessary to use the two
- decorators together, in which case ``@asynchronous`` must be
- first).
+ This decorator is for callback-style asynchronous methods; for
+ coroutines, use the ``@gen.coroutine`` decorator without
+ ``@asynchronous``. (It is legal for legacy reasons to use the two
+ decorators together provided ``@asynchronous`` is first, but
+ ``@asynchronous`` will be ignored in this case)
This decorator should only be applied to the :ref:`HTTP verb
methods `; its behavior is undefined for any other method.
@@ -1457,10 +1540,12 @@ def asynchronous(method):
method returns. It is up to the request handler to call
`self.finish() ` to finish the HTTP
request. Without this decorator, the request is automatically
- finished when the ``get()`` or ``post()`` method returns. Example::
+ finished when the ``get()`` or ``post()`` method returns. Example:
- class MyRequestHandler(web.RequestHandler):
- @web.asynchronous
+ .. testcode::
+
+ class MyRequestHandler(RequestHandler):
+ @asynchronous
def get(self):
http = httpclient.AsyncHTTPClient()
http.fetch("http://friendfeed.com/", self._on_download)
@@ -1469,18 +1554,23 @@ def asynchronous(method):
self.write("Downloaded!")
self.finish()
+ .. testoutput::
+ :hide:
+
.. versionadded:: 3.1
The ability to use ``@gen.coroutine`` without ``@asynchronous``.
+
"""
# Delay the IOLoop import because it's not available on app engine.
from tornado.ioloop import IOLoop
+
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
self._auto_finish = False
with stack_context.ExceptionStackContext(
self._stack_context_handle_exception):
result = method(self, *args, **kwargs)
- if isinstance(result, Future):
+ if is_future(result):
# If @asynchronous is used with @gen.coroutine, (but
# not @gen.engine), we can automatically finish the
# request when the future resolves. Additionally,
@@ -1521,7 +1611,7 @@ def stream_request_body(cls):
the entire body has been read.
There is a subtle interaction between ``data_received`` and asynchronous
- ``prepare``: The first call to ``data_recieved`` may occur at any point
+ ``prepare``: The first call to ``data_received`` may occur at any point
after the call to ``prepare`` has returned *or yielded*.
"""
if not issubclass(cls, RequestHandler):
@@ -1591,7 +1681,7 @@ class Application(httputil.HTTPServerConnectionDelegate):
])
http_server = httpserver.HTTPServer(application)
http_server.listen(8080)
- ioloop.IOLoop.instance().start()
+ ioloop.IOLoop.current().start()
The constructor for this class takes in a list of `URLSpec` objects
or (regexp, request_class) tuples. When we receive requests, we
@@ -1689,7 +1779,7 @@ class Application(httputil.HTTPServerConnectionDelegate):
`.TCPServer.bind`/`.TCPServer.start` methods directly.
Note that after calling this method you still need to call
- ``IOLoop.instance().start()`` to start the server.
+ ``IOLoop.current().start()`` to start the server.
"""
# import is here rather than top level because HTTPServer
# is not importable on appengine
@@ -1732,7 +1822,7 @@ class Application(httputil.HTTPServerConnectionDelegate):
self.transforms.append(transform_class)
def _get_host_handlers(self, request):
- host = request.host.lower().split(':')[0]
+ host = split_host_and_port(request.host.lower())[0]
matches = []
for pattern, handlers in self.handlers:
if pattern.match(host):
@@ -1773,9 +1863,9 @@ class Application(httputil.HTTPServerConnectionDelegate):
except TypeError:
pass
- def start_request(self, connection):
+ def start_request(self, server_conn, request_conn):
# Modern HTTPServer interface
- return _RequestDispatcher(self, connection)
+ return _RequestDispatcher(self, request_conn)
def __call__(self, request):
# Legacy HTTPServer interface
@@ -1831,7 +1921,8 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate):
def headers_received(self, start_line, headers):
self.set_request(httputil.HTTPServerRequest(
- connection=self.connection, start_line=start_line, headers=headers))
+ connection=self.connection, start_line=start_line,
+ headers=headers))
if self.stream_request_body:
self.request.body = Future()
return self.execute()
@@ -1848,7 +1939,9 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate):
handlers = app._get_host_handlers(self.request)
if not handlers:
self.handler_class = RedirectHandler
- self.handler_kwargs = dict(url="http://" + app.default_host + "/")
+ self.handler_kwargs = dict(url="%s://%s/"
+ % (self.request.protocol,
+ app.default_host))
return
for spec in handlers:
match = spec.regex.match(self.request.path)
@@ -1914,11 +2007,14 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate):
if self.stream_request_body:
self.handler._prepared_future = Future()
# Note that if an exception escapes handler._execute it will be
- # trapped in the Future it returns (which we are ignoring here).
+ # trapped in the Future it returns (which we are ignoring here,
+ # leaving it to be logged when the Future is GC'd).
# However, that shouldn't happen because _execute has a blanket
# except handler, and we cannot easily access the IOLoop here to
- # call add_future.
- self.handler._execute(transforms, *self.path_args, **self.path_kwargs)
+ # call add_future (because of the requirement to remain compatible
+ # with WSGI)
+ f = self.handler._execute(transforms, *self.path_args,
+ **self.path_kwargs)
# If we are streaming the request body, then execute() is finished
# when the handler has prepared to receive the body. If not,
# it doesn't matter when execute() finishes (so we return None)
@@ -1952,6 +2048,8 @@ class HTTPError(Exception):
self.log_message = log_message
self.args = args
self.reason = kwargs.get('reason', None)
+ if log_message and not args:
+ self.log_message = log_message.replace('%', '%%')
def __str__(self):
message = "HTTP %d: %s" % (
@@ -2212,7 +2310,8 @@ class StaticFileHandler(RequestHandler):
if content_type:
self.set_header("Content-Type", content_type)
- cache_time = self.get_cache_time(self.path, self.modified, content_type)
+ cache_time = self.get_cache_time(self.path, self.modified,
+ content_type)
if cache_time > 0:
self.set_header("Expires", datetime.datetime.utcnow() +
datetime.timedelta(seconds=cache_time))
@@ -2381,7 +2480,8 @@ class StaticFileHandler(RequestHandler):
.. versionadded:: 3.1
"""
stat_result = self._stat()
- modified = datetime.datetime.utcfromtimestamp(stat_result[stat.ST_MTIME])
+ modified = datetime.datetime.utcfromtimestamp(
+ stat_result[stat.ST_MTIME])
return modified
def get_content_type(self):
@@ -2624,6 +2724,8 @@ class UIModule(object):
UI modules often execute additional queries, and they can include
additional CSS and JavaScript that will be included in the output
page, which is automatically inserted on page render.
+
+ Subclasses of UIModule must override the `render` method.
"""
def __init__(self, handler):
self.handler = handler
@@ -2636,31 +2738,45 @@ class UIModule(object):
return self.handler.current_user
def render(self, *args, **kwargs):
- """Overridden in subclasses to return this module's output."""
+ """Override in subclasses to return this module's output."""
raise NotImplementedError()
def embedded_javascript(self):
- """Returns a JavaScript string that will be embedded in the page."""
+ """Override to return a JavaScript string
+ to be embedded in the page."""
return None
def javascript_files(self):
- """Returns a list of JavaScript files required by this module."""
+ """Override to return a list of JavaScript files needed by this module.
+
+ If the return values are relative paths, they will be passed to
+ `RequestHandler.static_url`; otherwise they will be used as-is.
+ """
return None
def embedded_css(self):
- """Returns a CSS string that will be embedded in the page."""
+ """Override to return a CSS string
+ that will be embedded in the page."""
return None
def css_files(self):
- """Returns a list of CSS files required by this module."""
+ """Override to returns a list of CSS files required by this module.
+
+ If the return values are relative paths, they will be passed to
+ `RequestHandler.static_url`; otherwise they will be used as-is.
+ """
return None
def html_head(self):
- """Returns a CSS string that will be put in the element"""
+ """Override to return an HTML string that will be put in the
+ element.
+ """
return None
def html_body(self):
- """Returns an HTML string that will be put in the element"""
+ """Override to return an HTML string that will be put at the end of
+ the element.
+ """
return None
def render_string(self, path, **kwargs):
@@ -2862,11 +2978,13 @@ else:
return result == 0
-def create_signed_value(secret, name, value, version=None, clock=None):
+def create_signed_value(secret, name, value, version=None, clock=None,
+ key_version=None):
if version is None:
version = DEFAULT_SIGNED_VALUE_VERSION
if clock is None:
clock = time.time
+
timestamp = utf8(str(int(clock())))
value = base64.b64encode(utf8(value))
if version == 1:
@@ -2883,7 +3001,7 @@ def create_signed_value(secret, name, value, version=None, clock=None):
#
# The fields are:
# - format version (i.e. 2; no length prefix)
- # - key version (currently 0; reserved for future key rotation features)
+ # - key version (integer, default is 0)
# - timestamp (integer seconds since epoch)
# - name (not encoded; assumed to be ~alphanumeric)
# - value (base64-encoded)
@@ -2891,34 +3009,32 @@ def create_signed_value(secret, name, value, version=None, clock=None):
def format_field(s):
return utf8("%d:" % len(s)) + utf8(s)
to_sign = b"|".join([
- b"2|1:0",
+ b"2",
+ format_field(str(key_version or 0)),
format_field(timestamp),
format_field(name),
format_field(value),
b''])
+
+ if isinstance(secret, dict):
+ assert key_version is not None, 'Key version must be set when sign key dict is used'
+ assert version >= 2, 'Version must be at least 2 for key version support'
+ secret = secret[key_version]
+
signature = _create_signature_v2(secret, to_sign)
return to_sign + signature
else:
raise ValueError("Unsupported version %d" % version)
-# A leading version number in decimal with no leading zeros, followed by a pipe.
+# A leading version number in decimal
+# with no leading zeros, followed by a pipe.
_signed_value_version_re = re.compile(br"^([1-9][0-9]*)\|(.*)$")
-def decode_signed_value(secret, name, value, max_age_days=31, clock=None, min_version=None):
- if clock is None:
- clock = time.time
- if min_version is None:
- min_version = DEFAULT_SIGNED_VALUE_MIN_VERSION
- if min_version > 2:
- raise ValueError("Unsupported min_version %d" % min_version)
- if not value:
- return None
-
- # Figure out what version this is. Version 1 did not include an
+def _get_version(value):
+ # Figures out what version value is. Version 1 did not include an
# explicit version field and started with arbitrary base64 data,
# which makes this tricky.
- value = utf8(value)
m = _signed_value_version_re.match(value)
if m is None:
version = 1
@@ -2935,13 +3051,31 @@ def decode_signed_value(secret, name, value, max_age_days=31, clock=None, min_ve
version = 1
except ValueError:
version = 1
+ return version
+
+
+def decode_signed_value(secret, name, value, max_age_days=31,
+ clock=None, min_version=None):
+ if clock is None:
+ clock = time.time
+ if min_version is None:
+ min_version = DEFAULT_SIGNED_VALUE_MIN_VERSION
+ if min_version > 2:
+ raise ValueError("Unsupported min_version %d" % min_version)
+ if not value:
+ return None
+
+ value = utf8(value)
+ version = _get_version(value)
if version < min_version:
return None
if version == 1:
- return _decode_signed_value_v1(secret, name, value, max_age_days, clock)
+ return _decode_signed_value_v1(secret, name, value,
+ max_age_days, clock)
elif version == 2:
- return _decode_signed_value_v2(secret, name, value, max_age_days, clock)
+ return _decode_signed_value_v2(secret, name, value,
+ max_age_days, clock)
else:
return None
@@ -2964,7 +3098,8 @@ def _decode_signed_value_v1(secret, name, value, max_age_days, clock):
# digits from the payload to the timestamp without altering the
# signature. For backwards compatibility, sanity-check timestamp
# here instead of modifying _cookie_signature.
- gen_log.warning("Cookie timestamp in future; possible tampering %r", value)
+ gen_log.warning("Cookie timestamp in future; possible tampering %r",
+ value)
return None
if parts[1].startswith(b"0"):
gen_log.warning("Tampered cookie %r", value)
@@ -2975,7 +3110,7 @@ def _decode_signed_value_v1(secret, name, value, max_age_days, clock):
return None
-def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
+def _decode_fields_v2(value):
def _consume_field(s):
length, _, rest = s.partition(b':')
n = int(length)
@@ -2986,16 +3121,28 @@ def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
raise ValueError("malformed v2 signed value field")
rest = rest[n + 1:]
return field_value, rest
+
rest = value[2:] # remove version number
+ key_version, rest = _consume_field(rest)
+ timestamp, rest = _consume_field(rest)
+ name_field, rest = _consume_field(rest)
+ value_field, passed_sig = _consume_field(rest)
+ return int(key_version), timestamp, name_field, value_field, passed_sig
+
+
+def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
try:
- key_version, rest = _consume_field(rest)
- timestamp, rest = _consume_field(rest)
- name_field, rest = _consume_field(rest)
- value_field, rest = _consume_field(rest)
+ key_version, timestamp, name_field, value_field, passed_sig = _decode_fields_v2(value)
except ValueError:
return None
- passed_sig = rest
signed_string = value[:-len(passed_sig)]
+
+ if isinstance(secret, dict):
+ try:
+ secret = secret[key_version]
+ except KeyError:
+ return None
+
expected_sig = _create_signature_v2(secret, signed_string)
if not _time_independent_equals(passed_sig, expected_sig):
return None
@@ -3011,6 +3158,19 @@ def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
return None
+def get_signature_key_version(value):
+ value = utf8(value)
+ version = _get_version(value)
+ if version < 2:
+ return None
+ try:
+ key_version, _, _, _, _ = _decode_fields_v2(value)
+ except ValueError:
+ return None
+
+ return key_version
+
+
def _create_signature_v1(secret, *parts):
hash = hmac.new(utf8(secret), digestmod=hashlib.sha1)
for part in parts:
diff --git a/tornado/websocket.py b/tornado/websocket.py
index d960b0e4..adf238be 100644
--- a/tornado/websocket.py
+++ b/tornado/websocket.py
@@ -16,7 +16,8 @@ the protocol (known as "draft 76") and are not compatible with this module.
Removed support for the draft 76 protocol version.
"""
-from __future__ import absolute_import, division, print_function, with_statement
+from __future__ import (absolute_import, division,
+ print_function, with_statement)
# Author: Jacob Kristhammar, 2010
import base64
@@ -39,9 +40,9 @@ from tornado.tcpclient import TCPClient
from tornado.util import _websocket_mask
try:
- from urllib.parse import urlparse # py2
+ from urllib.parse import urlparse # py2
except ImportError:
- from urlparse import urlparse # py3
+ from urlparse import urlparse # py3
try:
xrange # py2
@@ -74,17 +75,22 @@ class WebSocketHandler(tornado.web.RequestHandler):
http://tools.ietf.org/html/rfc6455.
Here is an example WebSocket handler that echos back all received messages
- back to the client::
+ back to the client:
- class EchoWebSocket(websocket.WebSocketHandler):
+ .. testcode::
+
+ class EchoWebSocket(tornado.websocket.WebSocketHandler):
def open(self):
- print "WebSocket opened"
+ print("WebSocket opened")
def on_message(self, message):
self.write_message(u"You said: " + message)
def on_close(self):
- print "WebSocket closed"
+ print("WebSocket closed")
+
+ .. testoutput::
+ :hide:
WebSockets are not standard HTTP connections. The "handshake" is
HTTP, but after the handshake, the protocol is
@@ -129,6 +135,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
self.close_code = None
self.close_reason = None
self.stream = None
+ self._on_close_called = False
@tornado.web.asynchronous
def get(self, *args, **kwargs):
@@ -138,16 +145,22 @@ class WebSocketHandler(tornado.web.RequestHandler):
# Upgrade header should be present and should be equal to WebSocket
if self.request.headers.get("Upgrade", "").lower() != 'websocket':
self.set_status(400)
- self.finish("Can \"Upgrade\" only to \"WebSocket\".")
+ log_msg = "Can \"Upgrade\" only to \"WebSocket\"."
+ self.finish(log_msg)
+ gen_log.debug(log_msg)
return
- # Connection header should be upgrade. Some proxy servers/load balancers
+ # Connection header should be upgrade.
+ # Some proxy servers/load balancers
# might mess with it.
headers = self.request.headers
- connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(","))
+ connection = map(lambda s: s.strip().lower(),
+ headers.get("Connection", "").split(","))
if 'upgrade' not in connection:
self.set_status(400)
- self.finish("\"Connection\" must be \"Upgrade\".")
+ log_msg = "\"Connection\" must be \"Upgrade\"."
+ self.finish(log_msg)
+ gen_log.debug(log_msg)
return
# Handle WebSocket Origin naming convention differences
@@ -159,30 +172,29 @@ class WebSocketHandler(tornado.web.RequestHandler):
else:
origin = self.request.headers.get("Sec-Websocket-Origin", None)
-
# If there was an origin header, check to make sure it matches
# according to check_origin. When the origin is None, we assume it
# did not come from a browser and that it can be passed on.
if origin is not None and not self.check_origin(origin):
self.set_status(403)
- self.finish("Cross origin websockets not allowed")
+ log_msg = "Cross origin websockets not allowed"
+ self.finish(log_msg)
+ gen_log.debug(log_msg)
return
self.stream = self.request.connection.detach()
self.stream.set_close_callback(self.on_connection_close)
- if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
- self.ws_connection = WebSocketProtocol13(
- self, compression_options=self.get_compression_options())
+ self.ws_connection = self.get_websocket_protocol()
+ if self.ws_connection:
self.ws_connection.accept_connection()
else:
if not self.stream.closed():
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 426 Upgrade Required\r\n"
- "Sec-WebSocket-Version: 8\r\n\r\n"))
+ "Sec-WebSocket-Version: 7, 8, 13\r\n\r\n"))
self.stream.close()
-
def write_message(self, message, binary=False):
"""Sends the given message to the client of this Web Socket.
@@ -229,7 +241,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
"""
return None
- def open(self):
+ def open(self, *args, **kwargs):
"""Invoked when a new WebSocket is opened.
The arguments to `open` are extracted from the `tornado.web.URLSpec`
@@ -350,6 +362,8 @@ class WebSocketHandler(tornado.web.RequestHandler):
if self.ws_connection:
self.ws_connection.on_connection_close()
self.ws_connection = None
+ if not self._on_close_called:
+ self._on_close_called = True
self.on_close()
def send_error(self, *args, **kwargs):
@@ -362,6 +376,13 @@ class WebSocketHandler(tornado.web.RequestHandler):
# we can close the connection more gracefully.
self.stream.close()
+ def get_websocket_protocol(self):
+ websocket_version = self.request.headers.get("Sec-WebSocket-Version")
+ if websocket_version in ("7", "8", "13"):
+ return WebSocketProtocol13(
+ self, compression_options=self.get_compression_options())
+
+
def _wrap_method(method):
def _disallow_for_websocket(self, *args, **kwargs):
if self.stream is None:
@@ -499,7 +520,8 @@ class WebSocketProtocol13(WebSocketProtocol):
self._handle_websocket_headers()
self._accept_connection()
except ValueError:
- gen_log.debug("Malformed WebSocket request received", exc_info=True)
+ gen_log.debug("Malformed WebSocket request received",
+ exc_info=True)
self._abort()
return
@@ -535,18 +557,19 @@ class WebSocketProtocol13(WebSocketProtocol):
selected = self.handler.select_subprotocol(subprotocols)
if selected:
assert selected in subprotocols
- subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
+ subprotocol_header = ("Sec-WebSocket-Protocol: %s\r\n"
+ % selected)
extension_header = ''
extensions = self._parse_extensions_header(self.request.headers)
for ext in extensions:
if (ext[0] == 'permessage-deflate' and
- self._compression_options is not None):
+ self._compression_options is not None):
# TODO: negotiate parameters if compression_options
# specifies limits.
self._create_compressors('server', ext[1])
if ('client_max_window_bits' in ext[1] and
- ext[1]['client_max_window_bits'] is None):
+ ext[1]['client_max_window_bits'] is None):
# Don't echo an offered client_max_window_bits
# parameter with no value.
del ext[1]['client_max_window_bits']
@@ -591,7 +614,7 @@ class WebSocketProtocol13(WebSocketProtocol):
extensions = self._parse_extensions_header(headers)
for ext in extensions:
if (ext[0] == 'permessage-deflate' and
- self._compression_options is not None):
+ self._compression_options is not None):
self._create_compressors('client', ext[1])
else:
raise ValueError("unsupported extension %r", ext)
@@ -703,7 +726,8 @@ class WebSocketProtocol13(WebSocketProtocol):
if self._masked_frame:
self.stream.read_bytes(4, self._on_masking_key)
else:
- self.stream.read_bytes(self._frame_length, self._on_frame_data)
+ self.stream.read_bytes(self._frame_length,
+ self._on_frame_data)
elif payloadlen == 126:
self.stream.read_bytes(2, self._on_frame_length_16)
elif payloadlen == 127:
@@ -737,7 +761,8 @@ class WebSocketProtocol13(WebSocketProtocol):
self._wire_bytes_in += len(data)
self._frame_mask = data
try:
- self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
+ self.stream.read_bytes(self._frame_length,
+ self._on_masked_frame_data)
except StreamClosedError:
self._abort()
@@ -852,12 +877,15 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
This class should not be instantiated directly; use the
`websocket_connect` function instead.
"""
- def __init__(self, io_loop, request, compression_options=None):
+ def __init__(self, io_loop, request, on_message_callback=None,
+ compression_options=None):
self.compression_options = compression_options
self.connect_future = TracebackFuture()
+ self.protocol = None
self.read_future = None
self.read_queue = collections.deque()
self.key = base64.b64encode(os.urandom(16))
+ self._on_message_callback = on_message_callback
scheme, sep, rest = request.url.partition(':')
scheme = {'ws': 'http', 'wss': 'https'}[scheme]
@@ -880,7 +908,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
self.tcp_client = TCPClient(io_loop=io_loop)
super(WebSocketClientConnection, self).__init__(
io_loop, None, request, lambda: None, self._on_http_response,
- 104857600, self.tcp_client, 65536)
+ 104857600, self.tcp_client, 65536, 104857600)
def close(self, code=None, reason=None):
"""Closes the websocket connection.
@@ -919,9 +947,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
start_line, headers)
self.headers = headers
- self.protocol = WebSocketProtocol13(
- self, mask_outgoing=True,
- compression_options=self.compression_options)
+ self.protocol = self.get_websocket_protocol()
self.protocol._process_server_headers(self.key, self.headers)
self.protocol._receive_frame()
@@ -946,6 +972,9 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
def read_message(self, callback=None):
"""Reads a message from the WebSocket server.
+ If on_message_callback was specified at WebSocket
+ initialization, this function will never return messages
+
Returns a future whose result is the message, or None
if the connection is closed. If a callback argument
is given it will be called with the future when it is
@@ -962,7 +991,9 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
return future
def on_message(self, message):
- if self.read_future is not None:
+ if self._on_message_callback:
+ self._on_message_callback(message)
+ elif self.read_future is not None:
self.read_future.set_result(message)
self.read_future = None
else:
@@ -971,9 +1002,13 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
def on_pong(self, data):
pass
+ def get_websocket_protocol(self):
+ return WebSocketProtocol13(self, mask_outgoing=True,
+ compression_options=self.compression_options)
+
def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
- compression_options=None):
+ on_message_callback=None, compression_options=None):
"""Client-side websocket support.
Takes a url and returns a Future whose result is a
@@ -982,11 +1017,26 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
``compression_options`` is interpreted in the same way as the
return value of `.WebSocketHandler.get_compression_options`.
+ The connection supports two styles of operation. In the coroutine
+ style, the application typically calls
+ `~.WebSocketClientConnection.read_message` in a loop::
+
+ conn = yield websocket_connection(loop)
+ while True:
+ msg = yield conn.read_message()
+ if msg is None: break
+ # Do something with msg
+
+ In the callback style, pass an ``on_message_callback`` to
+ ``websocket_connect``. In both styles, a message of ``None``
+ indicates that the connection has been closed.
+
.. versionchanged:: 3.2
Also accepts ``HTTPRequest`` objects in place of urls.
.. versionchanged:: 4.1
- Added ``compression_options``.
+ Added ``compression_options`` and ``on_message_callback``.
+ The ``io_loop`` argument is deprecated.
"""
if io_loop is None:
io_loop = IOLoop.current()
@@ -1000,7 +1050,9 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
request = httpclient._RequestProxy(
request, httpclient.HTTPRequest._DEFAULTS)
- conn = WebSocketClientConnection(io_loop, request, compression_options)
+ conn = WebSocketClientConnection(io_loop, request,
+ on_message_callback=on_message_callback,
+ compression_options=compression_options)
if callback is not None:
io_loop.add_future(conn.connect_future, callback)
return conn.connect_future
diff --git a/tornado/wsgi.py b/tornado/wsgi.py
index f3aa6650..59e6c559 100644
--- a/tornado/wsgi.py
+++ b/tornado/wsgi.py
@@ -207,7 +207,7 @@ class WSGIAdapter(object):
body = environ["wsgi.input"].read(
int(headers["Content-Length"]))
else:
- body = ""
+ body = b""
protocol = environ["wsgi.url_scheme"]
remote_ip = environ.get("REMOTE_ADDR", "")
if environ.get("HTTP_HOST"):
@@ -253,7 +253,7 @@ class WSGIContainer(object):
container = tornado.wsgi.WSGIContainer(simple_app)
http_server = tornado.httpserver.HTTPServer(container)
http_server.listen(8888)
- tornado.ioloop.IOLoop.instance().start()
+ tornado.ioloop.IOLoop.current().start()
This class is intended to let other frameworks (Django, web.py, etc)
run on the Tornado HTTP server and I/O loop.
@@ -284,7 +284,8 @@ class WSGIContainer(object):
if not data:
raise Exception("WSGI app did not call start_response")
- status_code = int(data["status"].split()[0])
+ status_code, reason = data["status"].split(' ', 1)
+ status_code = int(status_code)
headers = data["headers"]
header_set = set(k.lower() for (k, v) in headers)
body = escape.utf8(body)
@@ -296,13 +297,12 @@ class WSGIContainer(object):
if "server" not in header_set:
headers.append(("Server", "TornadoServer/%s" % tornado.version))
- parts = [escape.utf8("HTTP/1.1 " + data["status"] + "\r\n")]
+ start_line = httputil.ResponseStartLine("HTTP/1.1", status_code, reason)
+ header_obj = httputil.HTTPHeaders()
for key, value in headers:
- parts.append(escape.utf8(key) + b": " + escape.utf8(value) + b"\r\n")
- parts.append(b"\r\n")
- parts.append(body)
- request.write(b"".join(parts))
- request.finish()
+ header_obj.add(key, value)
+ request.connection.write_headers(start_line, header_obj, chunk=body)
+ request.connection.finish()
self._log(status_code, request)
@staticmethod