diff --git a/CHANGES.md b/CHANGES.md index f6ac5ad0..b8108a21 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,6 @@ ### 0.x.x (2015-xx-xx xx:xx:xx UTC) +* Update Tornado webserver to 4.2.dev1 (609dbb9) * Change network names to only display on top line of Day by Day layout on Episode View * Reposition country part of network name into the hover over in Day by Day layout * Add ToTV provider diff --git a/tornado/__init__.py b/tornado/__init__.py index 0e39f842..01b926be 100644 --- a/tornado/__init__.py +++ b/tornado/__init__.py @@ -25,5 +25,5 @@ from __future__ import absolute_import, division, print_function, with_statement # is zero for an official release, positive for a development branch, # or negative for a release candidate or beta (after the base version # number has been incremented) -version = "4.1.dev1" -version_info = (4, 1, 0, -100) +version = "4.2.dev1" +version_info = (4, 2, 0, -100) diff --git a/tornado/auth.py b/tornado/auth.py index ac2fd0d1..800b10af 100644 --- a/tornado/auth.py +++ b/tornado/auth.py @@ -32,7 +32,9 @@ They all take slightly different arguments due to the fact all these services implement authentication and authorization slightly differently. See the individual service classes below for complete documentation. -Example usage for Google OpenID:: +Example usage for Google OAuth: + +.. testcode:: class GoogleOAuth2LoginHandler(tornado.web.RequestHandler, tornado.auth.GoogleOAuth2Mixin): @@ -51,6 +53,10 @@ Example usage for Google OpenID:: response_type='code', extra_params={'approval_prompt': 'auto'}) +.. testoutput:: + :hide: + + .. versionchanged:: 4.0 All of the callback interfaces in this module are now guaranteed to run their callback with an argument of ``None`` on error. @@ -69,7 +75,7 @@ import hmac import time import uuid -from tornado.concurrent import TracebackFuture, chain_future, return_future +from tornado.concurrent import TracebackFuture, return_future from tornado import gen from tornado import httpclient from tornado import escape @@ -123,6 +129,7 @@ def _auth_return_future(f): if callback is not None: future.add_done_callback( functools.partial(_auth_future_to_callback, callback)) + def handle_exception(typ, value, tb): if future.done(): return False @@ -138,9 +145,6 @@ def _auth_return_future(f): class OpenIdMixin(object): """Abstract implementation of OpenID and Attribute Exchange. - See `GoogleMixin` below for a customized example (which also - includes OAuth support). - Class attributes: * ``_OPENID_ENDPOINT``: the identity provider's URI. @@ -312,8 +316,7 @@ class OpenIdMixin(object): class OAuthMixin(object): """Abstract implementation of OAuth 1.0 and 1.0a. - See `TwitterMixin` and `FriendFeedMixin` below for example implementations, - or `GoogleMixin` for an OAuth/OpenID hybrid. + See `TwitterMixin` below for an example implementation. Class attributes: @@ -565,7 +568,8 @@ class OAuthMixin(object): class OAuth2Mixin(object): """Abstract implementation of OAuth 2.0. - See `FacebookGraphMixin` below for an example implementation. + See `FacebookGraphMixin` or `GoogleOAuth2Mixin` below for example + implementations. Class attributes: @@ -629,7 +633,9 @@ class TwitterMixin(OAuthMixin): URL you registered as your application's callback URL. When your application is set up, you can use this mixin like this - to authenticate the user with Twitter and get access to their stream:: + to authenticate the user with Twitter and get access to their stream: + + .. testcode:: class TwitterLoginHandler(tornado.web.RequestHandler, tornado.auth.TwitterMixin): @@ -641,6 +647,9 @@ class TwitterMixin(OAuthMixin): else: yield self.authorize_redirect() + .. testoutput:: + :hide: + The user object returned by `~OAuthMixin.get_authenticated_user` includes the attributes ``username``, ``name``, ``access_token``, and all of the custom Twitter user attributes described at @@ -689,7 +698,9 @@ class TwitterMixin(OAuthMixin): `~OAuthMixin.get_authenticated_user`. The user returned through that process includes an 'access_token' attribute that can be used to make authenticated requests via this method. Example - usage:: + usage: + + .. testcode:: class MainHandler(tornado.web.RequestHandler, tornado.auth.TwitterMixin): @@ -706,6 +717,9 @@ class TwitterMixin(OAuthMixin): return self.finish("Posted a message!") + .. testoutput:: + :hide: + """ if path.startswith('http:') or path.startswith('https:'): # Raw urls are useful for e.g. search which doesn't follow the @@ -757,223 +771,6 @@ class TwitterMixin(OAuthMixin): raise gen.Return(user) -class FriendFeedMixin(OAuthMixin): - """FriendFeed OAuth authentication. - - To authenticate with FriendFeed, register your application with - FriendFeed at http://friendfeed.com/api/applications. Then copy - your Consumer Key and Consumer Secret to the application - `~tornado.web.Application.settings` ``friendfeed_consumer_key`` - and ``friendfeed_consumer_secret``. Use this mixin on the handler - for the URL you registered as your application's Callback URL. - - When your application is set up, you can use this mixin like this - to authenticate the user with FriendFeed and get access to their feed:: - - class FriendFeedLoginHandler(tornado.web.RequestHandler, - tornado.auth.FriendFeedMixin): - @tornado.gen.coroutine - def get(self): - if self.get_argument("oauth_token", None): - user = yield self.get_authenticated_user() - # Save the user using e.g. set_secure_cookie() - else: - yield self.authorize_redirect() - - The user object returned by `~OAuthMixin.get_authenticated_user()` includes the - attributes ``username``, ``name``, and ``description`` in addition to - ``access_token``. You should save the access token with the user; - it is required to make requests on behalf of the user later with - `friendfeed_request()`. - """ - _OAUTH_VERSION = "1.0" - _OAUTH_REQUEST_TOKEN_URL = "https://friendfeed.com/account/oauth/request_token" - _OAUTH_ACCESS_TOKEN_URL = "https://friendfeed.com/account/oauth/access_token" - _OAUTH_AUTHORIZE_URL = "https://friendfeed.com/account/oauth/authorize" - _OAUTH_NO_CALLBACKS = True - _OAUTH_VERSION = "1.0" - - @_auth_return_future - def friendfeed_request(self, path, callback, access_token=None, - post_args=None, **args): - """Fetches the given relative API path, e.g., "/bret/friends" - - If the request is a POST, ``post_args`` should be provided. Query - string arguments should be given as keyword arguments. - - All the FriendFeed methods are documented at - http://friendfeed.com/api/documentation. - - Many methods require an OAuth access token which you can - obtain through `~OAuthMixin.authorize_redirect` and - `~OAuthMixin.get_authenticated_user`. The user returned - through that process includes an ``access_token`` attribute that - can be used to make authenticated requests via this - method. - - Example usage:: - - class MainHandler(tornado.web.RequestHandler, - tornado.auth.FriendFeedMixin): - @tornado.web.authenticated - @tornado.gen.coroutine - def get(self): - new_entry = yield self.friendfeed_request( - "/entry", - post_args={"body": "Testing Tornado Web Server"}, - access_token=self.current_user["access_token"]) - - if not new_entry: - # Call failed; perhaps missing permission? - yield self.authorize_redirect() - return - self.finish("Posted a message!") - - """ - # Add the OAuth resource request signature if we have credentials - url = "http://friendfeed-api.com/v2" + path - if access_token: - all_args = {} - all_args.update(args) - all_args.update(post_args or {}) - method = "POST" if post_args is not None else "GET" - oauth = self._oauth_request_parameters( - url, access_token, all_args, method=method) - args.update(oauth) - if args: - url += "?" + urllib_parse.urlencode(args) - callback = functools.partial(self._on_friendfeed_request, callback) - http = self.get_auth_http_client() - if post_args is not None: - http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args), - callback=callback) - else: - http.fetch(url, callback=callback) - - def _on_friendfeed_request(self, future, response): - if response.error: - future.set_exception(AuthError( - "Error response %s fetching %s" % (response.error, - response.request.url))) - return - future.set_result(escape.json_decode(response.body)) - - def _oauth_consumer_token(self): - self.require_setting("friendfeed_consumer_key", "FriendFeed OAuth") - self.require_setting("friendfeed_consumer_secret", "FriendFeed OAuth") - return dict( - key=self.settings["friendfeed_consumer_key"], - secret=self.settings["friendfeed_consumer_secret"]) - - @gen.coroutine - def _oauth_get_user_future(self, access_token, callback): - user = yield self.friendfeed_request( - "/feedinfo/" + access_token["username"], - include="id,name,description", access_token=access_token) - if user: - user["username"] = user["id"] - callback(user) - - def _parse_user_response(self, callback, user): - if user: - user["username"] = user["id"] - callback(user) - - -class GoogleMixin(OpenIdMixin, OAuthMixin): - """Google Open ID / OAuth authentication. - - .. deprecated:: 4.0 - New applications should use `GoogleOAuth2Mixin` - below instead of this class. As of May 19, 2014, Google has stopped - supporting registration-free authentication. - - No application registration is necessary to use Google for - authentication or to access Google resources on behalf of a user. - - Google implements both OpenID and OAuth in a hybrid mode. If you - just need the user's identity, use - `~OpenIdMixin.authenticate_redirect`. If you need to make - requests to Google on behalf of the user, use - `authorize_redirect`. On return, parse the response with - `~OpenIdMixin.get_authenticated_user`. We send a dict containing - the values for the user, including ``email``, ``name``, and - ``locale``. - - Example usage:: - - class GoogleLoginHandler(tornado.web.RequestHandler, - tornado.auth.GoogleMixin): - @tornado.gen.coroutine - def get(self): - if self.get_argument("openid.mode", None): - user = yield self.get_authenticated_user() - # Save the user with e.g. set_secure_cookie() - else: - yield self.authenticate_redirect() - """ - _OPENID_ENDPOINT = "https://www.google.com/accounts/o8/ud" - _OAUTH_ACCESS_TOKEN_URL = "https://www.google.com/accounts/OAuthGetAccessToken" - - @return_future - def authorize_redirect(self, oauth_scope, callback_uri=None, - ax_attrs=["name", "email", "language", "username"], - callback=None): - """Authenticates and authorizes for the given Google resource. - - Some of the available resources which can be used in the ``oauth_scope`` - argument are: - - * Gmail Contacts - http://www.google.com/m8/feeds/ - * Calendar - http://www.google.com/calendar/feeds/ - * Finance - http://finance.google.com/finance/feeds/ - - You can authorize multiple resources by separating the resource - URLs with a space. - - .. versionchanged:: 3.1 - Returns a `.Future` and takes an optional callback. These are - not strictly necessary as this method is synchronous, - but they are supplied for consistency with - `OAuthMixin.authorize_redirect`. - """ - callback_uri = callback_uri or self.request.uri - args = self._openid_args(callback_uri, ax_attrs=ax_attrs, - oauth_scope=oauth_scope) - self.redirect(self._OPENID_ENDPOINT + "?" + urllib_parse.urlencode(args)) - callback() - - @_auth_return_future - def get_authenticated_user(self, callback): - """Fetches the authenticated user data upon redirect.""" - # Look to see if we are doing combined OpenID/OAuth - oauth_ns = "" - for name, values in self.request.arguments.items(): - if name.startswith("openid.ns.") and \ - values[-1] == b"http://specs.openid.net/extensions/oauth/1.0": - oauth_ns = name[10:] - break - token = self.get_argument("openid." + oauth_ns + ".request_token", "") - if token: - http = self.get_auth_http_client() - token = dict(key=token, secret="") - http.fetch(self._oauth_access_token_url(token), - functools.partial(self._on_access_token, callback)) - else: - chain_future(OpenIdMixin.get_authenticated_user(self), - callback) - - def _oauth_consumer_token(self): - self.require_setting("google_consumer_key", "Google OAuth") - self.require_setting("google_consumer_secret", "Google OAuth") - return dict( - key=self.settings["google_consumer_key"], - secret=self.settings["google_consumer_secret"]) - - def _oauth_get_user_future(self, access_token): - return OpenIdMixin.get_authenticated_user(self) - - class GoogleOAuth2Mixin(OAuth2Mixin): """Google authentication using OAuth2. @@ -1001,7 +798,9 @@ class GoogleOAuth2Mixin(OAuth2Mixin): def get_authenticated_user(self, redirect_uri, code, callback): """Handles the login for the Google user, returning a user object. - Example usage:: + Example usage: + + .. testcode:: class GoogleOAuth2LoginHandler(tornado.web.RequestHandler, tornado.auth.GoogleOAuth2Mixin): @@ -1019,6 +818,10 @@ class GoogleOAuth2Mixin(OAuth2Mixin): scope=['profile', 'email'], response_type='code', extra_params={'approval_prompt': 'auto'}) + + .. testoutput:: + :hide: + """ http = self.get_auth_http_client() body = urllib_parse.urlencode({ @@ -1051,217 +854,6 @@ class GoogleOAuth2Mixin(OAuth2Mixin): return httpclient.AsyncHTTPClient() -class FacebookMixin(object): - """Facebook Connect authentication. - - .. deprecated:: 1.1 - New applications should use `FacebookGraphMixin` - below instead of this class. This class does not support the - Future-based interface seen on other classes in this module. - - To authenticate with Facebook, register your application with - Facebook at http://www.facebook.com/developers/apps.php. Then - copy your API Key and Application Secret to the application settings - ``facebook_api_key`` and ``facebook_secret``. - - When your application is set up, you can use this mixin like this - to authenticate the user with Facebook:: - - class FacebookHandler(tornado.web.RequestHandler, - tornado.auth.FacebookMixin): - @tornado.web.asynchronous - def get(self): - if self.get_argument("session", None): - self.get_authenticated_user(self._on_auth) - return - yield self.authenticate_redirect() - - def _on_auth(self, user): - if not user: - raise tornado.web.HTTPError(500, "Facebook auth failed") - # Save the user using, e.g., set_secure_cookie() - - The user object returned by `get_authenticated_user` includes the - attributes ``facebook_uid`` and ``name`` in addition to session attributes - like ``session_key``. You should save the session key with the user; it is - required to make requests on behalf of the user later with - `facebook_request`. - """ - @return_future - def authenticate_redirect(self, callback_uri=None, cancel_uri=None, - extended_permissions=None, callback=None): - """Authenticates/installs this app for the current user. - - .. versionchanged:: 3.1 - Returns a `.Future` and takes an optional callback. These are - not strictly necessary as this method is synchronous, - but they are supplied for consistency with - `OAuthMixin.authorize_redirect`. - """ - self.require_setting("facebook_api_key", "Facebook Connect") - callback_uri = callback_uri or self.request.uri - args = { - "api_key": self.settings["facebook_api_key"], - "v": "1.0", - "fbconnect": "true", - "display": "page", - "next": urlparse.urljoin(self.request.full_url(), callback_uri), - "return_session": "true", - } - if cancel_uri: - args["cancel_url"] = urlparse.urljoin( - self.request.full_url(), cancel_uri) - if extended_permissions: - if isinstance(extended_permissions, (unicode_type, bytes)): - extended_permissions = [extended_permissions] - args["req_perms"] = ",".join(extended_permissions) - self.redirect("http://www.facebook.com/login.php?" + - urllib_parse.urlencode(args)) - callback() - - def authorize_redirect(self, extended_permissions, callback_uri=None, - cancel_uri=None, callback=None): - """Redirects to an authorization request for the given FB resource. - - The available resource names are listed at - http://wiki.developers.facebook.com/index.php/Extended_permission. - The most common resource types include: - - * publish_stream - * read_stream - * email - * sms - - extended_permissions can be a single permission name or a list of - names. To get the session secret and session key, call - get_authenticated_user() just as you would with - authenticate_redirect(). - - .. versionchanged:: 3.1 - Returns a `.Future` and takes an optional callback. These are - not strictly necessary as this method is synchronous, - but they are supplied for consistency with - `OAuthMixin.authorize_redirect`. - """ - return self.authenticate_redirect(callback_uri, cancel_uri, - extended_permissions, - callback=callback) - - def get_authenticated_user(self, callback): - """Fetches the authenticated Facebook user. - - The authenticated user includes the special Facebook attributes - 'session_key' and 'facebook_uid' in addition to the standard - user attributes like 'name'. - """ - self.require_setting("facebook_api_key", "Facebook Connect") - session = escape.json_decode(self.get_argument("session")) - self.facebook_request( - method="facebook.users.getInfo", - callback=functools.partial( - self._on_get_user_info, callback, session), - session_key=session["session_key"], - uids=session["uid"], - fields="uid,first_name,last_name,name,locale,pic_square," - "profile_url,username") - - def facebook_request(self, method, callback, **args): - """Makes a Facebook API REST request. - - We automatically include the Facebook API key and signature, but - it is the callers responsibility to include 'session_key' and any - other required arguments to the method. - - The available Facebook methods are documented here: - http://wiki.developers.facebook.com/index.php/API - - Here is an example for the stream.get() method:: - - class MainHandler(tornado.web.RequestHandler, - tornado.auth.FacebookMixin): - @tornado.web.authenticated - @tornado.web.asynchronous - def get(self): - self.facebook_request( - method="stream.get", - callback=self._on_stream, - session_key=self.current_user["session_key"]) - - def _on_stream(self, stream): - if stream is None: - # Not authorized to read the stream yet? - self.redirect(self.authorize_redirect("read_stream")) - return - self.render("stream.html", stream=stream) - - """ - self.require_setting("facebook_api_key", "Facebook Connect") - self.require_setting("facebook_secret", "Facebook Connect") - if not method.startswith("facebook."): - method = "facebook." + method - args["api_key"] = self.settings["facebook_api_key"] - args["v"] = "1.0" - args["method"] = method - args["call_id"] = str(long(time.time() * 1e6)) - args["format"] = "json" - args["sig"] = self._signature(args) - url = "http://api.facebook.com/restserver.php?" + \ - urllib_parse.urlencode(args) - http = self.get_auth_http_client() - http.fetch(url, callback=functools.partial( - self._parse_response, callback)) - - def _on_get_user_info(self, callback, session, users): - if users is None: - callback(None) - return - callback({ - "name": users[0]["name"], - "first_name": users[0]["first_name"], - "last_name": users[0]["last_name"], - "uid": users[0]["uid"], - "locale": users[0]["locale"], - "pic_square": users[0]["pic_square"], - "profile_url": users[0]["profile_url"], - "username": users[0].get("username"), - "session_key": session["session_key"], - "session_expires": session.get("expires"), - }) - - def _parse_response(self, callback, response): - if response.error: - gen_log.warning("HTTP error from Facebook: %s", response.error) - callback(None) - return - try: - json = escape.json_decode(response.body) - except Exception: - gen_log.warning("Invalid JSON from Facebook: %r", response.body) - callback(None) - return - if isinstance(json, dict) and json.get("error_code"): - gen_log.warning("Facebook error: %d: %r", json["error_code"], - json.get("error_msg")) - callback(None) - return - callback(json) - - def _signature(self, args): - parts = ["%s=%s" % (n, args[n]) for n in sorted(args.keys())] - body = "".join(parts) + self.settings["facebook_secret"] - if isinstance(body, unicode_type): - body = body.encode("utf-8") - return hashlib.md5(body).hexdigest() - - def get_auth_http_client(self): - """Returns the `.AsyncHTTPClient` instance to be used for auth requests. - - May be overridden by subclasses to use an HTTP client other than - the default. - """ - return httpclient.AsyncHTTPClient() - - class FacebookGraphMixin(OAuth2Mixin): """Facebook authentication using the new Graph API and OAuth2.""" _OAUTH_ACCESS_TOKEN_URL = "https://graph.facebook.com/oauth/access_token?" @@ -1274,9 +866,12 @@ class FacebookGraphMixin(OAuth2Mixin): code, callback, extra_fields=None): """Handles the login for the Facebook user, returning a user object. - Example usage:: + Example usage: - class FacebookGraphLoginHandler(LoginHandler, tornado.auth.FacebookGraphMixin): + .. testcode:: + + class FacebookGraphLoginHandler(tornado.web.RequestHandler, + tornado.auth.FacebookGraphMixin): @tornado.gen.coroutine def get(self): if self.get_argument("code", False): @@ -1291,6 +886,10 @@ class FacebookGraphMixin(OAuth2Mixin): redirect_uri='/auth/facebookgraph/', client_id=self.settings["facebook_api_key"], extra_params={"scope": "read_stream,offline_access"}) + + .. testoutput:: + :hide: + """ http = self.get_auth_http_client() args = { @@ -1307,7 +906,7 @@ class FacebookGraphMixin(OAuth2Mixin): http.fetch(self._oauth_request_token_url(**args), functools.partial(self._on_access_token, redirect_uri, client_id, - client_secret, callback, fields)) + client_secret, callback, fields)) def _on_access_token(self, redirect_uri, client_id, client_secret, future, fields, response): @@ -1358,7 +957,9 @@ class FacebookGraphMixin(OAuth2Mixin): process includes an ``access_token`` attribute that can be used to make authenticated requests via this method. - Example usage:: + Example usage: + + ..testcode:: class MainHandler(tornado.web.RequestHandler, tornado.auth.FacebookGraphMixin): @@ -1376,6 +977,9 @@ class FacebookGraphMixin(OAuth2Mixin): return self.finish("Posted a message!") + .. testoutput:: + :hide: + The given path is relative to ``self._FACEBOOK_BASE_URL``, by default "https://graph.facebook.com". diff --git a/tornado/autoreload.py b/tornado/autoreload.py index 3982579a..a52ddde4 100644 --- a/tornado/autoreload.py +++ b/tornado/autoreload.py @@ -100,6 +100,14 @@ try: except ImportError: signal = None +# os.execv is broken on Windows and can't properly parse command line +# arguments and executable name if they contain whitespaces. subprocess +# fixes that behavior. +# This distinction is also important because when we use execv, we want to +# close the IOLoop and all its file descriptors, to guard against any +# file descriptors that were not set CLOEXEC. When execv is not available, +# we must not close the IOLoop because we want the process to exit cleanly. +_has_execv = sys.platform != 'win32' _watched_files = set() _reload_hooks = [] @@ -108,14 +116,19 @@ _io_loops = weakref.WeakKeyDictionary() def start(io_loop=None, check_time=500): - """Begins watching source files for changes using the given `.IOLoop`. """ + """Begins watching source files for changes. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. + """ io_loop = io_loop or ioloop.IOLoop.current() if io_loop in _io_loops: return _io_loops[io_loop] = True if len(_io_loops) > 1: gen_log.warning("tornado.autoreload started more than once in the same process") - add_reload_hook(functools.partial(io_loop.close, all_fds=True)) + if _has_execv: + add_reload_hook(functools.partial(io_loop.close, all_fds=True)) modify_times = {} callback = functools.partial(_reload_on_update, modify_times) scheduler = ioloop.PeriodicCallback(callback, check_time, io_loop=io_loop) @@ -162,7 +175,7 @@ def _reload_on_update(modify_times): # processes restarted themselves, they'd all restart and then # all call fork_processes again. return - for module in sys.modules.values(): + for module in list(sys.modules.values()): # Some modules play games with sys.modules (e.g. email/__init__.py # in the standard library), and occasionally this can cause strange # failures in getattr. Just ignore anything that's not an ordinary @@ -211,10 +224,7 @@ def _reload(): not os.environ.get("PYTHONPATH", "").startswith(path_prefix)): os.environ["PYTHONPATH"] = (path_prefix + os.environ.get("PYTHONPATH", "")) - if sys.platform == 'win32': - # os.execv is broken on Windows and can't properly parse command line - # arguments and executable name if they contain whitespaces. subprocess - # fixes that behavior. + if not _has_execv: subprocess.Popen([sys.executable] + sys.argv) sys.exit(0) else: @@ -234,7 +244,10 @@ def _reload(): # this error specifically. os.spawnv(os.P_NOWAIT, sys.executable, [sys.executable] + sys.argv) - sys.exit(0) + # At this point the IOLoop has been closed and finally + # blocks will experience errors if we allow the stack to + # unwind, so just exit uncleanly. + os._exit(0) _USAGE = """\ Usage: diff --git a/tornado/concurrent.py b/tornado/concurrent.py index 6bab5d2e..51ae239f 100644 --- a/tornado/concurrent.py +++ b/tornado/concurrent.py @@ -25,11 +25,13 @@ module. from __future__ import absolute_import, division, print_function, with_statement import functools +import platform +import traceback import sys +from tornado.log import app_log from tornado.stack_context import ExceptionStackContext, wrap from tornado.util import raise_exc_info, ArgReplacer -from tornado.log import app_log try: from concurrent import futures @@ -37,9 +39,90 @@ except ImportError: futures = None +# Can the garbage collector handle cycles that include __del__ methods? +# This is true in cpython beginning with version 3.4 (PEP 442). +_GC_CYCLE_FINALIZERS = (platform.python_implementation() == 'CPython' and + sys.version_info >= (3, 4)) + + class ReturnValueIgnoredError(Exception): pass +# This class and associated code in the future object is derived +# from the Trollius project, a backport of asyncio to Python 2.x - 3.x + + +class _TracebackLogger(object): + """Helper to log a traceback upon destruction if not cleared. + + This solves a nasty problem with Futures and Tasks that have an + exception set: if nobody asks for the exception, the exception is + never logged. This violates the Zen of Python: 'Errors should + never pass silently. Unless explicitly silenced.' + + However, we don't want to log the exception as soon as + set_exception() is called: if the calling code is written + properly, it will get the exception and handle it properly. But + we *do* want to log it if result() or exception() was never called + -- otherwise developers waste a lot of time wondering why their + buggy code fails silently. + + An earlier attempt added a __del__() method to the Future class + itself, but this backfired because the presence of __del__() + prevents garbage collection from breaking cycles. A way out of + this catch-22 is to avoid having a __del__() method on the Future + class itself, but instead to have a reference to a helper object + with a __del__() method that logs the traceback, where we ensure + that the helper object doesn't participate in cycles, and only the + Future has a reference to it. + + The helper object is added when set_exception() is called. When + the Future is collected, and the helper is present, the helper + object is also collected, and its __del__() method will log the + traceback. When the Future's result() or exception() method is + called (and a helper object is present), it removes the the helper + object, after calling its clear() method to prevent it from + logging. + + One downside is that we do a fair amount of work to extract the + traceback from the exception, even when it is never logged. It + would seem cheaper to just store the exception object, but that + references the traceback, which references stack frames, which may + reference the Future, which references the _TracebackLogger, and + then the _TracebackLogger would be included in a cycle, which is + what we're trying to avoid! As an optimization, we don't + immediately format the exception; we only do the work when + activate() is called, which call is delayed until after all the + Future's callbacks have run. Since usually a Future has at least + one callback (typically set by 'yield From') and usually that + callback extracts the callback, thereby removing the need to + format the exception. + + PS. I don't claim credit for this solution. I first heard of it + in a discussion about closing files when they are collected. + """ + + __slots__ = ('exc_info', 'formatted_tb') + + def __init__(self, exc_info): + self.exc_info = exc_info + self.formatted_tb = None + + def activate(self): + exc_info = self.exc_info + if exc_info is not None: + self.exc_info = None + self.formatted_tb = traceback.format_exception(*exc_info) + + def clear(self): + self.exc_info = None + self.formatted_tb = None + + def __del__(self): + if self.formatted_tb: + app_log.error('Future exception was never retrieved: %s', + ''.join(self.formatted_tb).rstrip()) + class Future(object): """Placeholder for an asynchronous result. @@ -68,12 +151,23 @@ class Future(object): if that package was available and fall back to the thread-unsafe implementation if it was not. + .. versionchanged:: 4.1 + If a `.Future` contains an error but that error is never observed + (by calling ``result()``, ``exception()``, or ``exc_info()``), + a stack trace will be logged when the `.Future` is garbage collected. + This normally indicates an error in the application, but in cases + where it results in undesired logging it may be necessary to + suppress the logging by ensuring that the exception is observed: + ``f.add_done_callback(lambda f: f.exception())``. """ def __init__(self): self._done = False self._result = None - self._exception = None self._exc_info = None + + self._log_traceback = False # Used for Python >= 3.4 + self._tb_logger = None # Used for Python <= 3.3 + self._callbacks = [] def cancel(self): @@ -100,16 +194,21 @@ class Future(object): """Returns True if the future has finished running.""" return self._done + def _clear_tb_log(self): + self._log_traceback = False + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + def result(self, timeout=None): """If the operation succeeded, return its result. If it failed, re-raise its exception. """ + self._clear_tb_log() if self._result is not None: return self._result if self._exc_info is not None: raise_exc_info(self._exc_info) - elif self._exception is not None: - raise self._exception self._check_done() return self._result @@ -117,8 +216,9 @@ class Future(object): """If the operation raised an exception, return the `Exception` object. Otherwise returns None. """ - if self._exception is not None: - return self._exception + self._clear_tb_log() + if self._exc_info is not None: + return self._exc_info[1] else: self._check_done() return None @@ -147,14 +247,17 @@ class Future(object): def set_exception(self, exception): """Sets the exception of a ``Future.``""" - self._exception = exception - self._set_done() + self.set_exc_info( + (exception.__class__, + exception, + getattr(exception, '__traceback__', None))) def exc_info(self): """Returns a tuple in the same format as `sys.exc_info` or None. .. versionadded:: 4.0 """ + self._clear_tb_log() return self._exc_info def set_exc_info(self, exc_info): @@ -165,7 +268,18 @@ class Future(object): .. versionadded:: 4.0 """ self._exc_info = exc_info - self.set_exception(exc_info[1]) + self._log_traceback = True + if not _GC_CYCLE_FINALIZERS: + self._tb_logger = _TracebackLogger(exc_info) + + try: + self._set_done() + finally: + # Activate the logger after all callbacks have had a + # chance to call result() or exception(). + if self._log_traceback and self._tb_logger is not None: + self._tb_logger.activate() + self._exc_info = exc_info def _check_done(self): if not self._done: @@ -177,10 +291,25 @@ class Future(object): try: cb(self) except Exception: - app_log.exception('exception calling callback %r for %r', + app_log.exception('Exception in callback %r for %r', cb, self) self._callbacks = None + # On Python 3.3 or older, objects with a destructor part of a reference + # cycle are never destroyed. It's no longer the case on Python 3.4 thanks to + # the PEP 442. + if _GC_CYCLE_FINALIZERS: + def __del__(self): + if not self._log_traceback: + # set_exception() was not called, or result() or exception() + # has consumed the exception + return + + tb = traceback.format_exception(*self._exc_info) + + app_log.error('Future %r exception was never retrieved: %s', + self, ''.join(tb).rstrip()) + TracebackFuture = Future if futures is None: @@ -208,24 +337,42 @@ class DummyExecutor(object): dummy_executor = DummyExecutor() -def run_on_executor(fn): +def run_on_executor(*args, **kwargs): """Decorator to run a synchronous method asynchronously on an executor. The decorated method may be called with a ``callback`` keyword argument and returns a future. - This decorator should be used only on methods of objects with attributes - ``executor`` and ``io_loop``. + The `.IOLoop` and executor to be used are determined by the ``io_loop`` + and ``executor`` attributes of ``self``. To use different attributes, + pass keyword arguments to the decorator:: + + @run_on_executor(executor='_thread_pool') + def foo(self): + pass + + .. versionchanged:: 4.2 + Added keyword arguments to use alternative attributes. """ - @functools.wraps(fn) - def wrapper(self, *args, **kwargs): - callback = kwargs.pop("callback", None) - future = self.executor.submit(fn, self, *args, **kwargs) - if callback: - self.io_loop.add_future(future, - lambda future: callback(future.result())) - return future - return wrapper + def run_on_executor_decorator(fn): + executor = kwargs.get("executor", "executor") + io_loop = kwargs.get("io_loop", "io_loop") + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + callback = kwargs.pop("callback", None) + future = getattr(self, executor).submit(fn, self, *args, **kwargs) + if callback: + getattr(self, io_loop).add_future( + future, lambda future: callback(future.result())) + return future + return wrapper + if args and kwargs: + raise ValueError("cannot combine positional and keyword args") + if len(args) == 1: + return run_on_executor_decorator(args[0]) + elif len(args) != 0: + raise ValueError("expected 1 argument, got %d", len(args)) + return run_on_executor_decorator _NO_RESULT = object() @@ -250,7 +397,9 @@ def return_future(f): wait for the function to complete (perhaps by yielding it in a `.gen.engine` function, or passing it to `.IOLoop.add_future`). - Usage:: + Usage: + + .. testcode:: @return_future def future_func(arg1, arg2, callback): @@ -262,6 +411,8 @@ def return_future(f): yield future_func(arg1, arg2) callback() + .. + Note that ``@return_future`` and ``@gen.engine`` can be applied to the same function, provided ``@return_future`` appears first. However, consider using ``@gen.coroutine`` instead of this combination. @@ -293,7 +444,7 @@ def return_future(f): # If the initial synchronous part of f() raised an exception, # go ahead and raise it to the caller directly without waiting # for them to inspect the Future. - raise_exc_info(exc_info) + future.result() # If the caller passed in a callback, schedule it to be called # when the future resolves. It is important that this happens diff --git a/tornado/curl_httpclient.py b/tornado/curl_httpclient.py index 68047cc9..ae6f114a 100644 --- a/tornado/curl_httpclient.py +++ b/tornado/curl_httpclient.py @@ -28,12 +28,13 @@ from io import BytesIO from tornado import httputil from tornado import ioloop -from tornado.log import gen_log from tornado import stack_context from tornado.escape import utf8, native_str from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main +curl_log = logging.getLogger('tornado.curl_httpclient') + class CurlAsyncHTTPClient(AsyncHTTPClient): def initialize(self, io_loop, max_clients=10, defaults=None): @@ -207,9 +208,25 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): "callback": callback, "curl_start_time": time.time(), } - self._curl_setup_request(curl, request, curl.info["buffer"], - curl.info["headers"]) - self._multi.add_handle(curl) + try: + self._curl_setup_request( + curl, request, curl.info["buffer"], + curl.info["headers"]) + except Exception as e: + # If there was an error in setup, pass it on + # to the callback. Note that allowing the + # error to escape here will appear to work + # most of the time since we are still in the + # caller's original stack frame, but when + # _process_queue() is called from + # _finish_pending_requests the exceptions have + # nowhere to go. + callback(HTTPResponse( + request=request, + code=599, + error=e)) + else: + self._multi.add_handle(curl) if not started: break @@ -257,7 +274,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): def _curl_create(self): curl = pycurl.Curl() - if gen_log.isEnabledFor(logging.DEBUG): + if curl_log.isEnabledFor(logging.DEBUG): curl.setopt(pycurl.VERBOSE, 1) curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug) return curl @@ -286,10 +303,10 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): curl.setopt(pycurl.HEADERFUNCTION, functools.partial(self._curl_header_callback, - headers, request.header_callback)) + headers, request.header_callback)) if request.streaming_callback: - write_function = lambda chunk: self.io_loop.add_callback( - request.streaming_callback, chunk) + def write_function(chunk): + self.io_loop.add_callback(request.streaming_callback, chunk) else: write_function = buffer.write if bytes is str: # py2 @@ -381,6 +398,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): % request.method) request_buffer = BytesIO(utf8(request.body)) + def ioctl(cmd): if cmd == curl.IOCMD_RESTARTREAD: request_buffer.seek(0) @@ -403,11 +421,11 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): raise ValueError("Unsupported auth_mode %s" % request.auth_mode) curl.setopt(pycurl.USERPWD, native_str(userpwd)) - gen_log.debug("%s %s (username: %r)", request.method, request.url, - request.auth_username) + curl_log.debug("%s %s (username: %r)", request.method, request.url, + request.auth_username) else: curl.unsetopt(pycurl.USERPWD) - gen_log.debug("%s %s", request.method, request.url) + curl_log.debug("%s %s", request.method, request.url) if request.client_cert is not None: curl.setopt(pycurl.SSLCERT, request.client_cert) @@ -415,6 +433,9 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): if request.client_key is not None: curl.setopt(pycurl.SSLKEY, request.client_key) + if request.ssl_options is not None: + raise ValueError("ssl_options not supported in curl_httpclient") + if threading.activeCount() > 1: # libcurl/pycurl is not thread-safe by default. When multiple threads # are used, signals should be disabled. This has the side effect @@ -448,12 +469,12 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): def _curl_debug(self, debug_type, debug_msg): debug_types = ('I', '<', '>', '<', '>') if debug_type == 0: - gen_log.debug('%s', debug_msg.strip()) + curl_log.debug('%s', debug_msg.strip()) elif debug_type in (1, 2): for line in debug_msg.splitlines(): - gen_log.debug('%s %s', debug_types[debug_type], line) + curl_log.debug('%s %s', debug_types[debug_type], line) elif debug_type == 4: - gen_log.debug('%s %r', debug_types[debug_type], debug_msg) + curl_log.debug('%s %r', debug_types[debug_type], debug_msg) class CurlError(HTTPError): diff --git a/tornado/escape.py b/tornado/escape.py index 24be2264..2852cf51 100644 --- a/tornado/escape.py +++ b/tornado/escape.py @@ -82,7 +82,7 @@ def json_encode(value): # JSON permits but does not require forward slashes to be escaped. # This is useful when json data is emitted in a tags from prematurely terminating - # the javscript. Some json libraries do this escaping by default, + # the javascript. Some json libraries do this escaping by default, # although python's standard library does not, so we do it here. # http://stackoverflow.com/questions/1580647/json-why-are-forward-slashes-escaped return json.dumps(value).replace("`_ package on older +versions), additional types of objects may be yielded. Tornado includes +support for ``asyncio.Future`` and Twisted's ``Deferred`` class when +``tornado.platform.asyncio`` and ``tornado.platform.twisted`` are imported. +See the `convert_yielded` function to extend this mechanism. + .. versionchanged:: 3.2 Dict support added. + +.. versionchanged:: 4.1 + Support added for yielding ``asyncio`` Futures and Twisted Deferreds + via ``singledispatch``. + """ from __future__ import absolute_import, division, print_function, with_statement @@ -53,10 +81,21 @@ import functools import itertools import sys import types +import weakref from tornado.concurrent import Future, TracebackFuture, is_future, chain_future from tornado.ioloop import IOLoop +from tornado.log import app_log from tornado import stack_context +from tornado.util import raise_exc_info + +try: + from functools import singledispatch # py34+ +except ImportError as e: + try: + from singledispatch import singledispatch # backport + except ImportError: + singledispatch = None class KeyReuseError(Exception): @@ -101,9 +140,11 @@ def engine(func): which use ``self.finish()`` in place of a callback argument. """ func = _make_coroutine_wrapper(func, replace_callback=False) + @functools.wraps(func) def wrapper(*args, **kwargs): future = func(*args, **kwargs) + def final_callback(future): if future.result() is not None: raise ReturnValueIgnoredError( @@ -241,6 +282,113 @@ class Return(Exception): self.value = value +class WaitIterator(object): + """Provides an iterator to yield the results of futures as they finish. + + Yielding a set of futures like this: + + ``results = yield [future1, future2]`` + + pauses the coroutine until both ``future1`` and ``future2`` + return, and then restarts the coroutine with the results of both + futures. If either future is an exception, the expression will + raise that exception and all the results will be lost. + + If you need to get the result of each future as soon as possible, + or if you need the result of some futures even if others produce + errors, you can use ``WaitIterator``:: + + wait_iterator = gen.WaitIterator(future1, future2) + while not wait_iterator.done(): + try: + result = yield wait_iterator.next() + except Exception as e: + print("Error {} from {}".format(e, wait_iterator.current_future)) + else: + print("Result {} received from {} at {}".format( + result, wait_iterator.current_future, + wait_iterator.current_index)) + + Because results are returned as soon as they are available the + output from the iterator *will not be in the same order as the + input arguments*. If you need to know which future produced the + current result, you can use the attributes + ``WaitIterator.current_future``, or ``WaitIterator.current_index`` + to get the index of the future from the input list. (if keyword + arguments were used in the construction of the `WaitIterator`, + ``current_index`` will use the corresponding keyword). + + .. versionadded:: 4.1 + """ + def __init__(self, *args, **kwargs): + if args and kwargs: + raise ValueError( + "You must provide args or kwargs, not both") + + if kwargs: + self._unfinished = dict((f, k) for (k, f) in kwargs.items()) + futures = list(kwargs.values()) + else: + self._unfinished = dict((f, i) for (i, f) in enumerate(args)) + futures = args + + self._finished = collections.deque() + self.current_index = self.current_future = None + self._running_future = None + + # Use a weak reference to self to avoid cycles that may delay + # garbage collection. + self_ref = weakref.ref(self) + for future in futures: + future.add_done_callback(functools.partial( + self._done_callback, self_ref)) + + def done(self): + """Returns True if this iterator has no more results.""" + if self._finished or self._unfinished: + return False + # Clear the 'current' values when iteration is done. + self.current_index = self.current_future = None + return True + + def next(self): + """Returns a `.Future` that will yield the next available result. + + Note that this `.Future` will not be the same object as any of + the inputs. + """ + self._running_future = TracebackFuture() + # As long as there is an active _running_future, we must + # ensure that the WaitIterator is not GC'd (due to the + # use of weak references in __init__). Add a callback that + # references self so there is a hard reference that will be + # cleared automatically when this Future finishes. + self._running_future.add_done_callback(lambda f: self) + + if self._finished: + self._return_result(self._finished.popleft()) + + return self._running_future + + @staticmethod + def _done_callback(self_ref, done): + self = self_ref() + if self is not None: + if self._running_future and not self._running_future.done(): + self._return_result(done) + else: + self._finished.append(done) + + def _return_result(self, done): + """Called set the returned future's state that of the future + we yielded, and set the current future for the iterator. + """ + chain_future(done, self._running_future) + + self.current_future = done + self.current_index = self._unfinished.pop(done) + + class YieldPoint(object): """Base class for objects that may be yielded from the generator. @@ -355,11 +503,13 @@ def Task(func, *args, **kwargs): yielded. """ future = Future() + def handle_exception(typ, value, tb): if future.done(): return False future.set_exc_info((typ, value, tb)) return True + def set_result(result): if future.done(): return @@ -371,6 +521,11 @@ def Task(func, *args, **kwargs): class YieldFuture(YieldPoint): def __init__(self, future, io_loop=None): + """Adapts a `.Future` to the `YieldPoint` interface. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. + """ self.future = future self.io_loop = io_loop or IOLoop.current() @@ -382,7 +537,7 @@ class YieldFuture(YieldPoint): self.io_loop.add_future(self.future, runner.result_callback(self.key)) else: self.runner = None - self.result = self.future.result() + self.result_fn = self.future.result def is_ready(self): if self.runner is not None: @@ -394,7 +549,7 @@ class YieldFuture(YieldPoint): if self.runner is not None: return self.runner.pop_result(self.key).result() else: - return self.result + return self.result_fn() class Multi(YieldPoint): @@ -408,8 +563,18 @@ class Multi(YieldPoint): Instead of a list, the argument may also be a dictionary whose values are Futures, in which case a parallel dictionary is returned mapping the same keys to their results. + + It is not normally necessary to call this class directly, as it + will be created automatically as needed. However, calling it directly + allows you to use the ``quiet_exceptions`` argument to control + the logging of multiple exceptions. + + .. versionchanged:: 4.2 + If multiple ``YieldPoints`` fail, any exceptions after the first + (which is raised) will be logged. Added the ``quiet_exceptions`` + argument to suppress this logging for selected exception types. """ - def __init__(self, children): + def __init__(self, children, quiet_exceptions=()): self.keys = None if isinstance(children, dict): self.keys = list(children.keys()) @@ -421,6 +586,7 @@ class Multi(YieldPoint): self.children.append(i) assert all(isinstance(i, YieldPoint) for i in self.children) self.unfinished_children = set(self.children) + self.quiet_exceptions = quiet_exceptions def start(self, runner): for i in self.children: @@ -433,14 +599,27 @@ class Multi(YieldPoint): return not self.unfinished_children def get_result(self): - result = (i.get_result() for i in self.children) + result_list = [] + exc_info = None + for f in self.children: + try: + result_list.append(f.get_result()) + except Exception as e: + if exc_info is None: + exc_info = sys.exc_info() + else: + if not isinstance(e, self.quiet_exceptions): + app_log.error("Multiple exceptions in yield list", + exc_info=True) + if exc_info is not None: + raise_exc_info(exc_info) if self.keys is not None: - return dict(zip(self.keys, result)) + return dict(zip(self.keys, result_list)) else: - return list(result) + return list(result_list) -def multi_future(children): +def multi_future(children, quiet_exceptions=()): """Wait for multiple asynchronous futures in parallel. Takes a list of ``Futures`` (but *not* other ``YieldPoints``) and returns @@ -453,12 +632,21 @@ def multi_future(children): Futures, in which case a parallel dictionary is returned mapping the same keys to their results. - It is not necessary to call `multi_future` explcitly, since the engine will - do so automatically when the generator yields a list of `Futures`. - This function is faster than the `Multi` `YieldPoint` because it does not - require the creation of a stack context. + It is not normally necessary to call `multi_future` explcitly, + since the engine will do so automatically when the generator + yields a list of ``Futures``. However, calling it directly + allows you to use the ``quiet_exceptions`` argument to control + the logging of multiple exceptions. + + This function is faster than the `Multi` `YieldPoint` because it + does not require the creation of a stack context. .. versionadded:: 4.0 + + .. versionchanged:: 4.2 + If multiple ``Futures`` fail, any exceptions after the first (which is + raised) will be logged. Added the ``quiet_exceptions`` + argument to suppress this logging for selected exception types. """ if isinstance(children, dict): keys = list(children.keys()) @@ -471,20 +659,32 @@ def multi_future(children): future = Future() if not children: future.set_result({} if keys is not None else []) + def callback(f): unfinished_children.remove(f) if not unfinished_children: - try: - result_list = [i.result() for i in children] - except Exception: - future.set_exc_info(sys.exc_info()) - else: + result_list = [] + for f in children: + try: + result_list.append(f.result()) + except Exception as e: + if future.done(): + if not isinstance(e, quiet_exceptions): + app_log.error("Multiple exceptions in yield list", + exc_info=True) + else: + future.set_exc_info(sys.exc_info()) + if not future.done(): if keys is not None: future.set_result(dict(zip(keys, result_list))) else: future.set_result(result_list) + + listening = set() for f in children: - f.add_done_callback(callback) + if f not in listening: + listening.add(f) + f.add_done_callback(callback) return future @@ -504,7 +704,7 @@ def maybe_future(x): return fut -def with_timeout(timeout, future, io_loop=None): +def with_timeout(timeout, future, io_loop=None, quiet_exceptions=()): """Wraps a `.Future` in a timeout. Raises `TimeoutError` if the input future does not complete before @@ -512,9 +712,17 @@ def with_timeout(timeout, future, io_loop=None): `.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or an absolute time relative to `.IOLoop.time`) + If the wrapped `.Future` fails after it has timed out, the exception + will be logged unless it is of a type contained in ``quiet_exceptions`` + (which may be an exception type or a sequence of types). + Currently only supports Futures, not other `YieldPoint` classes. .. versionadded:: 4.0 + + .. versionchanged:: 4.1 + Added the ``quiet_exceptions`` argument and the logging of unhandled + exceptions. """ # TODO: allow yield points in addition to futures? # Tricky to do with stack_context semantics. @@ -528,9 +736,21 @@ def with_timeout(timeout, future, io_loop=None): chain_future(future, result) if io_loop is None: io_loop = IOLoop.current() + + def error_callback(future): + try: + future.result() + except Exception as e: + if not isinstance(e, quiet_exceptions): + app_log.error("Exception in Future %r after timeout", + future, exc_info=True) + + def timeout_callback(): + result.set_exception(TimeoutError("Timeout")) + # In case the wrapped future goes on to fail, log it. + future.add_done_callback(error_callback) timeout_handle = io_loop.add_timeout( - timeout, - lambda: result.set_exception(TimeoutError("Timeout"))) + timeout, timeout_callback) if isinstance(future, Future): # We know this future will resolve on the IOLoop, so we don't # need the extra thread-safety of IOLoop.add_future (and we also @@ -545,6 +765,25 @@ def with_timeout(timeout, future, io_loop=None): return result +def sleep(duration): + """Return a `.Future` that resolves after the given number of seconds. + + When used with ``yield`` in a coroutine, this is a non-blocking + analogue to `time.sleep` (which should not be used in coroutines + because it is blocking):: + + yield gen.sleep(0.5) + + Note that calling this function on its own does nothing; you must + wait on the `.Future` it returns (usually by yielding it). + + .. versionadded:: 4.1 + """ + f = Future() + IOLoop.current().call_later(duration, lambda: f.set_result(None)) + return f + + _null_future = Future() _null_future.set_result(None) @@ -638,13 +877,20 @@ class Runner(object): self.future = None try: orig_stack_contexts = stack_context._state.contexts + exc_info = None + try: value = future.result() except Exception: self.had_exception = True - yielded = self.gen.throw(*sys.exc_info()) + exc_info = sys.exc_info() + + if exc_info is not None: + yielded = self.gen.throw(*exc_info) + exc_info = None else: yielded = self.gen.send(value) + if stack_context._state.contexts is not orig_stack_contexts: self.gen.throw( stack_context.StackContextInconsistentError( @@ -678,19 +924,20 @@ class Runner(object): self.running = False def handle_yield(self, yielded): - if isinstance(yielded, list): - if all(is_future(f) for f in yielded): - yielded = multi_future(yielded) - else: - yielded = Multi(yielded) - elif isinstance(yielded, dict): - if all(is_future(f) for f in yielded.values()): - yielded = multi_future(yielded) - else: - yielded = Multi(yielded) + # Lists containing YieldPoints require stack contexts; + # other lists are handled via multi_future in convert_yielded. + if (isinstance(yielded, list) and + any(isinstance(f, YieldPoint) for f in yielded)): + yielded = Multi(yielded) + elif (isinstance(yielded, dict) and + any(isinstance(f, YieldPoint) for f in yielded.values())): + yielded = Multi(yielded) if isinstance(yielded, YieldPoint): + # YieldPoints are too closely coupled to the Runner to go + # through the generic convert_yielded mechanism. self.future = TracebackFuture() + def start_yield_point(): try: yielded.start(self) @@ -702,12 +949,14 @@ class Runner(object): except Exception: self.future = TracebackFuture() self.future.set_exc_info(sys.exc_info()) + if self.stack_context_deactivate is None: # Start a stack context if this is the first # YieldPoint we've seen. with stack_context.ExceptionStackContext( self.handle_exception) as deactivate: self.stack_context_deactivate = deactivate + def cb(): start_yield_point() self.run() @@ -715,16 +964,17 @@ class Runner(object): return False else: start_yield_point() - elif is_future(yielded): - self.future = yielded - if not self.future.done() or self.future is moment: - self.io_loop.add_future( - self.future, lambda f: self.run()) - return False else: - self.future = TracebackFuture() - self.future.set_exception(BadYieldError( - "yielded unknown object %r" % (yielded,))) + try: + self.future = convert_yielded(yielded) + except BadYieldError: + self.future = TracebackFuture() + self.future.set_exc_info(sys.exc_info()) + + if not self.future.done() or self.future is moment: + self.io_loop.add_future( + self.future, lambda f: self.run()) + return False return True def result_callback(self, key): @@ -763,3 +1013,30 @@ def _argument_adapter(callback): else: callback(None) return wrapper + + +def convert_yielded(yielded): + """Convert a yielded object into a `.Future`. + + The default implementation accepts lists, dictionaries, and Futures. + + If the `~functools.singledispatch` library is available, this function + may be extended to support additional types. For example:: + + @convert_yielded.register(asyncio.Future) + def _(asyncio_future): + return tornado.platform.asyncio.to_tornado_future(asyncio_future) + + .. versionadded:: 4.1 + """ + # Lists and dicts containing YieldPoints were handled separately + # via Multi(). + if isinstance(yielded, (list, dict)): + return multi_future(yielded) + elif is_future(yielded): + return yielded + else: + raise BadYieldError("yielded unknown object %r" % (yielded,)) + +if singledispatch is not None: + convert_yielded = singledispatch(convert_yielded) diff --git a/tornado/http1connection.py b/tornado/http1connection.py index 8a5f46c4..6226ef7a 100644 --- a/tornado/http1connection.py +++ b/tornado/http1connection.py @@ -37,6 +37,7 @@ class _QuietException(Exception): def __init__(self): pass + class _ExceptionLoggingContext(object): """Used with the ``with`` statement when calling delegate methods to log any exceptions with the given logger. Any exceptions caught are @@ -53,6 +54,7 @@ class _ExceptionLoggingContext(object): self.logger.error("Uncaught exception", exc_info=(typ, value, tb)) raise _QuietException + class HTTP1ConnectionParameters(object): """Parameters for `.HTTP1Connection` and `.HTTP1ServerConnection`. """ @@ -162,7 +164,8 @@ class HTTP1Connection(httputil.HTTPConnection): header_data = yield gen.with_timeout( self.stream.io_loop.time() + self.params.header_timeout, header_future, - io_loop=self.stream.io_loop) + io_loop=self.stream.io_loop, + quiet_exceptions=iostream.StreamClosedError) except gen.TimeoutError: self.close() raise gen.Return(False) @@ -201,7 +204,7 @@ class HTTP1Connection(httputil.HTTPConnection): # 1xx responses should never indicate the presence of # a body. if ('Content-Length' in headers or - 'Transfer-Encoding' in headers): + 'Transfer-Encoding' in headers): raise httputil.HTTPInputError( "Response code %d cannot have body" % code) # TODO: client delegates will get headers_received twice @@ -221,7 +224,8 @@ class HTTP1Connection(httputil.HTTPConnection): try: yield gen.with_timeout( self.stream.io_loop.time() + self._body_timeout, - body_future, self.stream.io_loop) + body_future, self.stream.io_loop, + quiet_exceptions=iostream.StreamClosedError) except gen.TimeoutError: gen_log.info("Timeout reading body from %s", self.context) @@ -326,8 +330,10 @@ class HTTP1Connection(httputil.HTTPConnection): def write_headers(self, start_line, headers, chunk=None, callback=None): """Implements `.HTTPConnection.write_headers`.""" + lines = [] if self.is_client: self._request_start_line = start_line + lines.append(utf8('%s %s HTTP/1.1' % (start_line[0], start_line[1]))) # Client requests with a non-empty body must have either a # Content-Length or a Transfer-Encoding. self._chunking_output = ( @@ -336,6 +342,7 @@ class HTTP1Connection(httputil.HTTPConnection): 'Transfer-Encoding' not in headers) else: self._response_start_line = start_line + lines.append(utf8('HTTP/1.1 %s %s' % (start_line[1], start_line[2]))) self._chunking_output = ( # TODO: should this use # self._request_start_line.version or @@ -365,7 +372,6 @@ class HTTP1Connection(httputil.HTTPConnection): self._expected_content_remaining = int(headers['Content-Length']) else: self._expected_content_remaining = None - lines = [utf8("%s %s %s" % start_line)] lines.extend([utf8(n) + b": " + utf8(v) for n, v in headers.get_all()]) for line in lines: if b'\n' in line: @@ -374,6 +380,7 @@ class HTTP1Connection(httputil.HTTPConnection): if self.stream.closed(): future = self._write_future = Future() future.set_exception(iostream.StreamClosedError()) + future.exception() else: if callback is not None: self._write_callback = stack_context.wrap(callback) @@ -412,6 +419,7 @@ class HTTP1Connection(httputil.HTTPConnection): if self.stream.closed(): future = self._write_future = Future() self._write_future.set_exception(iostream.StreamClosedError()) + self._write_future.exception() else: if callback is not None: self._write_callback = stack_context.wrap(callback) @@ -451,6 +459,9 @@ class HTTP1Connection(httputil.HTTPConnection): self._pending_write.add_done_callback(self._finish_request) def _on_write_complete(self, future): + exc = future.exception() + if exc is not None and not isinstance(exc, iostream.StreamClosedError): + future.result() if self._write_callback is not None: callback = self._write_callback self._write_callback = None @@ -491,8 +502,9 @@ class HTTP1Connection(httputil.HTTPConnection): # we SHOULD ignore at least one empty line before the request. # http://tools.ietf.org/html/rfc7230#section-3.5 data = native_str(data.decode('latin1')).lstrip("\r\n") - eol = data.find("\r\n") - start_line = data[:eol] + # RFC 7230 section allows for both CRLF and bare LF. + eol = data.find("\n") + start_line = data[:eol].rstrip("\r") try: headers = httputil.HTTPHeaders.parse(data[eol:]) except ValueError: @@ -686,9 +698,8 @@ class HTTP1ServerConnection(object): # This exception was already logged. conn.close() return - except Exception as e: - if 1 != e.errno: - gen_log.error("Uncaught exception", exc_info=True) + except Exception: + gen_log.error("Uncaught exception", exc_info=True) conn.close() return if not ret: diff --git a/tornado/httpclient.py b/tornado/httpclient.py index df429517..c2e68623 100644 --- a/tornado/httpclient.py +++ b/tornado/httpclient.py @@ -72,7 +72,7 @@ class HTTPClient(object): http_client.close() """ def __init__(self, async_client_class=None, **kwargs): - self._io_loop = IOLoop() + self._io_loop = IOLoop(make_current=False) if async_client_class is None: async_client_class = AsyncHTTPClient self._async_client = async_client_class(self._io_loop, **kwargs) @@ -95,11 +95,11 @@ class HTTPClient(object): If it is a string, we construct an `HTTPRequest` using any additional kwargs: ``HTTPRequest(request, **kwargs)`` - If an error occurs during the fetch, we raise an `HTTPError`. + If an error occurs during the fetch, we raise an `HTTPError` unless + the ``raise_error`` keyword argument is set to False. """ response = self._io_loop.run_sync(functools.partial( self._async_client.fetch, request, **kwargs)) - response.rethrow() return response @@ -136,6 +136,9 @@ class AsyncHTTPClient(Configurable): # or with force_instance: client = AsyncHTTPClient(force_instance=True, defaults=dict(user_agent="MyUserAgent")) + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ @classmethod def configurable_base(cls): @@ -200,7 +203,7 @@ class AsyncHTTPClient(Configurable): raise RuntimeError("inconsistent AsyncHTTPClient cache") del self._instance_cache[self.io_loop] - def fetch(self, request, callback=None, **kwargs): + def fetch(self, request, callback=None, raise_error=True, **kwargs): """Executes a request, asynchronously returning an `HTTPResponse`. The request may be either a string URL or an `HTTPRequest` object. @@ -208,8 +211,10 @@ class AsyncHTTPClient(Configurable): kwargs: ``HTTPRequest(request, **kwargs)`` This method returns a `.Future` whose result is an - `HTTPResponse`. The ``Future`` will raise an `HTTPError` if - the request returned a non-200 response code. + `HTTPResponse`. By default, the ``Future`` will raise an `HTTPError` + if the request returned a non-200 response code. Instead, if + ``raise_error`` is set to False, the response will always be + returned regardless of the response code. If a ``callback`` is given, it will be invoked with the `HTTPResponse`. In the callback interface, `HTTPError` is not automatically raised. @@ -243,7 +248,7 @@ class AsyncHTTPClient(Configurable): future.add_done_callback(handle_future) def handle_response(response): - if response.error: + if raise_error and response.error: future.set_exception(response.error) else: future.set_result(response) @@ -304,7 +309,8 @@ class HTTPRequest(object): validate_cert=None, ca_certs=None, allow_ipv6=None, client_key=None, client_cert=None, body_producer=None, - expect_100_continue=False, decompress_response=None): + expect_100_continue=False, decompress_response=None, + ssl_options=None): r"""All parameters except ``url`` are optional. :arg string url: URL to fetch @@ -374,12 +380,15 @@ class HTTPRequest(object): :arg string ca_certs: filename of CA certificates in PEM format, or None to use defaults. See note below when used with ``curl_httpclient``. - :arg bool allow_ipv6: Use IPv6 when available? Default is false in - ``simple_httpclient`` and true in ``curl_httpclient`` :arg string client_key: Filename for client SSL key, if any. See note below when used with ``curl_httpclient``. :arg string client_cert: Filename for client SSL certificate, if any. See note below when used with ``curl_httpclient``. + :arg ssl.SSLContext ssl_options: `ssl.SSLContext` object for use in + ``simple_httpclient`` (unsupported by ``curl_httpclient``). + Overrides ``validate_cert``, ``ca_certs``, ``client_key``, + and ``client_cert``. + :arg bool allow_ipv6: Use IPv6 when available? Default is true. :arg bool expect_100_continue: If true, send the ``Expect: 100-continue`` header and wait for a continue response before sending the request body. Only supported with @@ -402,6 +411,9 @@ class HTTPRequest(object): .. versionadded:: 4.0 The ``body_producer`` and ``expect_100_continue`` arguments. + + .. versionadded:: 4.2 + The ``ssl_options`` argument. """ # Note that some of these attributes go through property setters # defined below. @@ -439,6 +451,7 @@ class HTTPRequest(object): self.allow_ipv6 = allow_ipv6 self.client_key = client_key self.client_cert = client_cert + self.ssl_options = ssl_options self.expect_100_continue = expect_100_continue self.start_time = time.time() diff --git a/tornado/httpserver.py b/tornado/httpserver.py index 05d0e186..2dd04dd7 100644 --- a/tornado/httpserver.py +++ b/tornado/httpserver.py @@ -37,35 +37,17 @@ from tornado import httputil from tornado import iostream from tornado import netutil from tornado.tcpserver import TCPServer +from tornado.util import Configurable -class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): +class HTTPServer(TCPServer, Configurable, + httputil.HTTPServerConnectionDelegate): r"""A non-blocking, single-threaded HTTP server. - A server is defined by either a request callback that takes a - `.HTTPServerRequest` as an argument or a `.HTTPServerConnectionDelegate` - instance. - - A simple example server that echoes back the URI you requested:: - - import tornado.httpserver - import tornado.ioloop - from tornado import httputil - - def handle_request(request): - message = "You requested %s\n" % request.uri - request.connection.write_headers( - httputil.ResponseStartLine('HTTP/1.1', 200, 'OK'), - {"Content-Length": str(len(message))}) - request.connection.write(message) - request.connection.finish() - - http_server = tornado.httpserver.HTTPServer(handle_request) - http_server.listen(8888) - tornado.ioloop.IOLoop.instance().start() - - Applications should use the methods of `.HTTPConnection` to write - their response. + A server is defined by a subclass of `.HTTPServerConnectionDelegate`, + or, for backwards compatibility, a callback that takes an + `.HTTPServerRequest` as an argument. The delegate is usually a + `tornado.web.Application`. `HTTPServer` supports keep-alive connections by default (automatically for HTTP/1.1, or for HTTP/1.0 when the client @@ -80,15 +62,15 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): if Tornado is run behind an SSL-decoding proxy that does not set one of the supported ``xheaders``. - To make this server serve SSL traffic, send the ``ssl_options`` dictionary - argument with the arguments required for the `ssl.wrap_socket` method, - including ``certfile`` and ``keyfile``. (In Python 3.2+ you can pass - an `ssl.SSLContext` object instead of a dict):: + To make this server serve SSL traffic, send the ``ssl_options`` keyword + argument with an `ssl.SSLContext` object. For compatibility with older + versions of Python ``ssl_options`` may also be a dictionary of keyword + arguments for the `ssl.wrap_socket` method.:: - HTTPServer(applicaton, ssl_options={ - "certfile": os.path.join(data_dir, "mydomain.crt"), - "keyfile": os.path.join(data_dir, "mydomain.key"), - }) + ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"), + os.path.join(data_dir, "mydomain.key")) + HTTPServer(applicaton, ssl_options=ssl_ctx) `HTTPServer` initialization follows one of three patterns (the initialization methods are defined on `tornado.tcpserver.TCPServer`): @@ -97,7 +79,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): server = HTTPServer(app) server.listen(8888) - IOLoop.instance().start() + IOLoop.current().start() In many cases, `tornado.web.Application.listen` can be used to avoid the need to explicitly create the `HTTPServer`. @@ -108,7 +90,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): server = HTTPServer(app) server.bind(8888) server.start(0) # Forks multiple sub-processes - IOLoop.instance().start() + IOLoop.current().start() When using this interface, an `.IOLoop` must *not* be passed to the `HTTPServer` constructor. `~.TCPServer.start` will always start @@ -120,7 +102,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): tornado.process.fork_processes(0) server = HTTPServer(app) server.add_sockets(sockets) - IOLoop.instance().start() + IOLoop.current().start() The `~.TCPServer.add_sockets` interface is more complicated, but it can be used with `tornado.process.fork_processes` to @@ -134,13 +116,29 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): ``idle_connection_timeout``, ``body_timeout``, ``max_body_size`` arguments. Added support for `.HTTPServerConnectionDelegate` instances as ``request_callback``. + + .. versionchanged:: 4.1 + `.HTTPServerConnectionDelegate.start_request` is now called with + two arguments ``(server_conn, request_conn)`` (in accordance with the + documentation) instead of one ``(request_conn)``. + + .. versionchanged:: 4.2 + `HTTPServer` is now a subclass of `tornado.util.Configurable`. """ - def __init__(self, request_callback, no_keep_alive=False, io_loop=None, - xheaders=False, ssl_options=None, protocol=None, - decompress_request=False, - chunk_size=None, max_header_size=None, - idle_connection_timeout=None, body_timeout=None, - max_body_size=None, max_buffer_size=None): + def __init__(self, *args, **kwargs): + # Ignore args to __init__; real initialization belongs in + # initialize since we're Configurable. (there's something + # weird in initialization order between this class, + # Configurable, and TCPServer so we can't leave __init__ out + # completely) + pass + + def initialize(self, request_callback, no_keep_alive=False, io_loop=None, + xheaders=False, ssl_options=None, protocol=None, + decompress_request=False, + chunk_size=None, max_header_size=None, + idle_connection_timeout=None, body_timeout=None, + max_body_size=None, max_buffer_size=None): self.request_callback = request_callback self.no_keep_alive = no_keep_alive self.xheaders = xheaders @@ -157,6 +155,14 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): read_chunk_size=chunk_size) self._connections = set() + @classmethod + def configurable_base(cls): + return HTTPServer + + @classmethod + def configurable_default(cls): + return HTTPServer + @gen.coroutine def close_all_connections(self): while self._connections: @@ -173,7 +179,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): conn.start_serving(self) def start_request(self, server_conn, request_conn): - return _ServerRequestAdapter(self, request_conn) + return _ServerRequestAdapter(self, server_conn, request_conn) def on_close(self, server_conn): self._connections.remove(server_conn) @@ -246,13 +252,14 @@ class _ServerRequestAdapter(httputil.HTTPMessageDelegate): """Adapts the `HTTPMessageDelegate` interface to the interface expected by our clients. """ - def __init__(self, server, connection): + def __init__(self, server, server_conn, request_conn): self.server = server - self.connection = connection + self.connection = request_conn self.request = None if isinstance(server.request_callback, httputil.HTTPServerConnectionDelegate): - self.delegate = server.request_callback.start_request(connection) + self.delegate = server.request_callback.start_request( + server_conn, request_conn) self._chunks = None else: self.delegate = None diff --git a/tornado/httputil.py b/tornado/httputil.py index f5c9c04f..f5fea213 100644 --- a/tornado/httputil.py +++ b/tornado/httputil.py @@ -62,6 +62,11 @@ except ImportError: pass +# RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line +# terminator and ignore any preceding CR. +_CRLF_RE = re.compile(r'\r?\n') + + class _NormalizedHeaderCache(dict): """Dynamic cached mapping of header names to Http-Header-Case. @@ -193,7 +198,7 @@ class HTTPHeaders(dict): [('Content-Length', '42'), ('Content-Type', 'text/html')] """ h = cls() - for line in headers.splitlines(): + for line in _CRLF_RE.split(headers): if line: h.parse_line(line) return h @@ -229,6 +234,14 @@ class HTTPHeaders(dict): # default implementation returns dict(self), not the subclass return HTTPHeaders(self) + # Use our overridden copy method for the copy.copy module. + __copy__ = copy + + def __deepcopy__(self, memo_dict): + # Our values are immutable strings, so our standard copy is + # effectively a deep copy. + return self.copy() + class HTTPServerRequest(object): """A single HTTP request. @@ -331,7 +344,7 @@ class HTTPServerRequest(object): self.uri = uri self.version = version self.headers = headers or HTTPHeaders() - self.body = body or "" + self.body = body or b"" # set remote IP and protocol context = getattr(connection, 'context', None) @@ -380,6 +393,8 @@ class HTTPServerRequest(object): to write the response. """ assert isinstance(chunk, bytes) + assert self.version.startswith("HTTP/1."), \ + "deprecated interface ony supported in HTTP/1.x" self.connection.write(chunk, callback=callback) def finish(self): @@ -406,15 +421,14 @@ class HTTPServerRequest(object): def get_ssl_certificate(self, binary_form=False): """Returns the client's SSL certificate, if any. - To use client certificates, the HTTPServer must have been constructed - with cert_reqs set in ssl_options, e.g.:: + To use client certificates, the HTTPServer's + `ssl.SSLContext.verify_mode` field must be set, e.g.:: - server = HTTPServer(app, - ssl_options=dict( - certfile="foo.crt", - keyfile="foo.key", - cert_reqs=ssl.CERT_REQUIRED, - ca_certs="cacert.crt")) + ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_ctx.load_cert_chain("foo.crt", "foo.key") + ssl_ctx.load_verify_locations("cacerts.pem") + ssl_ctx.verify_mode = ssl.CERT_REQUIRED + server = HTTPServer(app, ssl_options=ssl_ctx) By default, the return value is a dictionary (or None, if no client certificate is present). If ``binary_form`` is true, a @@ -543,6 +557,8 @@ class HTTPConnection(object): headers. :arg callback: a callback to be run when the write is complete. + The ``version`` field of ``start_line`` is ignored. + Returns a `.Future` if no callback is given. """ raise NotImplementedError() @@ -689,14 +705,17 @@ def parse_body_arguments(content_type, body, arguments, files, headers=None): if values: arguments.setdefault(name, []).extend(values) elif content_type.startswith("multipart/form-data"): - fields = content_type.split(";") - for field in fields: - k, sep, v = field.strip().partition("=") - if k == "boundary" and v: - parse_multipart_form_data(utf8(v), body, arguments, files) - break - else: - gen_log.warning("Invalid multipart/form-data") + try: + fields = content_type.split(";") + for field in fields: + k, sep, v = field.strip().partition("=") + if k == "boundary" and v: + parse_multipart_form_data(utf8(v), body, arguments, files) + break + else: + raise ValueError("multipart boundary not found") + except Exception as e: + gen_log.warning("Invalid multipart/form-data: %s", e) def parse_multipart_form_data(boundary, data, arguments, files): @@ -782,7 +801,7 @@ def parse_request_start_line(line): method, path, version = line.split(" ") except ValueError: raise HTTPInputError("Malformed HTTP request line") - if not version.startswith("HTTP/"): + if not re.match(r"^HTTP/1\.[0-9]$", version): raise HTTPInputError( "Malformed HTTP version in HTTP Request-Line: %r" % version) return RequestStartLine(method, path, version) @@ -801,7 +820,7 @@ def parse_response_start_line(line): ResponseStartLine(version='HTTP/1.1', code=200, reason='OK') """ line = native_str(line) - match = re.match("(HTTP/1.[01]) ([0-9]+) ([^\r]*)", line) + match = re.match("(HTTP/1.[0-9]) ([0-9]+) ([^\r]*)", line) if not match: raise HTTPInputError("Error parsing response start line") return ResponseStartLine(match.group(1), int(match.group(2)), @@ -873,3 +892,20 @@ def _encode_header(key, pdict): def doctests(): import doctest return doctest.DocTestSuite() + + +def split_host_and_port(netloc): + """Returns ``(host, port)`` tuple from ``netloc``. + + Returned ``port`` will be ``None`` if not present. + + .. versionadded:: 4.1 + """ + match = re.match(r'^(.+):(\d+)$', netloc) + if match: + host = match.group(1) + port = int(match.group(2)) + else: + host = netloc + port = None + return (host, port) diff --git a/tornado/ioloop.py b/tornado/ioloop.py index 03193865..67e33b52 100644 --- a/tornado/ioloop.py +++ b/tornado/ioloop.py @@ -41,6 +41,7 @@ import sys import threading import time import traceback +import math from tornado.concurrent import TracebackFuture, is_future from tornado.log import app_log, gen_log @@ -76,35 +77,52 @@ class IOLoop(Configurable): simultaneous connections, you should use a system that supports either ``epoll`` or ``kqueue``. - Example usage for a simple TCP server:: + Example usage for a simple TCP server: + + .. testcode:: import errno import functools - import ioloop + import tornado.ioloop import socket def connection_ready(sock, fd, events): while True: try: connection, address = sock.accept() - except socket.error, e: + except socket.error as e: if e.args[0] not in (errno.EWOULDBLOCK, errno.EAGAIN): raise return connection.setblocking(0) handle_connection(connection, address) - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.setblocking(0) - sock.bind(("", port)) - sock.listen(128) + if __name__ == '__main__': + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(0) + sock.bind(("", port)) + sock.listen(128) - io_loop = ioloop.IOLoop.instance() - callback = functools.partial(connection_ready, sock) - io_loop.add_handler(sock.fileno(), callback, io_loop.READ) - io_loop.start() + io_loop = tornado.ioloop.IOLoop.current() + callback = functools.partial(connection_ready, sock) + io_loop.add_handler(sock.fileno(), callback, io_loop.READ) + io_loop.start() + .. testoutput:: + :hide: + + By default, a newly-constructed `IOLoop` becomes the thread's current + `IOLoop`, unless there already is a current `IOLoop`. This behavior + can be controlled with the ``make_current`` argument to the `IOLoop` + constructor: if ``make_current=True``, the new `IOLoop` will always + try to become current and it raises an error if there is already a + current instance. If ``make_current=False``, the new `IOLoop` will + not try to become current. + + .. versionchanged:: 4.2 + Added the ``make_current`` keyword argument to the `IOLoop` + constructor. """ # Constants from the epoll module _EPOLLIN = 0x001 @@ -133,7 +151,8 @@ class IOLoop(Configurable): Most applications have a single, global `IOLoop` running on the main thread. Use this method to get this instance from - another thread. To get the current thread's `IOLoop`, use `current()`. + another thread. In most other cases, it is better to use `current()` + to get the current thread's `IOLoop`. """ if not hasattr(IOLoop, "_instance"): with IOLoop._instance_lock: @@ -167,28 +186,26 @@ class IOLoop(Configurable): del IOLoop._instance @staticmethod - def current(): + def current(instance=True): """Returns the current thread's `IOLoop`. - If an `IOLoop` is currently running or has been marked as current - by `make_current`, returns that instance. Otherwise returns - `IOLoop.instance()`, i.e. the main thread's `IOLoop`. - - A common pattern for classes that depend on ``IOLoops`` is to use - a default argument to enable programs with multiple ``IOLoops`` - but not require the argument for simpler applications:: - - class MyClass(object): - def __init__(self, io_loop=None): - self.io_loop = io_loop or IOLoop.current() + If an `IOLoop` is currently running or has been marked as + current by `make_current`, returns that instance. If there is + no current `IOLoop`, returns `IOLoop.instance()` (i.e. the + main thread's `IOLoop`, creating one if necessary) if ``instance`` + is true. In general you should use `IOLoop.current` as the default when constructing an asynchronous object, and use `IOLoop.instance` when you mean to communicate to the main thread from a different one. + + .. versionchanged:: 4.1 + Added ``instance`` argument to control the fallback to + `IOLoop.instance()`. """ current = getattr(IOLoop._current, "instance", None) - if current is None: + if current is None and instance: return IOLoop.instance() return current @@ -200,6 +217,10 @@ class IOLoop(Configurable): `make_current` explicitly before starting the `IOLoop`, so that code run at startup time can find the right instance. + + .. versionchanged:: 4.1 + An `IOLoop` created while there is no current `IOLoop` + will automatically become current. """ IOLoop._current.instance = self @@ -223,8 +244,14 @@ class IOLoop(Configurable): from tornado.platform.select import SelectIOLoop return SelectIOLoop - def initialize(self): - pass + def initialize(self, make_current=None): + if make_current is None: + if IOLoop.current(instance=False) is None: + self.make_current() + elif make_current: + if IOLoop.current(instance=False) is None: + raise RuntimeError("current IOLoop already exists") + self.make_current() def close(self, all_fds=False): """Closes the `IOLoop`, freeing any resources used. @@ -390,7 +417,7 @@ class IOLoop(Configurable): # do stuff... if __name__ == '__main__': - IOLoop.instance().run_sync(main) + IOLoop.current().run_sync(main) """ future_cell = [None] @@ -633,8 +660,8 @@ class PollIOLoop(IOLoop): (Linux), `tornado.platform.kqueue.KQueueIOLoop` (BSD and Mac), or `tornado.platform.select.SelectIOLoop` (all platforms). """ - def initialize(self, impl, time_func=None): - super(PollIOLoop, self).initialize() + def initialize(self, impl, time_func=None, **kwargs): + super(PollIOLoop, self).initialize(**kwargs) self._impl = impl if hasattr(self._impl, 'fileno'): set_close_exec(self._impl.fileno()) @@ -739,8 +766,10 @@ class PollIOLoop(IOLoop): # IOLoop is just started once at the beginning. signal.set_wakeup_fd(old_wakeup_fd) old_wakeup_fd = None - except ValueError: # non-main thread - pass + except ValueError: + # Non-main thread, or the previous value of wakeup_fd + # is no longer valid. + old_wakeup_fd = None try: while True: @@ -944,8 +973,16 @@ class PeriodicCallback(object): """Schedules the given callback to be called periodically. The callback is called every ``callback_time`` milliseconds. + Note that the timeout is given in milliseconds, while most other + time-related functions in Tornado use seconds. + + If the callback runs for longer than ``callback_time`` milliseconds, + subsequent invocations will be skipped to get back on schedule. `start` must be called after the `PeriodicCallback` is created. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ def __init__(self, callback, callback_time, io_loop=None): self.callback = callback @@ -969,6 +1006,13 @@ class PeriodicCallback(object): self.io_loop.remove_timeout(self._timeout) self._timeout = None + def is_running(self): + """Return True if this `.PeriodicCallback` has been started. + + .. versionadded:: 4.1 + """ + return self._running + def _run(self): if not self._running: return @@ -982,6 +1026,9 @@ class PeriodicCallback(object): def _schedule_next(self): if self._running: current_time = self.io_loop.time() - while self._next_timeout <= current_time: - self._next_timeout += self.callback_time / 1000.0 + + if self._next_timeout <= current_time: + callback_time_sec = self.callback_time / 1000.0 + self._next_timeout += (math.floor((current_time - self._next_timeout) / callback_time_sec) + 1) * callback_time_sec + self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run) diff --git a/tornado/iostream.py b/tornado/iostream.py index 69f43957..3a175a67 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -37,7 +37,7 @@ import re from tornado.concurrent import TracebackFuture from tornado import ioloop from tornado.log import gen_log, app_log -from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError +from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError, _client_ssl_defaults, _server_ssl_defaults from tornado import stack_context from tornado.util import errno_from_exception @@ -68,13 +68,21 @@ _ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE, if hasattr(errno, "WSAECONNRESET"): _ERRNO_CONNRESET += (errno.WSAECONNRESET, errno.WSAECONNABORTED, errno.WSAETIMEDOUT) +if sys.platform == 'darwin': + # OSX appears to have a race condition that causes send(2) to return + # EPROTOTYPE if called while a socket is being torn down: + # http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/ + # Since the socket is being closed anyway, treat this as an ECONNRESET + # instead of an unexpected error. + _ERRNO_CONNRESET += (errno.EPROTOTYPE,) + # More non-portable errnos: _ERRNO_INPROGRESS = (errno.EINPROGRESS,) if hasattr(errno, "WSAEINPROGRESS"): _ERRNO_INPROGRESS += (errno.WSAEINPROGRESS,) -####################################################### + class StreamClosedError(IOError): """Exception raised by `IOStream` methods when the stream is closed. @@ -122,6 +130,7 @@ class BaseIOStream(object): """`BaseIOStream` constructor. :arg io_loop: The `.IOLoop` to use; defaults to `.IOLoop.current`. + Deprecated since Tornado 4.1. :arg max_buffer_size: Maximum amount of incoming data to buffer; defaults to 100MB. :arg read_chunk_size: Amount of data to read at one time from the @@ -160,6 +169,11 @@ class BaseIOStream(object): self._close_callback = None self._connect_callback = None self._connect_future = None + # _ssl_connect_future should be defined in SSLIOStream + # but it's here so we can clean it up in maybe_run_close_callback. + # TODO: refactor that so subclasses can add additional futures + # to be cancelled. + self._ssl_connect_future = None self._connecting = False self._state = None self._pending_callbacks = 0 @@ -230,6 +244,12 @@ class BaseIOStream(object): gen_log.info("Unsatisfiable read, closing connection: %s" % e) self.close(exc_info=True) return future + except: + if future is not None: + # Ensure that the future doesn't log an error because its + # failure was never examined. + future.add_done_callback(lambda f: f.exception()) + raise return future def read_until(self, delimiter, callback=None, max_bytes=None): @@ -257,6 +277,10 @@ class BaseIOStream(object): gen_log.info("Unsatisfiable read, closing connection: %s" % e) self.close(exc_info=True) return future + except: + if future is not None: + future.add_done_callback(lambda f: f.exception()) + raise return future def read_bytes(self, num_bytes, callback=None, streaming_callback=None, @@ -281,7 +305,12 @@ class BaseIOStream(object): self._read_bytes = num_bytes self._read_partial = partial self._streaming_callback = stack_context.wrap(streaming_callback) - self._try_inline_read() + try: + self._try_inline_read() + except: + if future is not None: + future.add_done_callback(lambda f: f.exception()) + raise return future def read_until_close(self, callback=None, streaming_callback=None): @@ -293,9 +322,16 @@ class BaseIOStream(object): If a callback is given, it will be run with the data as an argument; if not, this method returns a `.Future`. + Note that if a ``streaming_callback`` is used, data will be + read from the socket as quickly as it becomes available; there + is no way to apply backpressure or cancel the reads. If flow + control or cancellation are desired, use a loop with + `read_bytes(partial=True) <.read_bytes>` instead. + .. versionchanged:: 4.0 The callback argument is now optional and a `.Future` will be returned if it is omitted. + """ future = self._set_read_callback(callback) self._streaming_callback = stack_context.wrap(streaming_callback) @@ -305,7 +341,11 @@ class BaseIOStream(object): self._run_read_callback(self._read_buffer_size, False) return future self._read_until_close = True - self._try_inline_read() + try: + self._try_inline_read() + except: + future.add_done_callback(lambda f: f.exception()) + raise return future def write(self, data, callback=None): @@ -331,7 +371,7 @@ class BaseIOStream(object): if data: if (self.max_write_buffer_size is not None and self._write_buffer_size + len(data) > self.max_write_buffer_size): - raise StreamBufferFullError("Reached maximum read buffer size") + raise StreamBufferFullError("Reached maximum write buffer size") # Break up large contiguous strings before inserting them in the # write buffer, so we don't have to recopy the entire thing # as we slice off pieces to send to the socket. @@ -344,6 +384,7 @@ class BaseIOStream(object): future = None else: future = self._write_future = TracebackFuture() + future.add_done_callback(lambda f: f.exception()) if not self._connecting: self._handle_write() if self._write_buffer: @@ -401,9 +442,11 @@ class BaseIOStream(object): if self._connect_future is not None: futures.append(self._connect_future) self._connect_future = None + if self._ssl_connect_future is not None: + futures.append(self._ssl_connect_future) + self._ssl_connect_future = None for future in futures: - if (isinstance(self.error, (socket.error, IOError)) and - errno_from_exception(self.error) in _ERRNO_CONNRESET): + if self._is_connreset(self.error): # Treat connection resets as closed connections so # clients only have to catch one kind of exception # to avoid logging. @@ -601,9 +644,8 @@ class BaseIOStream(object): pos = self._read_to_buffer_loop() except UnsatisfiableReadError: raise - except Exception as e: - if 1 != e.errno: - gen_log.warning("error on read", exc_info=True) + except Exception: + gen_log.warning("error on read", exc_info=True) self.close(exc_info=True) return if pos is not None: @@ -627,13 +669,13 @@ class BaseIOStream(object): else: callback = self._read_callback self._read_callback = self._streaming_callback = None - if self._read_future is not None: - assert callback is None - future = self._read_future - self._read_future = None - future.set_result(self._consume(size)) + if self._read_future is not None: + assert callback is None + future = self._read_future + self._read_future = None + future.set_result(self._consume(size)) if callback is not None: - assert self._read_future is None + assert (self._read_future is None) or streaming self._run_callback(callback, self._consume(size)) else: # If we scheduled a callback, we will add the error listener @@ -684,7 +726,7 @@ class BaseIOStream(object): chunk = self.read_from_fd() except (socket.error, IOError, OSError) as e: # ssl.SSLError is a subclass of socket.error - if e.args[0] in _ERRNO_CONNRESET: + if self._is_connreset(e): # Treat ECONNRESET as a connection close rather than # an error to minimize log spam (the exception will # be available on self.error for apps that care). @@ -806,7 +848,7 @@ class BaseIOStream(object): self._write_buffer_frozen = True break else: - if e.args[0] not in _ERRNO_CONNRESET: + if not self._is_connreset(e): # Broken pipe errors are usually caused by connection # reset, and its better to not log EPIPE errors to # minimize log spam @@ -884,6 +926,14 @@ class BaseIOStream(object): self._state = self._state | state self.io_loop.update_handler(self.fileno(), self._state) + def _is_connreset(self, exc): + """Return true if exc is ECONNRESET or equivalent. + + May be overridden in subclasses. + """ + return (isinstance(exc, (socket.error, IOError)) and + errno_from_exception(exc) in _ERRNO_CONNRESET) + class IOStream(BaseIOStream): r"""Socket-based `IOStream` implementation. @@ -898,7 +948,9 @@ class IOStream(BaseIOStream): connected before passing it to the `IOStream` or connected with `IOStream.connect`. - A very simple (and broken) HTTP client using this class:: + A very simple (and broken) HTTP client using this class: + + .. testcode:: import tornado.ioloop import tornado.iostream @@ -917,14 +969,19 @@ class IOStream(BaseIOStream): stream.read_bytes(int(headers[b"Content-Length"]), on_body) def on_body(data): - print data + print(data) stream.close() - tornado.ioloop.IOLoop.instance().stop() + tornado.ioloop.IOLoop.current().stop() + + if __name__ == '__main__': + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + stream = tornado.iostream.IOStream(s) + stream.connect(("friendfeed.com", 80), send_request) + tornado.ioloop.IOLoop.current().start() + + .. testoutput:: + :hide: - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) - stream = tornado.iostream.IOStream(s) - stream.connect(("friendfeed.com", 80), send_request) - tornado.ioloop.IOLoop.instance().start() """ def __init__(self, socket, *args, **kwargs): self.socket = socket @@ -978,10 +1035,10 @@ class IOStream(BaseIOStream): returns a `.Future` (whose result after a successful connection will be the stream itself). - If specified, the ``server_hostname`` parameter will be used - in SSL connections for certificate validation (if requested in - the ``ssl_options``) and SNI (if supported; requires - Python 3.2+). + In SSL mode, the ``server_hostname`` parameter will be used + for certificate validation (unless disabled in the + ``ssl_options``) and SNI (if supported; requires Python + 2.7.9+). Note that it is safe to call `IOStream.write ` while the connection is pending, in @@ -992,6 +1049,11 @@ class IOStream(BaseIOStream): .. versionchanged:: 4.0 If no callback is given, returns a `.Future`. + .. versionchanged:: 4.2 + SSL certificates are validated by default; pass + ``ssl_options=dict(cert_reqs=ssl.CERT_NONE)`` or a + suitably-configured `ssl.SSLContext` to the + `SSLIOStream` constructor to disable. """ self._connecting = True if callback is not None: @@ -1011,8 +1073,9 @@ class IOStream(BaseIOStream): # reported later in _handle_connect. if (errno_from_exception(e) not in _ERRNO_INPROGRESS and errno_from_exception(e) not in _ERRNO_WOULDBLOCK): - gen_log.warning("Connect error on fd %s: %s", - self.socket.fileno(), e) + if future is None: + gen_log.warning("Connect error on fd %s: %s", + self.socket.fileno(), e) self.close(exc_info=True) return future self._add_io_state(self.io_loop.WRITE) @@ -1033,10 +1096,11 @@ class IOStream(BaseIOStream): data. It can also be used immediately after connecting, before any reads or writes. - The ``ssl_options`` argument may be either a dictionary - of options or an `ssl.SSLContext`. If a ``server_hostname`` - is given, it will be used for certificate verification - (as configured in the ``ssl_options``). + The ``ssl_options`` argument may be either an `ssl.SSLContext` + object or a dictionary of keyword arguments for the + `ssl.wrap_socket` function. The ``server_hostname`` argument + will be used for certificate validation unless disabled + in the ``ssl_options``. This method returns a `.Future` whose result is the new `SSLIOStream`. After this method has been called, @@ -1046,6 +1110,11 @@ class IOStream(BaseIOStream): transferred to the new stream. .. versionadded:: 4.0 + + .. versionchanged:: 4.2 + SSL certificates are validated by default; pass + ``ssl_options=dict(cert_reqs=ssl.CERT_NONE)`` or a + suitably-configured `ssl.SSLContext` to disable. """ if (self._read_callback or self._read_future or self._write_callback or self._write_future or @@ -1054,12 +1123,17 @@ class IOStream(BaseIOStream): self._read_buffer or self._write_buffer): raise ValueError("IOStream is not idle; cannot convert to SSL") if ssl_options is None: - ssl_options = {} + if server_side: + ssl_options = _server_ssl_defaults + else: + ssl_options = _client_ssl_defaults socket = self.socket self.io_loop.remove_handler(socket) self.socket = None - socket = ssl_wrap_socket(socket, ssl_options, server_side=server_side, + socket = ssl_wrap_socket(socket, ssl_options, + server_hostname=server_hostname, + server_side=server_side, do_handshake_on_connect=False) orig_close_callback = self._close_callback self._close_callback = None @@ -1071,6 +1145,7 @@ class IOStream(BaseIOStream): # If we had an "unwrap" counterpart to this method we would need # to restore the original callback after our Future resolves # so that repeated wrap/unwrap calls don't build up layers. + def close_callback(): if not future.done(): future.set_exception(ssl_stream.error or StreamClosedError()) @@ -1115,7 +1190,7 @@ class IOStream(BaseIOStream): # Sometimes setsockopt will fail if the socket is closed # at the wrong time. This can happen with HTTPServer # resetting the value to false between requests. - if e.errno not in (errno.EINVAL, errno.ECONNRESET): + if e.errno != errno.EINVAL and not self._is_connreset(e): raise @@ -1131,11 +1206,11 @@ class SSLIOStream(IOStream): wrapped when `IOStream.connect` is finished. """ def __init__(self, *args, **kwargs): - """The ``ssl_options`` keyword argument may either be a dictionary - of keywords arguments for `ssl.wrap_socket`, or an `ssl.SSLContext` - object. + """The ``ssl_options`` keyword argument may either be an + `ssl.SSLContext` object or a dictionary of keywords arguments + for `ssl.wrap_socket` """ - self._ssl_options = kwargs.pop('ssl_options', {}) + self._ssl_options = kwargs.pop('ssl_options', _client_ssl_defaults) super(SSLIOStream, self).__init__(*args, **kwargs) self._ssl_accepting = True self._handshake_reading = False @@ -1190,8 +1265,7 @@ class SSLIOStream(IOStream): # to cause do_handshake to raise EBADF, so make that error # quiet as well. # https://groups.google.com/forum/?fromgroups#!topic/python-tornado/ApucKJat1_0 - if (err.args[0] in _ERRNO_CONNRESET or - err.args[0] == errno.EBADF): + if self._is_connreset(err) or err.args[0] == errno.EBADF: return self.close(exc_info=True) raise except AttributeError: @@ -1204,10 +1278,17 @@ class SSLIOStream(IOStream): if not self._verify_cert(self.socket.getpeercert()): self.close() return - if self._ssl_connect_callback is not None: - callback = self._ssl_connect_callback - self._ssl_connect_callback = None - self._run_callback(callback) + self._run_ssl_connect_callback() + + def _run_ssl_connect_callback(self): + if self._ssl_connect_callback is not None: + callback = self._ssl_connect_callback + self._ssl_connect_callback = None + self._run_callback(callback) + if self._ssl_connect_future is not None: + future = self._ssl_connect_future + self._ssl_connect_future = None + future.set_result(self) def _verify_cert(self, peercert): """Returns True if peercert is valid according to the configured @@ -1249,14 +1330,11 @@ class SSLIOStream(IOStream): super(SSLIOStream, self)._handle_write() def connect(self, address, callback=None, server_hostname=None): - # Save the user's callback and run it after the ssl handshake - # has completed. - self._ssl_connect_callback = stack_context.wrap(callback) self._server_hostname = server_hostname - # Note: Since we don't pass our callback argument along to - # super.connect(), this will always return a Future. - # This is harmless, but a bit less efficient than it could be. - return super(SSLIOStream, self).connect(address, callback=None) + # Pass a dummy callback to super.connect(), which is slightly + # more efficient than letting it return a Future we ignore. + super(SSLIOStream, self).connect(address, callback=lambda: None) + return self.wait_for_handshake(callback) def _handle_connect(self): # Call the superclass method to check for errors. @@ -1281,6 +1359,51 @@ class SSLIOStream(IOStream): do_handshake_on_connect=False) self._add_io_state(old_state) + def wait_for_handshake(self, callback=None): + """Wait for the initial SSL handshake to complete. + + If a ``callback`` is given, it will be called with no + arguments once the handshake is complete; otherwise this + method returns a `.Future` which will resolve to the + stream itself after the handshake is complete. + + Once the handshake is complete, information such as + the peer's certificate and NPN/ALPN selections may be + accessed on ``self.socket``. + + This method is intended for use on server-side streams + or after using `IOStream.start_tls`; it should not be used + with `IOStream.connect` (which already waits for the + handshake to complete). It may only be called once per stream. + + .. versionadded:: 4.2 + """ + if (self._ssl_connect_callback is not None or + self._ssl_connect_future is not None): + raise RuntimeError("Already waiting") + if callback is not None: + self._ssl_connect_callback = stack_context.wrap(callback) + future = None + else: + future = self._ssl_connect_future = TracebackFuture() + if not self._ssl_accepting: + self._run_ssl_connect_callback() + return future + + def write_to_fd(self, data): + try: + return self.socket.send(data) + except ssl.SSLError as e: + if e.args[0] == ssl.SSL_ERROR_WANT_WRITE: + # In Python 3.5+, SSLSocket.send raises a WANT_WRITE error if + # the socket is not writeable; we need to transform this into + # an EWOULDBLOCK socket.error or a zero return value, + # either of which will be recognized by the caller of this + # method. Prior to Python 3.5, an unwriteable socket would + # simply return 0 bytes written. + return 0 + raise + def read_from_fd(self): if self._ssl_accepting: # If the handshake hasn't finished yet, there can't be anything @@ -1311,6 +1434,11 @@ class SSLIOStream(IOStream): return None return chunk + def _is_connreset(self, e): + if isinstance(e, ssl.SSLError) and e.args[0] == ssl.SSL_ERROR_EOF: + return True + return super(SSLIOStream, self)._is_connreset(e) + class PipeIOStream(BaseIOStream): """Pipe-based `IOStream` implementation. diff --git a/tornado/locale.py b/tornado/locale.py index 07c6d582..a668765b 100644 --- a/tornado/locale.py +++ b/tornado/locale.py @@ -55,6 +55,7 @@ _default_locale = "en_US" _translations = {} _supported_locales = frozenset([_default_locale]) _use_gettext = False +CONTEXT_SEPARATOR = "\x04" def get(*locale_codes): @@ -273,6 +274,9 @@ class Locale(object): """ raise NotImplementedError() + def pgettext(self, context, message, plural_message=None, count=None): + raise NotImplementedError() + def format_date(self, date, gmt_offset=0, relative=True, shorter=False, full_format=False): """Formats the given date (which should be GMT). @@ -422,6 +426,11 @@ class CSVLocale(Locale): message_dict = self.translations.get("unknown", {}) return message_dict.get(message, message) + def pgettext(self, context, message, plural_message=None, count=None): + if self.translations: + gen_log.warning('pgettext is not supported by CSVLocale') + return self.translate(message, plural_message, count) + class GettextLocale(Locale): """Locale implementation using the `gettext` module.""" @@ -445,6 +454,44 @@ class GettextLocale(Locale): else: return self.gettext(message) + def pgettext(self, context, message, plural_message=None, count=None): + """Allows to set context for translation, accepts plural forms. + + Usage example:: + + pgettext("law", "right") + pgettext("good", "right") + + Plural message example:: + + pgettext("organization", "club", "clubs", len(clubs)) + pgettext("stick", "club", "clubs", len(clubs)) + + To generate POT file with context, add following options to step 1 + of `load_gettext_translations` sequence:: + + xgettext [basic options] --keyword=pgettext:1c,2 --keyword=pgettext:1c,2,3 + + .. versionadded:: 4.2 + """ + if plural_message is not None: + assert count is not None + msgs_with_ctxt = ("%s%s%s" % (context, CONTEXT_SEPARATOR, message), + "%s%s%s" % (context, CONTEXT_SEPARATOR, plural_message), + count) + result = self.ngettext(*msgs_with_ctxt) + if CONTEXT_SEPARATOR in result: + # Translation not found + result = self.ngettext(message, plural_message, count) + return result + else: + msg_with_ctxt = "%s%s%s" % (context, CONTEXT_SEPARATOR, message) + result = self.gettext(msg_with_ctxt) + if CONTEXT_SEPARATOR in result: + # Translation not found + result = message + return result + LOCALE_NAMES = { "af_ZA": {"name_en": u("Afrikaans"), "name": u("Afrikaans")}, "am_ET": {"name_en": u("Amharic"), "name": u('\u12a0\u121b\u122d\u129b')}, diff --git a/tornado/locks.py b/tornado/locks.py new file mode 100644 index 00000000..4b0bdb38 --- /dev/null +++ b/tornado/locks.py @@ -0,0 +1,460 @@ +# Copyright 2015 The Tornado Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +.. testsetup:: * + + from tornado import ioloop, gen, locks + io_loop = ioloop.IOLoop.current() +""" + +from __future__ import absolute_import, division, print_function, with_statement + +__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock'] + +import collections + +from tornado import gen, ioloop +from tornado.concurrent import Future + + +class _TimeoutGarbageCollector(object): + """Base class for objects that periodically clean up timed-out waiters. + + Avoids memory leak in a common pattern like: + + while True: + yield condition.wait(short_timeout) + print('looping....') + """ + def __init__(self): + self._waiters = collections.deque() # Futures. + self._timeouts = 0 + + def _garbage_collect(self): + # Occasionally clear timed-out waiters. + self._timeouts += 1 + if self._timeouts > 100: + self._timeouts = 0 + self._waiters = collections.deque( + w for w in self._waiters if not w.done()) + + +class Condition(_TimeoutGarbageCollector): + """A condition allows one or more coroutines to wait until notified. + + Like a standard `threading.Condition`, but does not need an underlying lock + that is acquired and released. + + With a `Condition`, coroutines can wait to be notified by other coroutines: + + .. testcode:: + + condition = locks.Condition() + + @gen.coroutine + def waiter(): + print("I'll wait right here") + yield condition.wait() # Yield a Future. + print("I'm done waiting") + + @gen.coroutine + def notifier(): + print("About to notify") + condition.notify() + print("Done notifying") + + @gen.coroutine + def runner(): + # Yield two Futures; wait for waiter() and notifier() to finish. + yield [waiter(), notifier()] + + io_loop.run_sync(runner) + + .. testoutput:: + + I'll wait right here + About to notify + Done notifying + I'm done waiting + + `wait` takes an optional ``timeout`` argument, which is either an absolute + timestamp:: + + io_loop = ioloop.IOLoop.current() + + # Wait up to 1 second for a notification. + yield condition.wait(timeout=io_loop.time() + 1) + + ...or a `datetime.timedelta` for a timeout relative to the current time:: + + # Wait up to 1 second. + yield condition.wait(timeout=datetime.timedelta(seconds=1)) + + The method raises `tornado.gen.TimeoutError` if there's no notification + before the deadline. + """ + + def __init__(self): + super(Condition, self).__init__() + self.io_loop = ioloop.IOLoop.current() + + def __repr__(self): + result = '<%s' % (self.__class__.__name__, ) + if self._waiters: + result += ' waiters[%s]' % len(self._waiters) + return result + '>' + + def wait(self, timeout=None): + """Wait for `.notify`. + + Returns a `.Future` that resolves ``True`` if the condition is notified, + or ``False`` after a timeout. + """ + waiter = Future() + self._waiters.append(waiter) + if timeout: + def on_timeout(): + waiter.set_result(False) + self._garbage_collect() + io_loop = ioloop.IOLoop.current() + timeout_handle = io_loop.add_timeout(timeout, on_timeout) + waiter.add_done_callback( + lambda _: io_loop.remove_timeout(timeout_handle)) + return waiter + + def notify(self, n=1): + """Wake ``n`` waiters.""" + waiters = [] # Waiters we plan to run right now. + while n and self._waiters: + waiter = self._waiters.popleft() + if not waiter.done(): # Might have timed out. + n -= 1 + waiters.append(waiter) + + for waiter in waiters: + waiter.set_result(True) + + def notify_all(self): + """Wake all waiters.""" + self.notify(len(self._waiters)) + + +class Event(object): + """An event blocks coroutines until its internal flag is set to True. + + Similar to `threading.Event`. + + A coroutine can wait for an event to be set. Once it is set, calls to + ``yield event.wait()`` will not block unless the event has been cleared: + + .. testcode:: + + event = locks.Event() + + @gen.coroutine + def waiter(): + print("Waiting for event") + yield event.wait() + print("Not waiting this time") + yield event.wait() + print("Done") + + @gen.coroutine + def setter(): + print("About to set the event") + event.set() + + @gen.coroutine + def runner(): + yield [waiter(), setter()] + + io_loop.run_sync(runner) + + .. testoutput:: + + Waiting for event + About to set the event + Not waiting this time + Done + """ + def __init__(self): + self._future = Future() + + def __repr__(self): + return '<%s %s>' % ( + self.__class__.__name__, 'set' if self.is_set() else 'clear') + + def is_set(self): + """Return ``True`` if the internal flag is true.""" + return self._future.done() + + def set(self): + """Set the internal flag to ``True``. All waiters are awakened. + + Calling `.wait` once the flag is set will not block. + """ + if not self._future.done(): + self._future.set_result(None) + + def clear(self): + """Reset the internal flag to ``False``. + + Calls to `.wait` will block until `.set` is called. + """ + if self._future.done(): + self._future = Future() + + def wait(self, timeout=None): + """Block until the internal flag is true. + + Returns a Future, which raises `tornado.gen.TimeoutError` after a + timeout. + """ + if timeout is None: + return self._future + else: + return gen.with_timeout(timeout, self._future) + + +class _ReleasingContextManager(object): + """Releases a Lock or Semaphore at the end of a "with" statement. + + with (yield semaphore.acquire()): + pass + + # Now semaphore.release() has been called. + """ + def __init__(self, obj): + self._obj = obj + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + self._obj.release() + + +class Semaphore(_TimeoutGarbageCollector): + """A lock that can be acquired a fixed number of times before blocking. + + A Semaphore manages a counter representing the number of `.release` calls + minus the number of `.acquire` calls, plus an initial value. The `.acquire` + method blocks if necessary until it can return without making the counter + negative. + + Semaphores limit access to a shared resource. To allow access for two + workers at a time: + + .. testsetup:: semaphore + + from collections import deque + + from tornado import gen, ioloop + from tornado.concurrent import Future + + # Ensure reliable doctest output: resolve Futures one at a time. + futures_q = deque([Future() for _ in range(3)]) + + @gen.coroutine + def simulator(futures): + for f in futures: + yield gen.moment + f.set_result(None) + + ioloop.IOLoop.current().add_callback(simulator, list(futures_q)) + + def use_some_resource(): + return futures_q.popleft() + + .. testcode:: semaphore + + sem = locks.Semaphore(2) + + @gen.coroutine + def worker(worker_id): + yield sem.acquire() + try: + print("Worker %d is working" % worker_id) + yield use_some_resource() + finally: + print("Worker %d is done" % worker_id) + sem.release() + + @gen.coroutine + def runner(): + # Join all workers. + yield [worker(i) for i in range(3)] + + io_loop.run_sync(runner) + + .. testoutput:: semaphore + + Worker 0 is working + Worker 1 is working + Worker 0 is done + Worker 2 is working + Worker 1 is done + Worker 2 is done + + Workers 0 and 1 are allowed to run concurrently, but worker 2 waits until + the semaphore has been released once, by worker 0. + + `.acquire` is a context manager, so ``worker`` could be written as:: + + @gen.coroutine + def worker(worker_id): + with (yield sem.acquire()): + print("Worker %d is working" % worker_id) + yield use_some_resource() + + # Now the semaphore has been released. + print("Worker %d is done" % worker_id) + """ + def __init__(self, value=1): + super(Semaphore, self).__init__() + if value < 0: + raise ValueError('semaphore initial value must be >= 0') + + self._value = value + + def __repr__(self): + res = super(Semaphore, self).__repr__() + extra = 'locked' if self._value == 0 else 'unlocked,value:{0}'.format( + self._value) + if self._waiters: + extra = '{0},waiters:{1}'.format(extra, len(self._waiters)) + return '<{0} [{1}]>'.format(res[1:-1], extra) + + def release(self): + """Increment the counter and wake one waiter.""" + self._value += 1 + while self._waiters: + waiter = self._waiters.popleft() + if not waiter.done(): + self._value -= 1 + + # If the waiter is a coroutine paused at + # + # with (yield semaphore.acquire()): + # + # then the context manager's __exit__ calls release() at the end + # of the "with" block. + waiter.set_result(_ReleasingContextManager(self)) + break + + def acquire(self, timeout=None): + """Decrement the counter. Returns a Future. + + Block if the counter is zero and wait for a `.release`. The Future + raises `.TimeoutError` after the deadline. + """ + waiter = Future() + if self._value > 0: + self._value -= 1 + waiter.set_result(_ReleasingContextManager(self)) + else: + self._waiters.append(waiter) + if timeout: + def on_timeout(): + waiter.set_exception(gen.TimeoutError()) + self._garbage_collect() + io_loop = ioloop.IOLoop.current() + timeout_handle = io_loop.add_timeout(timeout, on_timeout) + waiter.add_done_callback( + lambda _: io_loop.remove_timeout(timeout_handle)) + return waiter + + def __enter__(self): + raise RuntimeError( + "Use Semaphore like 'with (yield semaphore.acquire())', not like" + " 'with semaphore'") + + __exit__ = __enter__ + + +class BoundedSemaphore(Semaphore): + """A semaphore that prevents release() being called too many times. + + If `.release` would increment the semaphore's value past the initial + value, it raises `ValueError`. Semaphores are mostly used to guard + resources with limited capacity, so a semaphore released too many times + is a sign of a bug. + """ + def __init__(self, value=1): + super(BoundedSemaphore, self).__init__(value=value) + self._initial_value = value + + def release(self): + """Increment the counter and wake one waiter.""" + if self._value >= self._initial_value: + raise ValueError("Semaphore released too many times") + super(BoundedSemaphore, self).release() + + +class Lock(object): + """A lock for coroutines. + + A Lock begins unlocked, and `acquire` locks it immediately. While it is + locked, a coroutine that yields `acquire` waits until another coroutine + calls `release`. + + Releasing an unlocked lock raises `RuntimeError`. + + `acquire` supports the context manager protocol: + + >>> from tornado import gen, locks + >>> lock = locks.Lock() + >>> + >>> @gen.coroutine + ... def f(): + ... with (yield lock.acquire()): + ... # Do something holding the lock. + ... pass + ... + ... # Now the lock is released. + """ + def __init__(self): + self._block = BoundedSemaphore(value=1) + + def __repr__(self): + return "<%s _block=%s>" % ( + self.__class__.__name__, + self._block) + + def acquire(self, timeout=None): + """Attempt to lock. Returns a Future. + + Returns a Future, which raises `tornado.gen.TimeoutError` after a + timeout. + """ + return self._block.acquire(timeout) + + def release(self): + """Unlock. + + The first coroutine in line waiting for `acquire` gets the lock. + + If not locked, raise a `RuntimeError`. + """ + try: + self._block.release() + except ValueError: + raise RuntimeError('release unlocked lock') + + def __enter__(self): + raise RuntimeError( + "Use Lock like 'with (yield lock)', not like 'with lock'") + + __exit__ = __enter__ diff --git a/tornado/log.py b/tornado/log.py index 374071d4..c68dec46 100644 --- a/tornado/log.py +++ b/tornado/log.py @@ -206,6 +206,14 @@ def enable_pretty_logging(options=None, logger=None): def define_logging_options(options=None): + """Add logging-related flags to ``options``. + + These options are present automatically on the default options instance; + this method is only necessary if you have created your own `.OptionParser`. + + .. versionadded:: 4.2 + This function existed in prior versions but was broken and undocumented until 4.2. + """ if options is None: # late import to prevent cycle from tornado.options import options @@ -227,4 +235,4 @@ def define_logging_options(options=None): options.define("log_file_num_backups", type=int, default=10, help="number of log files to keep") - options.add_parse_callback(enable_pretty_logging) + options.add_parse_callback(lambda: enable_pretty_logging(options)) diff --git a/tornado/netutil.py b/tornado/netutil.py index f147c974..9aa292c4 100644 --- a/tornado/netutil.py +++ b/tornado/netutil.py @@ -20,7 +20,7 @@ from __future__ import absolute_import, division, print_function, with_statement import errno import os -import platform +import sys import socket import stat @@ -35,6 +35,15 @@ except ImportError: # ssl is not available on Google App Engine ssl = None +try: + import certifi +except ImportError: + # certifi is optional as long as we have ssl.create_default_context. + if ssl is None or hasattr(ssl, 'create_default_context'): + certifi = None + else: + raise + try: xrange # py2 except NameError: @@ -50,6 +59,38 @@ else: ssl_match_hostname = backports.ssl_match_hostname.match_hostname SSLCertificateError = backports.ssl_match_hostname.CertificateError +if hasattr(ssl, 'SSLContext'): + if hasattr(ssl, 'create_default_context'): + # Python 2.7.9+, 3.4+ + # Note that the naming of ssl.Purpose is confusing; the purpose + # of a context is to authentiate the opposite side of the connection. + _client_ssl_defaults = ssl.create_default_context( + ssl.Purpose.SERVER_AUTH) + _server_ssl_defaults = ssl.create_default_context( + ssl.Purpose.CLIENT_AUTH) + else: + # Python 3.2-3.3 + _client_ssl_defaults = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + _client_ssl_defaults.verify_mode = ssl.CERT_REQUIRED + _client_ssl_defaults.load_verify_locations(certifi.where()) + _server_ssl_defaults = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + if hasattr(ssl, 'OP_NO_COMPRESSION'): + # Disable TLS compression to avoid CRIME and related attacks. + # This constant wasn't added until python 3.3. + _client_ssl_defaults.options |= ssl.OP_NO_COMPRESSION + _server_ssl_defaults.options |= ssl.OP_NO_COMPRESSION + +elif ssl: + # Python 2.6-2.7.8 + _client_ssl_defaults = dict(cert_reqs=ssl.CERT_REQUIRED, + ca_certs=certifi.where()) + _server_ssl_defaults = {} +else: + # Google App Engine + _client_ssl_defaults = dict(cert_reqs=None, + ca_certs=None) + _server_ssl_defaults = {} + # ThreadedResolver runs getaddrinfo on a thread. If the hostname is unicode, # getaddrinfo attempts to import encodings.idna. If this is done at # module-import time, the import lock is already held by the main thread, @@ -68,6 +109,7 @@ if hasattr(errno, "WSAEWOULDBLOCK"): # Default backlog used when calling sock.listen() _DEFAULT_BACKLOG = 128 + def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=_DEFAULT_BACKLOG, flags=None): """Creates listening sockets bound to the given port and address. @@ -105,7 +147,7 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, for res in set(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM, 0, flags)): af, socktype, proto, canonname, sockaddr = res - if (platform.system() == 'Darwin' and address == 'localhost' and + if (sys.platform == 'darwin' and address == 'localhost' and af == socket.AF_INET6 and sockaddr[3] != 0): # Mac OS X includes a link-local address fe80::1%lo0 in the # getaddrinfo results for 'localhost'. However, the firewall @@ -187,6 +229,9 @@ def add_accept_handler(sock, callback, io_loop=None): address of the other end of the connection). Note that this signature is different from the ``callback(fd, events)`` signature used for `.IOLoop` handlers. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ if io_loop is None: io_loop = IOLoop.current() @@ -301,6 +346,9 @@ class ExecutorResolver(Resolver): The executor will be shut down when the resolver is closed unless ``close_resolver=False``; use this if you want to reuse the same executor elsewhere. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ def initialize(self, io_loop=None, executor=None, close_executor=True): self.io_loop = io_loop or IOLoop.current() @@ -413,7 +461,7 @@ def ssl_options_to_context(ssl_options): `~ssl.SSLContext` object. The ``ssl_options`` dictionary contains keywords to be passed to - `ssl.wrap_socket`. In Python 3.2+, `ssl.SSLContext` objects can + `ssl.wrap_socket`. In Python 2.7.9+, `ssl.SSLContext` objects can be used instead. This function converts the dict form to its `~ssl.SSLContext` equivalent, and may be used when a component which accepts both forms needs to upgrade to the `~ssl.SSLContext` version @@ -444,11 +492,11 @@ def ssl_options_to_context(ssl_options): def ssl_wrap_socket(socket, ssl_options, server_hostname=None, **kwargs): """Returns an ``ssl.SSLSocket`` wrapping the given socket. - ``ssl_options`` may be either a dictionary (as accepted by - `ssl_options_to_context`) or an `ssl.SSLContext` object. - Additional keyword arguments are passed to ``wrap_socket`` - (either the `~ssl.SSLContext` method or the `ssl` module function - as appropriate). + ``ssl_options`` may be either an `ssl.SSLContext` object or a + dictionary (as accepted by `ssl_options_to_context`). Additional + keyword arguments are passed to ``wrap_socket`` (either the + `~ssl.SSLContext` method or the `ssl` module function as + appropriate). """ context = ssl_options_to_context(ssl_options) if hasattr(ssl, 'SSLContext') and isinstance(context, ssl.SSLContext): diff --git a/tornado/options.py b/tornado/options.py index 5e23e291..89a9e432 100644 --- a/tornado/options.py +++ b/tornado/options.py @@ -204,6 +204,13 @@ class OptionParser(object): (name, self._options[name].file_name)) frame = sys._getframe(0) options_file = frame.f_code.co_filename + + # Can be called directly, or through top level define() fn, in which + # case, step up above that frame to look for real caller. + if (frame.f_back.f_code.co_filename == options_file and + frame.f_back.f_code.co_name == 'define'): + frame = frame.f_back + file_name = frame.f_back.f_code.co_filename if file_name == options_file: file_name = "" @@ -249,7 +256,7 @@ class OptionParser(object): arg = args[i].lstrip("-") name, equals, value = arg.partition("=") name = name.replace('-', '_') - if not name in self._options: + if name not in self._options: self.print_help() raise Error('Unrecognized command line option: %r' % name) option = self._options[name] diff --git a/tornado/platform/asyncio.py b/tornado/platform/asyncio.py index dd6722a4..8f3dbff6 100644 --- a/tornado/platform/asyncio.py +++ b/tornado/platform/asyncio.py @@ -12,6 +12,8 @@ unfinished callbacks on the event loop that fail when it resumes) from __future__ import absolute_import, division, print_function, with_statement import functools +import tornado.concurrent +from tornado.gen import convert_yielded from tornado.ioloop import IOLoop from tornado import stack_context @@ -27,8 +29,10 @@ except ImportError as e: # Re-raise the original asyncio error, not the trollius one. raise e + class BaseAsyncIOLoop(IOLoop): - def initialize(self, asyncio_loop, close_loop=False): + def initialize(self, asyncio_loop, close_loop=False, **kwargs): + super(BaseAsyncIOLoop, self).initialize(**kwargs) self.asyncio_loop = asyncio_loop self.close_loop = close_loop self.asyncio_loop.call_soon(self.make_current) @@ -129,12 +133,29 @@ class BaseAsyncIOLoop(IOLoop): class AsyncIOMainLoop(BaseAsyncIOLoop): - def initialize(self): + def initialize(self, **kwargs): super(AsyncIOMainLoop, self).initialize(asyncio.get_event_loop(), - close_loop=False) + close_loop=False, **kwargs) class AsyncIOLoop(BaseAsyncIOLoop): - def initialize(self): + def initialize(self, **kwargs): super(AsyncIOLoop, self).initialize(asyncio.new_event_loop(), - close_loop=True) + close_loop=True, **kwargs) + + +def to_tornado_future(asyncio_future): + """Convert an ``asyncio.Future`` to a `tornado.concurrent.Future`.""" + tf = tornado.concurrent.Future() + tornado.concurrent.chain_future(asyncio_future, tf) + return tf + + +def to_asyncio_future(tornado_future): + """Convert a `tornado.concurrent.Future` to an ``asyncio.Future``.""" + af = asyncio.Future() + tornado.concurrent.chain_future(tornado_future, af) + return af + +if hasattr(convert_yielded, 'register'): + convert_yielded.register(asyncio.Future, to_tornado_future) diff --git a/tornado/platform/auto.py b/tornado/platform/auto.py index ddfe06b4..fc40c9d9 100644 --- a/tornado/platform/auto.py +++ b/tornado/platform/auto.py @@ -27,13 +27,14 @@ from __future__ import absolute_import, division, print_function, with_statement import os -if os.name == 'nt': - from tornado.platform.common import Waker - from tornado.platform.windows import set_close_exec -elif 'APPENGINE_RUNTIME' in os.environ: +if 'APPENGINE_RUNTIME' in os.environ: from tornado.platform.common import Waker + def set_close_exec(fd): pass +elif os.name == 'nt': + from tornado.platform.common import Waker + from tornado.platform.windows import set_close_exec else: from tornado.platform.posix import set_close_exec, Waker @@ -41,9 +42,13 @@ try: # monotime monkey-patches the time module to have a monotonic function # in versions of python before 3.3. import monotime + # Silence pyflakes warning about this unused import + monotime except ImportError: pass try: from time import monotonic as monotonic_time except ImportError: monotonic_time = None + +__all__ = ['Waker', 'set_close_exec', 'monotonic_time'] diff --git a/tornado/platform/caresresolver.py b/tornado/platform/caresresolver.py index c4648c22..5559614f 100644 --- a/tornado/platform/caresresolver.py +++ b/tornado/platform/caresresolver.py @@ -18,6 +18,9 @@ class CaresResolver(Resolver): so it is only recommended for use in ``AF_INET`` (i.e. IPv4). This is the default for ``tornado.simple_httpclient``, but other libraries may default to ``AF_UNSPEC``. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ def initialize(self, io_loop=None): self.io_loop = io_loop or IOLoop.current() diff --git a/tornado/platform/kqueue.py b/tornado/platform/kqueue.py index de8c046d..f8f3e4a6 100644 --- a/tornado/platform/kqueue.py +++ b/tornado/platform/kqueue.py @@ -54,8 +54,7 @@ class _KQueue(object): if events & IOLoop.WRITE: kevents.append(select.kevent( fd, filter=select.KQ_FILTER_WRITE, flags=flags)) - if events & IOLoop.READ or not kevents: - # Always read when there is not a write + if events & IOLoop.READ: kevents.append(select.kevent( fd, filter=select.KQ_FILTER_READ, flags=flags)) # Even though control() takes a list, it seems to return EINVAL diff --git a/tornado/platform/select.py b/tornado/platform/select.py index 9a879562..db52ef91 100644 --- a/tornado/platform/select.py +++ b/tornado/platform/select.py @@ -47,7 +47,7 @@ class _Select(object): # Closed connections are reported as errors by epoll and kqueue, # but as zero-byte reads by select, so when errors are requested # we need to listen for both read and error. - self.read_fds.add(fd) + # self.read_fds.add(fd) def modify(self, fd, events): self.unregister(fd) diff --git a/tornado/platform/twisted.py b/tornado/platform/twisted.py index 27d991cd..7b3c8ca5 100644 --- a/tornado/platform/twisted.py +++ b/tornado/platform/twisted.py @@ -35,7 +35,7 @@ of the application:: tornado.platform.twisted.install() from twisted.internet import reactor -When the app is ready to start, call `IOLoop.instance().start()` +When the app is ready to start, call `IOLoop.current().start()` instead of `reactor.run()`. It is also possible to create a non-global reactor by calling @@ -70,8 +70,10 @@ import datetime import functools import numbers import socket +import sys import twisted.internet.abstract +from twisted.internet.defer import Deferred from twisted.internet.posixbase import PosixReactorBase from twisted.internet.interfaces import \ IReactorFDSet, IDelayedCall, IReactorTime, IReadDescriptor, IWriteDescriptor @@ -84,6 +86,7 @@ import twisted.names.resolve from zope.interface import implementer +from tornado.concurrent import Future from tornado.escape import utf8 from tornado import gen import tornado.ioloop @@ -147,6 +150,9 @@ class TornadoReactor(PosixReactorBase): We override `mainLoop` instead of `doIteration` and must implement timed call functionality on top of `IOLoop.add_timeout` rather than using the implementation in `PosixReactorBase`. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ def __init__(self, io_loop=None): if not io_loop: @@ -356,7 +362,11 @@ class _TestReactor(TornadoReactor): def install(io_loop=None): - """Install this package as the default Twisted reactor.""" + """Install this package as the default Twisted reactor. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. + """ if not io_loop: io_loop = tornado.ioloop.IOLoop.current() reactor = TornadoReactor(io_loop) @@ -406,7 +416,8 @@ class TwistedIOLoop(tornado.ioloop.IOLoop): because the ``SIGCHLD`` handlers used by Tornado and Twisted conflict with each other. """ - def initialize(self, reactor=None): + def initialize(self, reactor=None, **kwargs): + super(TwistedIOLoop, self).initialize(**kwargs) if reactor is None: import twisted.internet.reactor reactor = twisted.internet.reactor @@ -512,6 +523,9 @@ class TwistedResolver(Resolver): ``socket.AF_UNSPEC``. Requires Twisted 12.1 or newer. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ def initialize(self, io_loop=None): self.io_loop = io_loop or IOLoop.current() @@ -554,3 +568,18 @@ class TwistedResolver(Resolver): (resolved_family, (resolved, port)), ] raise gen.Return(result) + +if hasattr(gen.convert_yielded, 'register'): + @gen.convert_yielded.register(Deferred) + def _(d): + f = Future() + + def errback(failure): + try: + failure.raiseException() + # Should never happen, but just in case + raise Exception("errback called without error") + except: + f.set_exc_info(sys.exc_info()) + d.addCallbacks(f.set_result, errback) + return f diff --git a/tornado/process.py b/tornado/process.py index cea3dbd0..f580e192 100644 --- a/tornado/process.py +++ b/tornado/process.py @@ -29,6 +29,7 @@ import time from binascii import hexlify +from tornado.concurrent import Future from tornado import ioloop from tornado.iostream import PipeIOStream from tornado.log import gen_log @@ -48,6 +49,10 @@ except NameError: long = int # py3 +# Re-export this exception for convenience. +CalledProcessError = subprocess.CalledProcessError + + def cpu_count(): """Returns the number of processors on this machine.""" if multiprocessing is None: @@ -191,6 +196,9 @@ class Subprocess(object): ``tornado.process.Subprocess.STREAM``, which will make the corresponding attribute of the resulting Subprocess a `.PipeIOStream`. * A new keyword argument ``io_loop`` may be used to pass in an IOLoop. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ STREAM = object() @@ -255,6 +263,33 @@ class Subprocess(object): Subprocess._waiting[self.pid] = self Subprocess._try_cleanup_process(self.pid) + def wait_for_exit(self, raise_error=True): + """Returns a `.Future` which resolves when the process exits. + + Usage:: + + ret = yield proc.wait_for_exit() + + This is a coroutine-friendly alternative to `set_exit_callback` + (and a replacement for the blocking `subprocess.Popen.wait`). + + By default, raises `subprocess.CalledProcessError` if the process + has a non-zero exit status. Use ``wait_for_exit(raise_error=False)`` + to suppress this behavior and return the exit status without raising. + + .. versionadded:: 4.2 + """ + future = Future() + + def callback(ret): + if ret != 0 and raise_error: + # Unfortunately we don't have the original args any more. + future.set_exception(CalledProcessError(ret, None)) + else: + future.set_result(ret) + self.set_exit_callback(callback) + return future + @classmethod def initialize(cls, io_loop=None): """Initializes the ``SIGCHLD`` handler. @@ -263,6 +298,9 @@ class Subprocess(object): Note that the `.IOLoop` used for signal handling need not be the same one used by individual Subprocess objects (as long as the ``IOLoops`` are each running in separate threads). + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ if cls._initialized: return diff --git a/tornado/queues.py b/tornado/queues.py new file mode 100644 index 00000000..55ab4834 --- /dev/null +++ b/tornado/queues.py @@ -0,0 +1,321 @@ +# Copyright 2015 The Tornado Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, with_statement + +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty'] + +import collections +import heapq + +from tornado import gen, ioloop +from tornado.concurrent import Future +from tornado.locks import Event + + +class QueueEmpty(Exception): + """Raised by `.Queue.get_nowait` when the queue has no items.""" + pass + + +class QueueFull(Exception): + """Raised by `.Queue.put_nowait` when a queue is at its maximum size.""" + pass + + +def _set_timeout(future, timeout): + if timeout: + def on_timeout(): + future.set_exception(gen.TimeoutError()) + io_loop = ioloop.IOLoop.current() + timeout_handle = io_loop.add_timeout(timeout, on_timeout) + future.add_done_callback( + lambda _: io_loop.remove_timeout(timeout_handle)) + + +class Queue(object): + """Coordinate producer and consumer coroutines. + + If maxsize is 0 (the default) the queue size is unbounded. + + .. testcode:: + + q = queues.Queue(maxsize=2) + + @gen.coroutine + def consumer(): + while True: + item = yield q.get() + try: + print('Doing work on %s' % item) + yield gen.sleep(0.01) + finally: + q.task_done() + + @gen.coroutine + def producer(): + for item in range(5): + yield q.put(item) + print('Put %s' % item) + + @gen.coroutine + def main(): + consumer() # Start consumer. + yield producer() # Wait for producer to put all tasks. + yield q.join() # Wait for consumer to finish all tasks. + print('Done') + + io_loop.run_sync(main) + + .. testoutput:: + + Put 0 + Put 1 + Put 2 + Doing work on 0 + Doing work on 1 + Put 3 + Doing work on 2 + Put 4 + Doing work on 3 + Doing work on 4 + Done + """ + def __init__(self, maxsize=0): + if maxsize is None: + raise TypeError("maxsize can't be None") + + if maxsize < 0: + raise ValueError("maxsize can't be negative") + + self._maxsize = maxsize + self._init() + self._getters = collections.deque([]) # Futures. + self._putters = collections.deque([]) # Pairs of (item, Future). + self._unfinished_tasks = 0 + self._finished = Event() + self._finished.set() + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + def empty(self): + return not self._queue + + def full(self): + if self.maxsize == 0: + return False + else: + return self.qsize() >= self.maxsize + + def put(self, item, timeout=None): + """Put an item into the queue, perhaps waiting until there is room. + + Returns a Future, which raises `tornado.gen.TimeoutError` after a + timeout. + """ + try: + self.put_nowait(item) + except QueueFull: + future = Future() + self._putters.append((item, future)) + _set_timeout(future, timeout) + return future + else: + return gen._null_future + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise `QueueFull`. + """ + self._consume_expired() + if self._getters: + assert self.empty(), "queue non-empty, why are getters waiting?" + getter = self._getters.popleft() + self.__put_internal(item) + getter.set_result(self._get()) + elif self.full(): + raise QueueFull + else: + self.__put_internal(item) + + def get(self, timeout=None): + """Remove and return an item from the queue. + + Returns a Future which resolves once an item is available, or raises + `tornado.gen.TimeoutError` after a timeout. + """ + future = Future() + try: + future.set_result(self.get_nowait()) + except QueueEmpty: + self._getters.append(future) + _set_timeout(future, timeout) + return future + + def get_nowait(self): + """Remove and return an item from the queue without blocking. + + Return an item if one is immediately available, else raise + `QueueEmpty`. + """ + self._consume_expired() + if self._putters: + assert self.full(), "queue not full, why are putters waiting?" + item, putter = self._putters.popleft() + self.__put_internal(item) + putter.set_result(None) + return self._get() + elif self.qsize(): + return self._get() + else: + raise QueueEmpty + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each `.get` used to fetch a task, a + subsequent call to `.task_done` tells the queue that the processing + on the task is complete. + + If a `.join` is blocking, it resumes when all items have been + processed; that is, when every `.put` is matched by a `.task_done`. + + Raises `ValueError` if called more times than `.put`. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + def join(self, timeout=None): + """Block until all items in the queue are processed. + + Returns a Future, which raises `tornado.gen.TimeoutError` after a + timeout. + """ + return self._finished.wait(timeout) + + # These three are overridable in subclasses. + def _init(self): + self._queue = collections.deque() + + def _get(self): + return self._queue.popleft() + + def _put(self, item): + self._queue.append(item) + # End of the overridable methods. + + def __put_internal(self, item): + self._unfinished_tasks += 1 + self._finished.clear() + self._put(item) + + def _consume_expired(self): + # Remove timed-out waiters. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + while self._getters and self._getters[0].done(): + self._getters.popleft() + + def __repr__(self): + return '<%s at %s %s>' % ( + type(self).__name__, hex(id(self)), self._format()) + + def __str__(self): + return '<%s %s>' % (type(self).__name__, self._format()) + + def _format(self): + result = 'maxsize=%r' % (self.maxsize, ) + if getattr(self, '_queue', None): + result += ' queue=%r' % self._queue + if self._getters: + result += ' getters[%s]' % len(self._getters) + if self._putters: + result += ' putters[%s]' % len(self._putters) + if self._unfinished_tasks: + result += ' tasks=%s' % self._unfinished_tasks + return result + + +class PriorityQueue(Queue): + """A `.Queue` that retrieves entries in priority order, lowest first. + + Entries are typically tuples like ``(priority number, data)``. + + .. testcode:: + + q = queues.PriorityQueue() + q.put((1, 'medium-priority item')) + q.put((0, 'high-priority item')) + q.put((10, 'low-priority item')) + + print(q.get_nowait()) + print(q.get_nowait()) + print(q.get_nowait()) + + .. testoutput:: + + (0, 'high-priority item') + (1, 'medium-priority item') + (10, 'low-priority item') + """ + def _init(self): + self._queue = [] + + def _put(self, item): + heapq.heappush(self._queue, item) + + def _get(self): + return heapq.heappop(self._queue) + + +class LifoQueue(Queue): + """A `.Queue` that retrieves the most recently put items first. + + .. testcode:: + + q = queues.LifoQueue() + q.put(3) + q.put(2) + q.put(1) + + print(q.get_nowait()) + print(q.get_nowait()) + print(q.get_nowait()) + + .. testoutput:: + + 1 + 2 + 3 + """ + def _init(self): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index cf30f072..cf58e162 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -7,7 +7,7 @@ from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _ from tornado import httputil from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters from tornado.iostream import StreamClosedError -from tornado.netutil import Resolver, OverrideResolver +from tornado.netutil import Resolver, OverrideResolver, _client_ssl_defaults from tornado.log import gen_log from tornado import stack_context from tornado.tcpclient import TCPClient @@ -34,7 +34,7 @@ except ImportError: ssl = None try: - import lib.certifi + import certifi except ImportError: certifi = None @@ -50,9 +50,6 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): """Non-blocking HTTP client with no external dependencies. This class implements an HTTP 1.1 client on top of Tornado's IOStreams. - It does not currently implement all applicable parts of the HTTP - specification, but it does enough to work with major web service APIs. - Some features found in the curl-based AsyncHTTPClient are not yet supported. In particular, proxies are not supported, connections are not reused, and callers cannot select the network interface to be @@ -60,25 +57,39 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): """ def initialize(self, io_loop, max_clients=10, hostname_mapping=None, max_buffer_size=104857600, - resolver=None, defaults=None, max_header_size=None): + resolver=None, defaults=None, max_header_size=None, + max_body_size=None): """Creates a AsyncHTTPClient. Only a single AsyncHTTPClient instance exists per IOLoop in order to provide limitations on the number of pending connections. - force_instance=True may be used to suppress this behavior. + ``force_instance=True`` may be used to suppress this behavior. - max_clients is the number of concurrent requests that can be - in progress. Note that this arguments are only used when the - client is first created, and will be ignored when an existing - client is reused. + Note that because of this implicit reuse, unless ``force_instance`` + is used, only the first call to the constructor actually uses + its arguments. It is recommended to use the ``configure`` method + instead of the constructor to ensure that arguments take effect. - hostname_mapping is a dictionary mapping hostnames to IP addresses. + ``max_clients`` is the number of concurrent requests that can be + in progress; when this limit is reached additional requests will be + queued. Note that time spent waiting in this queue still counts + against the ``request_timeout``. + + ``hostname_mapping`` is a dictionary mapping hostnames to IP addresses. It can be used to make local DNS changes when modifying system-wide - settings like /etc/hosts is not possible or desirable (e.g. in + settings like ``/etc/hosts`` is not possible or desirable (e.g. in unittests). - max_buffer_size is the number of bytes that can be read by IOStream. It - defaults to 100mb. + ``max_buffer_size`` (default 100MB) is the number of bytes + that can be read into memory at once. ``max_body_size`` + (defaults to ``max_buffer_size``) is the largest response body + that the client will accept. Without a + ``streaming_callback``, the smaller of these two limits + applies; with a ``streaming_callback`` only ``max_body_size`` + does. + + .. versionchanged:: 4.2 + Added the ``max_body_size`` argument. """ super(SimpleAsyncHTTPClient, self).initialize(io_loop, defaults=defaults) @@ -88,6 +99,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): self.waiting = {} self.max_buffer_size = max_buffer_size self.max_header_size = max_header_size + self.max_body_size = max_body_size # TCPClient could create a Resolver for us, but we have to do it # ourselves to support hostname_mapping. if resolver: @@ -135,10 +147,14 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): release_callback = functools.partial(self._release_fetch, key) self._handle_request(request, release_callback, callback) + def _connection_class(self): + return _HTTPConnection + def _handle_request(self, request, release_callback, final_callback): - _HTTPConnection(self.io_loop, self, request, release_callback, - final_callback, self.max_buffer_size, self.tcp_client, - self.max_header_size) + self._connection_class()( + self.io_loop, self, request, release_callback, + final_callback, self.max_buffer_size, self.tcp_client, + self.max_header_size, self.max_body_size) def _release_fetch(self, key): del self.active[key] @@ -166,7 +182,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): def __init__(self, io_loop, client, request, release_callback, final_callback, max_buffer_size, tcp_client, - max_header_size): + max_header_size, max_body_size): self.start_time = io_loop.time() self.io_loop = io_loop self.client = client @@ -176,6 +192,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): self.max_buffer_size = max_buffer_size self.tcp_client = tcp_client self.max_header_size = max_header_size + self.max_body_size = max_body_size self.code = None self.headers = None self.chunks = [] @@ -193,12 +210,8 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): netloc = self.parsed.netloc if "@" in netloc: userpass, _, netloc = netloc.rpartition("@") - match = re.match(r'^(.+):(\d+)$', netloc) - if match: - host = match.group(1) - port = int(match.group(2)) - else: - host = netloc + host, port = httputil.split_host_and_port(netloc) + if port is None: port = 443 if self.parsed.scheme == "https" else 80 if re.match(r'^\[.*\]$', host): # raw ipv6 addresses in urls are enclosed in brackets @@ -224,12 +237,24 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): def _get_ssl_options(self, scheme): if scheme == "https": + if self.request.ssl_options is not None: + return self.request.ssl_options + # If we are using the defaults, don't construct a + # new SSLContext. + if (self.request.validate_cert and + self.request.ca_certs is None and + self.request.client_cert is None and + self.request.client_key is None): + return _client_ssl_defaults ssl_options = {} if self.request.validate_cert: ssl_options["cert_reqs"] = ssl.CERT_REQUIRED if self.request.ca_certs is not None: ssl_options["ca_certs"] = self.request.ca_certs - else: + elif not hasattr(ssl, 'create_default_context'): + # When create_default_context is present, + # we can omit the "ca_certs" parameter entirely, + # which avoids the dependency on "certifi" for py34. ssl_options["ca_certs"] = _default_ca_certs() if self.request.client_key is not None: ssl_options["keyfile"] = self.request.client_key @@ -321,9 +346,9 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): body_present = (self.request.body is not None or self.request.body_producer is not None) if ((body_expected and not body_present) or - (body_present and not body_expected)): + (body_present and not body_expected)): raise ValueError( - 'Body must %sbe None for method %s (unelss ' + 'Body must %sbe None for method %s (unless ' 'allow_nonstandard_methods is true)' % ('not ' if body_expected else '', self.request.method)) if self.request.expect_100_continue: @@ -340,26 +365,30 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): self.request.headers["Accept-Encoding"] = "gzip" req_path = ((self.parsed.path or '/') + (('?' + self.parsed.query) if self.parsed.query else '')) - self.stream.set_nodelay(True) - self.connection = HTTP1Connection( - self.stream, True, - HTTP1ConnectionParameters( - no_keep_alive=True, - max_header_size=self.max_header_size, - decompress=self.request.decompress_response), - self._sockaddr) + self.connection = self._create_connection(stream) start_line = httputil.RequestStartLine(self.request.method, - req_path, 'HTTP/1.1') + req_path, '') self.connection.write_headers(start_line, self.request.headers) if self.request.expect_100_continue: self._read_response() else: self._write_body(True) + def _create_connection(self, stream): + stream.set_nodelay(True) + connection = HTTP1Connection( + stream, True, + HTTP1ConnectionParameters( + no_keep_alive=True, + max_header_size=self.max_header_size, + max_body_size=self.max_body_size, + decompress=self.request.decompress_response), + self._sockaddr) + return connection + def _write_body(self, start_read): if self.request.body is not None: self.connection.write(self.request.body) - self.connection.finish() elif self.request.body_producer is not None: fut = self.request.body_producer(self.connection.write) if is_future(fut): @@ -370,7 +399,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): self._read_response() self.io_loop.add_future(fut, on_body_written) return - self.connection.finish() + self.connection.finish() if start_read: self._read_response() diff --git a/tornado/tcpclient.py b/tornado/tcpclient.py index 0abbea20..f594d91b 100644 --- a/tornado/tcpclient.py +++ b/tornado/tcpclient.py @@ -111,6 +111,7 @@ class _Connector(object): if self.timeout is not None: # If the first attempt failed, don't wait for the # timeout to try an address from the secondary queue. + self.io_loop.remove_timeout(self.timeout) self.on_timeout() return self.clear_timeout() @@ -135,6 +136,9 @@ class _Connector(object): class TCPClient(object): """A non-blocking TCP connection factory. + + .. versionchanged:: 4.1 + The ``io_loop`` argument is deprecated. """ def __init__(self, resolver=None, io_loop=None): self.io_loop = io_loop or IOLoop.current() diff --git a/tornado/tcpserver.py b/tornado/tcpserver.py index 427acec5..c9d148a8 100644 --- a/tornado/tcpserver.py +++ b/tornado/tcpserver.py @@ -41,14 +41,15 @@ class TCPServer(object): To use `TCPServer`, define a subclass which overrides the `handle_stream` method. - To make this server serve SSL traffic, send the ssl_options dictionary - argument with the arguments required for the `ssl.wrap_socket` method, - including "certfile" and "keyfile":: + To make this server serve SSL traffic, send the ``ssl_options`` keyword + argument with an `ssl.SSLContext` object. For compatibility with older + versions of Python ``ssl_options`` may also be a dictionary of keyword + arguments for the `ssl.wrap_socket` method.:: - TCPServer(ssl_options={ - "certfile": os.path.join(data_dir, "mydomain.crt"), - "keyfile": os.path.join(data_dir, "mydomain.key"), - }) + ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"), + os.path.join(data_dir, "mydomain.key")) + TCPServer(ssl_options=ssl_ctx) `TCPServer` initialization follows one of three patterns: @@ -56,14 +57,14 @@ class TCPServer(object): server = TCPServer() server.listen(8888) - IOLoop.instance().start() + IOLoop.current().start() 2. `bind`/`start`: simple multi-process:: server = TCPServer() server.bind(8888) server.start(0) # Forks multiple sub-processes - IOLoop.instance().start() + IOLoop.current().start() When using this interface, an `.IOLoop` must *not* be passed to the `TCPServer` constructor. `start` will always start @@ -75,7 +76,7 @@ class TCPServer(object): tornado.process.fork_processes(0) server = TCPServer() server.add_sockets(sockets) - IOLoop.instance().start() + IOLoop.current().start() The `add_sockets` interface is more complicated, but it can be used with `tornado.process.fork_processes` to give you more @@ -95,7 +96,7 @@ class TCPServer(object): self._pending_sockets = [] self._started = False self.max_buffer_size = max_buffer_size - self.read_chunk_size = None + self.read_chunk_size = read_chunk_size # Verify the SSL options. Otherwise we don't get errors until clients # connect. This doesn't verify that the keys are legitimate, but @@ -212,7 +213,20 @@ class TCPServer(object): sock.close() def handle_stream(self, stream, address): - """Override to handle a new `.IOStream` from an incoming connection.""" + """Override to handle a new `.IOStream` from an incoming connection. + + This method may be a coroutine; if so any exceptions it raises + asynchronously will be logged. Accepting of incoming connections + will not be blocked by this coroutine. + + If this `TCPServer` is configured for SSL, ``handle_stream`` + may be called before the SSL handshake has completed. Use + `.SSLIOStream.wait_for_handshake` if you need to verify the client's + certificate or use NPN/ALPN. + + .. versionchanged:: 4.2 + Added the option for this method to be a coroutine. + """ raise NotImplementedError() def _handle_connection(self, connection, address): @@ -252,6 +266,8 @@ class TCPServer(object): stream = IOStream(connection, io_loop=self.io_loop, max_buffer_size=self.max_buffer_size, read_chunk_size=self.read_chunk_size) - self.handle_stream(stream, address) + future = self.handle_stream(stream, address) + if future is not None: + self.io_loop.add_future(future, lambda f: f.result()) except Exception: app_log.error("Error in connection callback", exc_info=True) diff --git a/tornado/test/README b/tornado/test/README deleted file mode 100644 index 33edba98..00000000 --- a/tornado/test/README +++ /dev/null @@ -1,4 +0,0 @@ -Test coverage is almost non-existent, but it's a start. Be sure to -set PYTHONPATH appropriately (generally to the root directory of your -tornado checkout) when running tests to make sure you're getting the -version of the tornado package that you expect. \ No newline at end of file diff --git a/tornado/test/asyncio_test.py b/tornado/test/asyncio_test.py new file mode 100644 index 00000000..1be0e54f --- /dev/null +++ b/tornado/test/asyncio_test.py @@ -0,0 +1,69 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, with_statement + +import sys +import textwrap + +from tornado import gen +from tornado.testing import AsyncTestCase, gen_test +from tornado.test.util import unittest + +try: + from tornado.platform.asyncio import asyncio, AsyncIOLoop +except ImportError: + asyncio = None + +skipIfNoSingleDispatch = unittest.skipIf( + gen.singledispatch is None, "singledispatch module not present") + + +@unittest.skipIf(asyncio is None, "asyncio module not present") +class AsyncIOLoopTest(AsyncTestCase): + def get_new_ioloop(self): + io_loop = AsyncIOLoop() + asyncio.set_event_loop(io_loop.asyncio_loop) + return io_loop + + def test_asyncio_callback(self): + # Basic test that the asyncio loop is set up correctly. + asyncio.get_event_loop().call_soon(self.stop) + self.wait() + + @skipIfNoSingleDispatch + @gen_test + def test_asyncio_future(self): + # Test that we can yield an asyncio future from a tornado coroutine. + # Without 'yield from', we must wrap coroutines in asyncio.async. + x = yield asyncio.async( + asyncio.get_event_loop().run_in_executor(None, lambda: 42)) + self.assertEqual(x, 42) + + @unittest.skipIf(sys.version_info < (3, 3), + 'PEP 380 not available') + @skipIfNoSingleDispatch + @gen_test + def test_asyncio_yield_from(self): + # Test that we can use asyncio coroutines with 'yield from' + # instead of asyncio.async(). This requires python 3.3 syntax. + global_namespace = dict(globals(), **locals()) + local_namespace = {} + exec(textwrap.dedent(""" + @gen.coroutine + def f(): + event_loop = asyncio.get_event_loop() + x = yield from event_loop.run_in_executor(None, lambda: 42) + return x + """), global_namespace, local_namespace) + result = yield local_namespace['f']() + self.assertEqual(result, 42) diff --git a/tornado/test/auth_test.py b/tornado/test/auth_test.py index 254e1ae1..541ecf16 100644 --- a/tornado/test/auth_test.py +++ b/tornado/test/auth_test.py @@ -5,7 +5,7 @@ from __future__ import absolute_import, division, print_function, with_statement -from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin, TwitterMixin, GoogleMixin, AuthError +from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin, TwitterMixin, AuthError from tornado.concurrent import Future from tornado.escape import json_decode from tornado import gen @@ -238,28 +238,6 @@ class TwitterServerVerifyCredentialsHandler(RequestHandler): self.write(dict(screen_name='foo', name='Foo')) -class GoogleOpenIdClientLoginHandler(RequestHandler, GoogleMixin): - def initialize(self, test): - self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate') - - @asynchronous - def get(self): - if self.get_argument("openid.mode", None): - self.get_authenticated_user(self.on_user) - return - res = self.authenticate_redirect() - assert isinstance(res, Future) - assert res.done() - - def on_user(self, user): - if user is None: - raise Exception("user is None") - self.finish(user) - - def get_auth_http_client(self): - return self.settings['http_client'] - - class AuthTest(AsyncHTTPTestCase): def get_app(self): return Application( @@ -286,7 +264,6 @@ class AuthTest(AsyncHTTPTestCase): ('/twitter/client/login_gen_coroutine', TwitterClientLoginGenCoroutineHandler, dict(test=self)), ('/twitter/client/show_user', TwitterClientShowUserHandler, dict(test=self)), ('/twitter/client/show_user_future', TwitterClientShowUserFutureHandler, dict(test=self)), - ('/google/client/openid_login', GoogleOpenIdClientLoginHandler, dict(test=self)), # simulated servers ('/openid/server/authenticate', OpenIdServerAuthenticateHandler), @@ -436,16 +413,3 @@ class AuthTest(AsyncHTTPTestCase): response = self.fetch('/twitter/client/show_user_future?name=error') self.assertEqual(response.code, 500) self.assertIn(b'Error response HTTP 500', response.body) - - def test_google_redirect(self): - # same as test_openid_redirect - response = self.fetch('/google/client/openid_login', follow_redirects=False) - self.assertEqual(response.code, 302) - self.assertTrue( - '/openid/server/authenticate?' in response.headers['Location']) - - def test_google_get_user(self): - response = self.fetch('/google/client/openid_login?openid.mode=blah&openid.ns.ax=http://openid.net/srv/ax/1.0&openid.ax.type.email=http://axschema.org/contact/email&openid.ax.value.email=foo@example.com', follow_redirects=False) - response.rethrow() - parsed = json_decode(response.body) - self.assertEqual(parsed["email"], "foo@example.com") diff --git a/tornado/test/concurrent_test.py b/tornado/test/concurrent_test.py index 5e93ad6a..bf90ad0e 100644 --- a/tornado/test/concurrent_test.py +++ b/tornado/test/concurrent_test.py @@ -21,13 +21,14 @@ import socket import sys import traceback -from tornado.concurrent import Future, return_future, ReturnValueIgnoredError +from tornado.concurrent import Future, return_future, ReturnValueIgnoredError, run_on_executor from tornado.escape import utf8, to_unicode from tornado import gen from tornado.iostream import IOStream from tornado import stack_context from tornado.tcpserver import TCPServer from tornado.testing import AsyncTestCase, LogTrapTestCase, bind_unused_port, gen_test +from tornado.test.util import unittest try: @@ -334,3 +335,81 @@ class DecoratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase): class GeneratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase): client_class = GeneratorCapClient + + +@unittest.skipIf(futures is None, "concurrent.futures module not present") +class RunOnExecutorTest(AsyncTestCase): + @gen_test + def test_no_calling(self): + class Object(object): + def __init__(self, io_loop): + self.io_loop = io_loop + self.executor = futures.thread.ThreadPoolExecutor(1) + + @run_on_executor + def f(self): + return 42 + + o = Object(io_loop=self.io_loop) + answer = yield o.f() + self.assertEqual(answer, 42) + + @gen_test + def test_call_with_no_args(self): + class Object(object): + def __init__(self, io_loop): + self.io_loop = io_loop + self.executor = futures.thread.ThreadPoolExecutor(1) + + @run_on_executor() + def f(self): + return 42 + + o = Object(io_loop=self.io_loop) + answer = yield o.f() + self.assertEqual(answer, 42) + + @gen_test + def test_call_with_io_loop(self): + class Object(object): + def __init__(self, io_loop): + self._io_loop = io_loop + self.executor = futures.thread.ThreadPoolExecutor(1) + + @run_on_executor(io_loop='_io_loop') + def f(self): + return 42 + + o = Object(io_loop=self.io_loop) + answer = yield o.f() + self.assertEqual(answer, 42) + + @gen_test + def test_call_with_executor(self): + class Object(object): + def __init__(self, io_loop): + self.io_loop = io_loop + self.__executor = futures.thread.ThreadPoolExecutor(1) + + @run_on_executor(executor='_Object__executor') + def f(self): + return 42 + + o = Object(io_loop=self.io_loop) + answer = yield o.f() + self.assertEqual(answer, 42) + + @gen_test + def test_call_with_both(self): + class Object(object): + def __init__(self, io_loop): + self._io_loop = io_loop + self.__executor = futures.thread.ThreadPoolExecutor(1) + + @run_on_executor(io_loop='_io_loop', executor='_Object__executor') + def f(self): + return 42 + + o = Object(io_loop=self.io_loop) + answer = yield o.f() + self.assertEqual(answer, 42) diff --git a/tornado/test/curl_httpclient_test.py b/tornado/test/curl_httpclient_test.py index 8d7065df..3ac21f4d 100644 --- a/tornado/test/curl_httpclient_test.py +++ b/tornado/test/curl_httpclient_test.py @@ -8,7 +8,7 @@ from tornado.stack_context import ExceptionStackContext from tornado.testing import AsyncHTTPTestCase from tornado.test import httpclient_test from tornado.test.util import unittest -from tornado.web import Application, RequestHandler, URLSpec +from tornado.web import Application, RequestHandler try: diff --git a/tornado/test/escape_test.py b/tornado/test/escape_test.py index f6404288..98a23463 100644 --- a/tornado/test/escape_test.py +++ b/tornado/test/escape_test.py @@ -217,9 +217,8 @@ class EscapeTestCase(unittest.TestCase): self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9") def test_squeeze(self): - self.assertEqual(squeeze(u('sequences of whitespace chars')) - , u('sequences of whitespace chars')) - + self.assertEqual(squeeze(u('sequences of whitespace chars')), u('sequences of whitespace chars')) + def test_recursive_unicode(self): tests = { 'dict': {b"foo": b"bar"}, diff --git a/tornado/test/gen_test.py b/tornado/test/gen_test.py index a15cdf73..fdaa0ec8 100644 --- a/tornado/test/gen_test.py +++ b/tornado/test/gen_test.py @@ -62,6 +62,11 @@ class GenEngineTest(AsyncTestCase): def async_future(self, result, callback): self.io_loop.add_callback(callback, result) + @gen.coroutine + def async_exception(self, e): + yield gen.moment + raise e + def test_no_yield(self): @gen.engine def f(): @@ -385,11 +390,56 @@ class GenEngineTest(AsyncTestCase): results = yield [self.async_future(1), self.async_future(2)] self.assertEqual(results, [1, 2]) + @gen_test + def test_multi_future_duplicate(self): + f = self.async_future(2) + results = yield [self.async_future(1), f, self.async_future(3), f] + self.assertEqual(results, [1, 2, 3, 2]) + @gen_test def test_multi_dict_future(self): results = yield dict(foo=self.async_future(1), bar=self.async_future(2)) self.assertEqual(results, dict(foo=1, bar=2)) + @gen_test + def test_multi_exceptions(self): + with ExpectLog(app_log, "Multiple exceptions in yield list"): + with self.assertRaises(RuntimeError) as cm: + yield gen.Multi([self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2"))]) + self.assertEqual(str(cm.exception), "error 1") + + # With only one exception, no error is logged. + with self.assertRaises(RuntimeError): + yield gen.Multi([self.async_exception(RuntimeError("error 1")), + self.async_future(2)]) + + # Exception logging may be explicitly quieted. + with self.assertRaises(RuntimeError): + yield gen.Multi([self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2"))], + quiet_exceptions=RuntimeError) + + @gen_test + def test_multi_future_exceptions(self): + with ExpectLog(app_log, "Multiple exceptions in yield list"): + with self.assertRaises(RuntimeError) as cm: + yield [self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2"))] + self.assertEqual(str(cm.exception), "error 1") + + # With only one exception, no error is logged. + with self.assertRaises(RuntimeError): + yield [self.async_exception(RuntimeError("error 1")), + self.async_future(2)] + + # Exception logging may be explicitly quieted. + with self.assertRaises(RuntimeError): + yield gen.multi_future( + [self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2"))], + quiet_exceptions=RuntimeError) + def test_arguments(self): @gen.engine def f(): @@ -816,6 +866,7 @@ class GenCoroutineTest(AsyncTestCase): @gen_test def test_moment(self): calls = [] + @gen.coroutine def f(name, yieldable): for i in range(5): @@ -838,6 +889,34 @@ class GenCoroutineTest(AsyncTestCase): yield [f('a', gen.moment), f('b', immediate)] self.assertEqual(''.join(calls), 'abbbbbaaaa') + @gen_test + def test_sleep(self): + yield gen.sleep(0.01) + self.finished = True + + @skipBefore33 + @gen_test + def test_py3_leak_exception_context(self): + class LeakedException(Exception): + pass + + @gen.coroutine + def inner(iteration): + raise LeakedException(iteration) + + try: + yield inner(1) + except LeakedException as e: + self.assertEqual(str(e), "1") + self.assertIsNone(e.__context__) + + try: + yield inner(2) + except LeakedException as e: + self.assertEqual(str(e), "2") + self.assertIsNone(e.__context__) + + self.finished = True class GenSequenceHandler(RequestHandler): @asynchronous @@ -1031,7 +1110,7 @@ class WithTimeoutTest(AsyncTestCase): self.io_loop.add_timeout(datetime.timedelta(seconds=0.1), lambda: future.set_result('asdf')) result = yield gen.with_timeout(datetime.timedelta(seconds=3600), - future) + future, io_loop=self.io_loop) self.assertEqual(result, 'asdf') @gen_test @@ -1039,16 +1118,17 @@ class WithTimeoutTest(AsyncTestCase): future = Future() self.io_loop.add_timeout( datetime.timedelta(seconds=0.1), - lambda: future.set_exception(ZeroDivisionError)) + lambda: future.set_exception(ZeroDivisionError())) with self.assertRaises(ZeroDivisionError): - yield gen.with_timeout(datetime.timedelta(seconds=3600), future) + yield gen.with_timeout(datetime.timedelta(seconds=3600), + future, io_loop=self.io_loop) @gen_test def test_already_resolved(self): future = Future() future.set_result('asdf') result = yield gen.with_timeout(datetime.timedelta(seconds=3600), - future) + future, io_loop=self.io_loop) self.assertEqual(result, 'asdf') @unittest.skipIf(futures is None, 'futures module not present') @@ -1067,5 +1147,117 @@ class WithTimeoutTest(AsyncTestCase): executor.submit(lambda: None)) +class WaitIteratorTest(AsyncTestCase): + @gen_test + def test_empty_iterator(self): + g = gen.WaitIterator() + self.assertTrue(g.done(), 'empty generator iterated') + + with self.assertRaises(ValueError): + g = gen.WaitIterator(False, bar=False) + + self.assertEqual(g.current_index, None, "bad nil current index") + self.assertEqual(g.current_future, None, "bad nil current future") + + @gen_test + def test_already_done(self): + f1 = Future() + f2 = Future() + f3 = Future() + f1.set_result(24) + f2.set_result(42) + f3.set_result(84) + + g = gen.WaitIterator(f1, f2, f3) + i = 0 + while not g.done(): + r = yield g.next() + # Order is not guaranteed, but the current implementation + # preserves ordering of already-done Futures. + if i == 0: + self.assertEqual(g.current_index, 0) + self.assertIs(g.current_future, f1) + self.assertEqual(r, 24) + elif i == 1: + self.assertEqual(g.current_index, 1) + self.assertIs(g.current_future, f2) + self.assertEqual(r, 42) + elif i == 2: + self.assertEqual(g.current_index, 2) + self.assertIs(g.current_future, f3) + self.assertEqual(r, 84) + i += 1 + + self.assertEqual(g.current_index, None, "bad nil current index") + self.assertEqual(g.current_future, None, "bad nil current future") + + dg = gen.WaitIterator(f1=f1, f2=f2) + + while not dg.done(): + dr = yield dg.next() + if dg.current_index == "f1": + self.assertTrue(dg.current_future == f1 and dr == 24, + "WaitIterator dict status incorrect") + elif dg.current_index == "f2": + self.assertTrue(dg.current_future == f2 and dr == 42, + "WaitIterator dict status incorrect") + else: + self.fail("got bad WaitIterator index {}".format( + dg.current_index)) + + i += 1 + + self.assertEqual(dg.current_index, None, "bad nil current index") + self.assertEqual(dg.current_future, None, "bad nil current future") + + def finish_coroutines(self, iteration, futures): + if iteration == 3: + futures[2].set_result(24) + elif iteration == 5: + futures[0].set_exception(ZeroDivisionError()) + elif iteration == 8: + futures[1].set_result(42) + futures[3].set_result(84) + + if iteration < 8: + self.io_loop.add_callback(self.finish_coroutines, iteration + 1, futures) + + @gen_test + def test_iterator(self): + futures = [Future(), Future(), Future(), Future()] + + self.finish_coroutines(0, futures) + + g = gen.WaitIterator(*futures) + + i = 0 + while not g.done(): + try: + r = yield g.next() + except ZeroDivisionError: + self.assertIs(g.current_future, futures[0], + 'exception future invalid') + else: + if i == 0: + self.assertEqual(r, 24, 'iterator value incorrect') + self.assertEqual(g.current_index, 2, 'wrong index') + elif i == 2: + self.assertEqual(r, 42, 'iterator value incorrect') + self.assertEqual(g.current_index, 1, 'wrong index') + elif i == 3: + self.assertEqual(r, 84, 'iterator value incorrect') + self.assertEqual(g.current_index, 3, 'wrong index') + i += 1 + + @gen_test + def test_no_ref(self): + # In this usage, there is no direct hard reference to the + # WaitIterator itself, only the Future it returns. Since + # WaitIterator uses weak references internally to improve GC + # performance, this used to cause problems. + yield gen.with_timeout(datetime.timedelta(seconds=0.1), + gen.WaitIterator(gen.sleep(0)).next()) + + if __name__ == '__main__': unittest.main() diff --git a/tornado/test/gettext_translations/extract_me.py b/tornado/test/gettext_translations/extract_me.py index 75406ecc..45321cce 100644 --- a/tornado/test/gettext_translations/extract_me.py +++ b/tornado/test/gettext_translations/extract_me.py @@ -1,11 +1,16 @@ +# flake8: noqa # Dummy source file to allow creation of the initial .po file in the # same way as a real project. I'm not entirely sure about the real # workflow here, but this seems to work. # -# 1) xgettext --language=Python --keyword=_:1,2 -d tornado_test extract_me.py -o tornado_test.po -# 2) Edit tornado_test.po, setting CHARSET and setting msgstr +# 1) xgettext --language=Python --keyword=_:1,2 --keyword=pgettext:1c,2 --keyword=pgettext:1c,2,3 extract_me.py -o tornado_test.po +# 2) Edit tornado_test.po, setting CHARSET, Plural-Forms and setting msgstr # 3) msgfmt tornado_test.po -o tornado_test.mo # 4) Put the file in the proper location: $LANG/LC_MESSAGES from __future__ import absolute_import, division, print_function, with_statement _("school") +pgettext("law", "right") +pgettext("good", "right") +pgettext("organization", "club", "clubs", 1) +pgettext("stick", "club", "clubs", 1) diff --git a/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.mo b/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.mo index 089f6c7a..a97bf9c5 100644 Binary files a/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.mo and b/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.mo differ diff --git a/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.po b/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.po index 732ee6da..88d72c86 100644 --- a/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.po +++ b/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: PACKAGE VERSION\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2012-06-14 01:10-0700\n" +"POT-Creation-Date: 2015-01-27 11:05+0300\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" @@ -16,7 +16,32 @@ msgstr "" "MIME-Version: 1.0\n" "Content-Type: text/plain; charset=utf-8\n" "Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=2; plural=(n > 1);\n" -#: extract_me.py:1 +#: extract_me.py:11 msgid "school" msgstr "école" + +#: extract_me.py:12 +msgctxt "law" +msgid "right" +msgstr "le droit" + +#: extract_me.py:13 +msgctxt "good" +msgid "right" +msgstr "le bien" + +#: extract_me.py:14 +msgctxt "organization" +msgid "club" +msgid_plural "clubs" +msgstr[0] "le club" +msgstr[1] "les clubs" + +#: extract_me.py:15 +msgctxt "stick" +msgid "club" +msgid_plural "clubs" +msgstr[0] "le bâton" +msgstr[1] "les bâtons" diff --git a/tornado/test/httpclient_test.py b/tornado/test/httpclient_test.py index bfb50d87..ecc63e4a 100644 --- a/tornado/test/httpclient_test.py +++ b/tornado/test/httpclient_test.py @@ -12,6 +12,7 @@ import datetime from io import BytesIO from tornado.escape import utf8 +from tornado import gen from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient from tornado.httpserver import HTTPServer from tornado.ioloop import IOLoop @@ -52,9 +53,12 @@ class RedirectHandler(RequestHandler): class ChunkHandler(RequestHandler): + @gen.coroutine def get(self): self.write("asdf") self.flush() + # Wait a bit to ensure the chunks are sent and received separately. + yield gen.sleep(0.01) self.write("qwer") @@ -178,6 +182,8 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase): sock, port = bind_unused_port() with closing(sock): def write_response(stream, request_data): + if b"HTTP/1." not in request_data: + self.skipTest("requires HTTP/1.x") stream.write(b"""\ HTTP/1.1 200 OK Transfer-Encoding: chunked @@ -300,23 +306,26 @@ Transfer-Encoding: chunked chunks = [] def header_callback(header_line): - if header_line.startswith('HTTP/'): + if header_line.startswith('HTTP/1.1 101'): + # Upgrading to HTTP/2 + pass + elif header_line.startswith('HTTP/'): first_line.append(header_line) elif header_line != '\r\n': k, v = header_line.split(':', 1) - headers[k] = v.strip() + headers[k.lower()] = v.strip() def streaming_callback(chunk): # All header callbacks are run before any streaming callbacks, # so the header data is available to process the data as it # comes in. - self.assertEqual(headers['Content-Type'], 'text/html; charset=UTF-8') + self.assertEqual(headers['content-type'], 'text/html; charset=UTF-8') chunks.append(chunk) self.fetch('/chunk', header_callback=header_callback, streaming_callback=streaming_callback) - self.assertEqual(len(first_line), 1) - self.assertRegexpMatches(first_line[0], 'HTTP/1.[01] 200 OK\r\n') + self.assertEqual(len(first_line), 1, first_line) + self.assertRegexpMatches(first_line[0], 'HTTP/[0-9]\\.[0-9] 200.*\r\n') self.assertEqual(chunks, [b'asdf', b'qwer']) def test_header_callback_stack_context(self): @@ -327,7 +336,7 @@ Transfer-Encoding: chunked return True def header_callback(header_line): - if header_line.startswith('Content-Type:'): + if header_line.lower().startswith('content-type:'): 1 / 0 with ExceptionStackContext(error_handler): @@ -404,6 +413,11 @@ Transfer-Encoding: chunked self.assertEqual(context.exception.code, 404) self.assertEqual(context.exception.response.code, 404) + @gen_test + def test_future_http_error_no_raise(self): + response = yield self.http_client.fetch(self.get_url('/notfound'), raise_error=False) + self.assertEqual(response.code, 404) + @gen_test def test_reuse_request_from_response(self): # The response.request attribute should be an HTTPRequest, not @@ -454,7 +468,7 @@ Transfer-Encoding: chunked # Twisted's reactor does not. The removeReader call fails and so # do all future removeAll calls (which our tests do at cleanup). # - #def test_post_307(self): + # def test_post_307(self): # response = self.fetch("/redirect?status=307&url=/post", # method="POST", body=b"arg1=foo&arg2=bar") # self.assertEqual(response.body, b"Post arg1: foo, arg2: bar") @@ -536,14 +550,19 @@ class SyncHTTPClientTest(unittest.TestCase): def tearDown(self): def stop_server(): self.server.stop() - self.server_ioloop.stop() + # Delay the shutdown of the IOLoop by one iteration because + # the server may still have some cleanup work left when + # the client finishes with the response (this is noticable + # with http/2, which leaves a Future with an unexamined + # StreamClosedError on the loop). + self.server_ioloop.add_callback(self.server_ioloop.stop) self.server_ioloop.add_callback(stop_server) self.server_thread.join() self.http_client.close() self.server_ioloop.close(all_fds=True) def get_url(self, path): - return 'http://localhost:%d%s' % (self.port, path) + return 'http://127.0.0.1:%d%s' % (self.port, path) def test_sync_client(self): response = self.http_client.fetch(self.get_url('/')) @@ -584,5 +603,5 @@ class HTTPRequestTestCase(unittest.TestCase): def test_if_modified_since(self): http_date = datetime.datetime.utcnow() request = HTTPRequest('http://example.com', if_modified_since=http_date) - self.assertEqual(request.headers, - {'If-Modified-Since': format_timestamp(http_date)}) + self.assertEqual(request.headers, + {'If-Modified-Since': format_timestamp(http_date)}) diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py index 156a027b..f05599dd 100644 --- a/tornado/test/httpserver_test.py +++ b/tornado/test/httpserver_test.py @@ -32,6 +32,7 @@ def read_stream_body(stream, callback): """Reads an HTTP response from `stream` and runs callback with its headers and body.""" chunks = [] + class Delegate(HTTPMessageDelegate): def headers_received(self, start_line, headers): self.headers = headers @@ -161,19 +162,22 @@ class BadSSLOptionsTest(unittest.TestCase): application = Application() module_dir = os.path.dirname(__file__) existing_certificate = os.path.join(module_dir, 'test.crt') + existing_key = os.path.join(module_dir, 'test.key') - self.assertRaises(ValueError, HTTPServer, application, ssl_options={ - "certfile": "/__mising__.crt", + self.assertRaises((ValueError, IOError), + HTTPServer, application, ssl_options={ + "certfile": "/__mising__.crt", }) - self.assertRaises(ValueError, HTTPServer, application, ssl_options={ - "certfile": existing_certificate, - "keyfile": "/__missing__.key" + self.assertRaises((ValueError, IOError), + HTTPServer, application, ssl_options={ + "certfile": existing_certificate, + "keyfile": "/__missing__.key" }) # This actually works because both files exist HTTPServer(application, ssl_options={ "certfile": existing_certificate, - "keyfile": existing_certificate + "keyfile": existing_key, }) @@ -195,14 +199,14 @@ class HTTPConnectionTest(AsyncHTTPTestCase): def get_app(self): return Application(self.get_handlers()) - def raw_fetch(self, headers, body): + def raw_fetch(self, headers, body, newline=b"\r\n"): with closing(IOStream(socket.socket())) as stream: stream.connect(('127.0.0.1', self.get_http_port()), self.stop) self.wait() stream.write( - b"\r\n".join(headers + - [utf8("Content-Length: %d\r\n" % len(body))]) + - b"\r\n" + body) + newline.join(headers + + [utf8("Content-Length: %d" % len(body))]) + + newline + newline + body) read_stream_body(stream, self.stop) headers, body = self.wait() return body @@ -232,12 +236,19 @@ class HTTPConnectionTest(AsyncHTTPTestCase): self.assertEqual(u("\u00f3"), data["filename"]) self.assertEqual(u("\u00fa"), data["filebody"]) + def test_newlines(self): + # We support both CRLF and bare LF as line separators. + for newline in (b"\r\n", b"\n"): + response = self.raw_fetch([b"GET /hello HTTP/1.0"], b"", + newline=newline) + self.assertEqual(response, b'Hello world') + def test_100_continue(self): # Run through a 100-continue interaction by hand: # When given Expect: 100-continue, we get a 100 response after the # headers, and then the real response after the body. stream = IOStream(socket.socket(), io_loop=self.io_loop) - stream.connect(("localhost", self.get_http_port()), callback=self.stop) + stream.connect(("127.0.0.1", self.get_http_port()), callback=self.stop) self.wait() stream.write(b"\r\n".join([b"POST /hello HTTP/1.1", b"Content-Length: 1024", @@ -374,7 +385,7 @@ class HTTPServerRawTest(AsyncHTTPTestCase): def setUp(self): super(HTTPServerRawTest, self).setUp() self.stream = IOStream(socket.socket()) - self.stream.connect(('localhost', self.get_http_port()), self.stop) + self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop) self.wait() def tearDown(self): @@ -555,7 +566,7 @@ class UnixSocketTest(AsyncTestCase): self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n") self.stream.read_until(b"\r\n", self.stop) response = self.wait() - self.assertEqual(response, b"HTTP/1.0 200 OK\r\n") + self.assertEqual(response, b"HTTP/1.1 200 OK\r\n") self.stream.read_until(b"\r\n\r\n", self.stop) headers = HTTPHeaders.parse(self.wait().decode('latin1')) self.stream.read_bytes(int(headers["Content-Length"]), self.stop) @@ -582,6 +593,7 @@ class KeepAliveTest(AsyncHTTPTestCase): class HelloHandler(RequestHandler): def get(self): self.finish('Hello world') + def post(self): self.finish('Hello world') @@ -623,13 +635,13 @@ class KeepAliveTest(AsyncHTTPTestCase): # The next few methods are a crude manual http client def connect(self): self.stream = IOStream(socket.socket(), io_loop=self.io_loop) - self.stream.connect(('localhost', self.get_http_port()), self.stop) + self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop) self.wait() def read_headers(self): self.stream.read_until(b'\r\n', self.stop) first_line = self.wait() - self.assertTrue(first_line.startswith(self.http_version + b' 200'), first_line) + self.assertTrue(first_line.startswith(b'HTTP/1.1 200'), first_line) self.stream.read_until(b'\r\n\r\n', self.stop) header_bytes = self.wait() headers = HTTPHeaders.parse(header_bytes.decode('latin1')) @@ -808,8 +820,8 @@ class StreamingChunkSizeTest(AsyncHTTPTestCase): def get_app(self): class App(HTTPServerConnectionDelegate): - def start_request(self, connection): - return StreamingChunkSizeTest.MessageDelegate(connection) + def start_request(self, server_conn, request_conn): + return StreamingChunkSizeTest.MessageDelegate(request_conn) return App() def fetch_chunk_sizes(self, **kwargs): @@ -856,6 +868,7 @@ class StreamingChunkSizeTest(AsyncHTTPTestCase): def test_chunked_compressed(self): compressed = self.compress(self.BODY) self.assertGreater(len(compressed), 20) + def body_producer(write): write(compressed[:20]) write(compressed[20:]) @@ -900,7 +913,7 @@ class IdleTimeoutTest(AsyncHTTPTestCase): def connect(self): stream = IOStream(socket.socket()) - stream.connect(('localhost', self.get_http_port()), self.stop) + stream.connect(('127.0.0.1', self.get_http_port()), self.stop) self.wait() self.streams.append(stream) return stream @@ -1045,6 +1058,15 @@ class LegacyInterfaceTest(AsyncHTTPTestCase): # delegate interface, and writes its response via request.write # instead of request.connection.write_headers. def handle_request(request): + self.http1 = request.version.startswith("HTTP/1.") + if not self.http1: + # This test will be skipped if we're using HTTP/2, + # so just close it out cleanly using the modern interface. + request.connection.write_headers( + ResponseStartLine('', 200, 'OK'), + HTTPHeaders()) + request.connection.finish() + return message = b"Hello world" request.write(utf8("HTTP/1.1 200 OK\r\n" "Content-Length: %d\r\n\r\n" % len(message))) @@ -1054,4 +1076,6 @@ class LegacyInterfaceTest(AsyncHTTPTestCase): def test_legacy_interface(self): response = self.fetch('/') + if not self.http1: + self.skipTest("requires HTTP/1.x") self.assertEqual(response.body, b"Hello world") diff --git a/tornado/test/httputil_test.py b/tornado/test/httputil_test.py index 5ca5cf9f..6e953601 100644 --- a/tornado/test/httputil_test.py +++ b/tornado/test/httputil_test.py @@ -3,11 +3,13 @@ from __future__ import absolute_import, division, print_function, with_statement from tornado.httputil import url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp, HTTPServerRequest, parse_request_start_line -from tornado.escape import utf8 +from tornado.escape import utf8, native_str from tornado.log import gen_log from tornado.testing import ExpectLog from tornado.test.util import unittest +from tornado.util import u +import copy import datetime import logging import time @@ -228,6 +230,75 @@ Foo: even ("Foo", "bar baz"), ("Foo", "even more lines")]) + def test_unicode_newlines(self): + # Ensure that only \r\n is recognized as a header separator, and not + # the other newline-like unicode characters. + # Characters that are likely to be problematic can be found in + # http://unicode.org/standard/reports/tr13/tr13-5.html + # and cpython's unicodeobject.c (which defines the implementation + # of unicode_type.splitlines(), and uses a different list than TR13). + newlines = [ + u('\u001b'), # VERTICAL TAB + u('\u001c'), # FILE SEPARATOR + u('\u001d'), # GROUP SEPARATOR + u('\u001e'), # RECORD SEPARATOR + u('\u0085'), # NEXT LINE + u('\u2028'), # LINE SEPARATOR + u('\u2029'), # PARAGRAPH SEPARATOR + ] + for newline in newlines: + # Try the utf8 and latin1 representations of each newline + for encoding in ['utf8', 'latin1']: + try: + try: + encoded = newline.encode(encoding) + except UnicodeEncodeError: + # Some chars cannot be represented in latin1 + continue + data = b'Cookie: foo=' + encoded + b'bar' + # parse() wants a native_str, so decode through latin1 + # in the same way the real parser does. + headers = HTTPHeaders.parse( + native_str(data.decode('latin1'))) + expected = [('Cookie', 'foo=' + + native_str(encoded.decode('latin1')) + 'bar')] + self.assertEqual( + expected, list(headers.get_all())) + except Exception: + gen_log.warning("failed while trying %r in %s", + newline, encoding) + raise + + def test_optional_cr(self): + # Both CRLF and LF should be accepted as separators. CR should not be + # part of the data when followed by LF, but it is a normal char + # otherwise (or should bare CR be an error?) + headers = HTTPHeaders.parse( + 'CRLF: crlf\r\nLF: lf\nCR: cr\rMore: more\r\n') + self.assertEqual(sorted(headers.get_all()), + [('Cr', 'cr\rMore: more'), + ('Crlf', 'crlf'), + ('Lf', 'lf'), + ]) + + def test_copy(self): + all_pairs = [('A', '1'), ('A', '2'), ('B', 'c')] + h1 = HTTPHeaders() + for k, v in all_pairs: + h1.add(k, v) + h2 = h1.copy() + h3 = copy.copy(h1) + h4 = copy.deepcopy(h1) + for headers in [h1, h2, h3, h4]: + # All the copies are identical, no matter how they were + # constructed. + self.assertEqual(list(sorted(headers.get_all())), all_pairs) + for headers in [h2, h3, h4]: + # Neither the dict or its member lists are reused. + self.assertIsNot(headers, h1) + self.assertIsNot(headers.get_list('A'), h1.get_list('A')) + + class FormatTimestampTest(unittest.TestCase): # Make sure that all the input types are supported. @@ -264,6 +335,10 @@ class HTTPServerRequestTest(unittest.TestCase): # more required parameters slip in. HTTPServerRequest(uri='/') + def test_body_is_a_byte_string(self): + requets = HTTPServerRequest(uri='/') + self.assertIsInstance(requets.body, bytes) + class ParseRequestStartLineTest(unittest.TestCase): METHOD = "GET" diff --git a/tornado/test/import_test.py b/tornado/test/import_test.py index de7cc0b9..1be6427f 100644 --- a/tornado/test/import_test.py +++ b/tornado/test/import_test.py @@ -1,3 +1,4 @@ +# flake8: noqa from __future__ import absolute_import, division, print_function, with_statement from tornado.test.util import unittest diff --git a/tornado/test/ioloop_test.py b/tornado/test/ioloop_test.py index 7eb7594f..f3a0cbdc 100644 --- a/tornado/test/ioloop_test.py +++ b/tornado/test/ioloop_test.py @@ -11,8 +11,9 @@ import threading import time from tornado import gen -from tornado.ioloop import IOLoop, TimeoutError +from tornado.ioloop import IOLoop, TimeoutError, PollIOLoop, PeriodicCallback from tornado.log import app_log +from tornado.platform.select import _Select from tornado.stack_context import ExceptionStackContext, StackContext, wrap, NullContext from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog from tornado.test.util import unittest, skipIfNonUnix, skipOnTravis @@ -23,6 +24,42 @@ except ImportError: futures = None +class FakeTimeSelect(_Select): + def __init__(self): + self._time = 1000 + super(FakeTimeSelect, self).__init__() + + def time(self): + return self._time + + def sleep(self, t): + self._time += t + + def poll(self, timeout): + events = super(FakeTimeSelect, self).poll(0) + if events: + return events + self._time += timeout + return [] + + +class FakeTimeIOLoop(PollIOLoop): + """IOLoop implementation with a fake and deterministic clock. + + The clock advances as needed to trigger timeouts immediately. + For use when testing code that involves the passage of time + and no external dependencies. + """ + def initialize(self): + self.fts = FakeTimeSelect() + super(FakeTimeIOLoop, self).initialize(impl=self.fts, + time_func=self.fts.time) + + def sleep(self, t): + """Simulate a blocking sleep by advancing the clock.""" + self.fts.sleep(t) + + class TestIOLoop(AsyncTestCase): @skipOnTravis def test_add_callback_wakeup(self): @@ -180,10 +217,12 @@ class TestIOLoop(AsyncTestCase): # t2 should be cancelled by t1, even though it is already scheduled to # be run before the ioloop even looks at it. now = self.io_loop.time() + def t1(): calls[0] = True self.io_loop.remove_timeout(t2_handle) self.io_loop.add_timeout(now + 0.01, t1) + def t2(): calls[1] = True t2_handle = self.io_loop.add_timeout(now + 0.02, t2) @@ -252,6 +291,7 @@ class TestIOLoop(AsyncTestCase): """The handler callback receives the same fd object it passed in.""" server_sock, port = bind_unused_port() fds = [] + def handle_connection(fd, events): fds.append(fd) conn, addr = server_sock.accept() @@ -274,6 +314,7 @@ class TestIOLoop(AsyncTestCase): def test_mixed_fd_fileobj(self): server_sock, port = bind_unused_port() + def f(fd, events): pass self.io_loop.add_handler(server_sock, f, IOLoop.READ) @@ -288,6 +329,7 @@ class TestIOLoop(AsyncTestCase): """Calling start() twice should raise an error, not deadlock.""" returned_from_start = [False] got_exception = [False] + def callback(): try: self.io_loop.start() @@ -305,7 +347,7 @@ class TestIOLoop(AsyncTestCase): # Use a NullContext to keep the exception from being caught by # AsyncTestCase. with NullContext(): - self.io_loop.add_callback(lambda: 1/0) + self.io_loop.add_callback(lambda: 1 / 0) self.io_loop.add_callback(self.stop) with ExpectLog(app_log, "Exception in callback"): self.wait() @@ -316,7 +358,7 @@ class TestIOLoop(AsyncTestCase): @gen.coroutine def callback(): self.io_loop.add_callback(self.stop) - 1/0 + 1 / 0 self.io_loop.add_callback(callback) with ExpectLog(app_log, "Exception in callback"): self.wait() @@ -324,12 +366,12 @@ class TestIOLoop(AsyncTestCase): def test_spawn_callback(self): # An added callback runs in the test's stack_context, so will be # re-arised in wait(). - self.io_loop.add_callback(lambda: 1/0) + self.io_loop.add_callback(lambda: 1 / 0) with self.assertRaises(ZeroDivisionError): self.wait() # A spawned callback is run directly on the IOLoop, so it will be # logged without stopping the test. - self.io_loop.spawn_callback(lambda: 1/0) + self.io_loop.spawn_callback(lambda: 1 / 0) self.io_loop.add_callback(self.stop) with ExpectLog(app_log, "Exception in callback"): self.wait() @@ -344,6 +386,7 @@ class TestIOLoop(AsyncTestCase): # After reading from one fd, remove the other from the IOLoop. chunks = [] + def handle_read(fd, events): chunks.append(fd.recv(1024)) if fd is client: @@ -352,7 +395,7 @@ class TestIOLoop(AsyncTestCase): self.io_loop.remove_handler(client) self.io_loop.add_handler(client, handle_read, self.io_loop.READ) self.io_loop.add_handler(server, handle_read, self.io_loop.READ) - self.io_loop.call_later(0.01, self.stop) + self.io_loop.call_later(0.03, self.stop) self.wait() # Only one fd was read; the other was cleanly removed. @@ -520,5 +563,47 @@ class TestIOLoopRunSync(unittest.TestCase): self.assertRaises(TimeoutError, self.io_loop.run_sync, f, timeout=0.01) +class TestPeriodicCallback(unittest.TestCase): + def setUp(self): + self.io_loop = FakeTimeIOLoop() + self.io_loop.make_current() + + def tearDown(self): + self.io_loop.close() + + def test_basic(self): + calls = [] + + def cb(): + calls.append(self.io_loop.time()) + pc = PeriodicCallback(cb, 10000) + pc.start() + self.io_loop.call_later(50, self.io_loop.stop) + self.io_loop.start() + self.assertEqual(calls, [1010, 1020, 1030, 1040, 1050]) + + def test_overrun(self): + sleep_durations = [9, 9, 10, 11, 20, 20, 35, 35, 0, 0] + expected = [ + 1010, 1020, 1030, # first 3 calls on schedule + 1050, 1070, # next 2 delayed one cycle + 1100, 1130, # next 2 delayed 2 cycles + 1170, 1210, # next 2 delayed 3 cycles + 1220, 1230, # then back on schedule. + ] + calls = [] + + def cb(): + calls.append(self.io_loop.time()) + if not sleep_durations: + self.io_loop.stop() + return + self.io_loop.sleep(sleep_durations.pop(0)) + pc = PeriodicCallback(cb, 10000) + pc.start() + self.io_loop.start() + self.assertEqual(calls, expected) + + if __name__ == "__main__": unittest.main() diff --git a/tornado/test/iostream_test.py b/tornado/test/iostream_test.py index f51caeaf..45df6b50 100644 --- a/tornado/test/iostream_test.py +++ b/tornado/test/iostream_test.py @@ -7,10 +7,10 @@ from tornado.httputil import HTTPHeaders from tornado.log import gen_log, app_log from tornado.netutil import ssl_wrap_socket from tornado.stack_context import NullContext +from tornado.tcpserver import TCPServer from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test -from tornado.test.util import unittest, skipIfNonUnix +from tornado.test.util import unittest, skipIfNonUnix, refusing_port from tornado.web import RequestHandler, Application -import lib.certifi import errno import logging import os @@ -51,18 +51,18 @@ class TestIOStreamWebMixin(object): def test_read_until_close(self): stream = self._make_client_iostream() - stream.connect(('localhost', self.get_http_port()), callback=self.stop) + stream.connect(('127.0.0.1', self.get_http_port()), callback=self.stop) self.wait() stream.write(b"GET / HTTP/1.0\r\n\r\n") stream.read_until_close(self.stop) data = self.wait() - self.assertTrue(data.startswith(b"HTTP/1.0 200")) + self.assertTrue(data.startswith(b"HTTP/1.1 200")) self.assertTrue(data.endswith(b"Hello")) def test_read_zero_bytes(self): self.stream = self._make_client_iostream() - self.stream.connect(("localhost", self.get_http_port()), + self.stream.connect(("127.0.0.1", self.get_http_port()), callback=self.stop) self.wait() self.stream.write(b"GET / HTTP/1.0\r\n\r\n") @@ -70,7 +70,7 @@ class TestIOStreamWebMixin(object): # normal read self.stream.read_bytes(9, self.stop) data = self.wait() - self.assertEqual(data, b"HTTP/1.0 ") + self.assertEqual(data, b"HTTP/1.1 ") # zero bytes self.stream.read_bytes(0, self.stop) @@ -91,7 +91,7 @@ class TestIOStreamWebMixin(object): def connected_callback(): connected[0] = True self.stop() - stream.connect(("localhost", self.get_http_port()), + stream.connect(("127.0.0.1", self.get_http_port()), callback=connected_callback) # unlike the previous tests, try to write before the connection # is complete. @@ -121,11 +121,11 @@ class TestIOStreamWebMixin(object): """Basic test of IOStream's ability to return Futures.""" stream = self._make_client_iostream() connect_result = yield stream.connect( - ("localhost", self.get_http_port())) + ("127.0.0.1", self.get_http_port())) self.assertIs(connect_result, stream) yield stream.write(b"GET / HTTP/1.0\r\n\r\n") first_line = yield stream.read_until(b"\r\n") - self.assertEqual(first_line, b"HTTP/1.0 200 OK\r\n") + self.assertEqual(first_line, b"HTTP/1.1 200 OK\r\n") # callback=None is equivalent to no callback. header_data = yield stream.read_until(b"\r\n\r\n", callback=None) headers = HTTPHeaders.parse(header_data.decode('latin1')) @@ -137,7 +137,7 @@ class TestIOStreamWebMixin(object): @gen_test def test_future_close_while_reading(self): stream = self._make_client_iostream() - yield stream.connect(("localhost", self.get_http_port())) + yield stream.connect(("127.0.0.1", self.get_http_port())) yield stream.write(b"GET / HTTP/1.0\r\n\r\n") with self.assertRaises(StreamClosedError): yield stream.read_bytes(1024 * 1024) @@ -147,7 +147,7 @@ class TestIOStreamWebMixin(object): def test_future_read_until_close(self): # Ensure that the data comes through before the StreamClosedError. stream = self._make_client_iostream() - yield stream.connect(("localhost", self.get_http_port())) + yield stream.connect(("127.0.0.1", self.get_http_port())) yield stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n") yield stream.read_until(b"\r\n\r\n") body = yield stream.read_until_close() @@ -217,17 +217,18 @@ class TestIOStreamMixin(object): # When a connection is refused, the connect callback should not # be run. (The kqueue IOLoop used to behave differently from the # epoll IOLoop in this respect) - server_socket, port = bind_unused_port() - server_socket.close() + cleanup_func, port = refusing_port() + self.addCleanup(cleanup_func) stream = IOStream(socket.socket(), self.io_loop) self.connect_called = False def connect_callback(): self.connect_called = True + self.stop() stream.set_close_callback(self.stop) # log messages vary by platform and ioloop implementation with ExpectLog(gen_log, ".*", required=False): - stream.connect(("localhost", port), connect_callback) + stream.connect(("127.0.0.1", port), connect_callback) self.wait() self.assertFalse(self.connect_called) self.assertTrue(isinstance(stream.error, socket.error), stream.error) @@ -248,7 +249,8 @@ class TestIOStreamMixin(object): # opendns and some ISPs return bogus addresses for nonexistent # domains instead of the proper error codes). with ExpectLog(gen_log, "Connect error"): - stream.connect(('an invalid domain', 54321)) + stream.connect(('an invalid domain', 54321), callback=self.stop) + self.wait() self.assertTrue(isinstance(stream.error, socket.gaierror), stream.error) def test_read_callback_error(self): @@ -308,6 +310,7 @@ class TestIOStreamMixin(object): def streaming_callback(data): chunks.append(data) self.stop() + def close_callback(data): assert not data, data closed[0] = True @@ -325,6 +328,31 @@ class TestIOStreamMixin(object): server.close() client.close() + def test_streaming_until_close_future(self): + server, client = self.make_iostream_pair() + try: + chunks = [] + + @gen.coroutine + def client_task(): + yield client.read_until_close(streaming_callback=chunks.append) + + @gen.coroutine + def server_task(): + yield server.write(b"1234") + yield gen.sleep(0.01) + yield server.write(b"5678") + server.close() + + @gen.coroutine + def f(): + yield [client_task(), server_task()] + self.io_loop.run_sync(f) + self.assertEqual(chunks, [b"1234", b"5678"]) + finally: + server.close() + client.close() + def test_delayed_close_callback(self): # The scenario: Server closes the connection while there is a pending # read that can be served out of buffered data. The client does not @@ -353,6 +381,7 @@ class TestIOStreamMixin(object): def test_future_delayed_close_callback(self): # Same as test_delayed_close_callback, but with the future interface. server, client = self.make_iostream_pair() + # We can't call make_iostream_pair inside a gen_test function # because the ioloop is not reentrant. @gen_test @@ -532,6 +561,7 @@ class TestIOStreamMixin(object): # and IOStream._maybe_add_error_listener. server, client = self.make_iostream_pair() closed = [False] + def close_callback(): closed[0] = True self.stop() @@ -724,6 +754,26 @@ class TestIOStreamMixin(object): server.close() client.close() + def test_flow_control(self): + MB = 1024 * 1024 + server, client = self.make_iostream_pair(max_buffer_size=5 * MB) + try: + # Client writes more than the server will accept. + client.write(b"a" * 10 * MB) + # The server pauses while reading. + server.read_bytes(MB, self.stop) + self.wait() + self.io_loop.call_later(0.1, self.stop) + self.wait() + # The client's writes have been blocked; the server can + # continue to read gradually. + for i in range(9): + server.read_bytes(MB, self.stop) + self.wait() + finally: + server.close() + client.close() + class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase): def _make_client_iostream(self): @@ -732,7 +782,8 @@ class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase): class TestIOStreamWebHTTPS(TestIOStreamWebMixin, AsyncHTTPSTestCase): def _make_client_iostream(self): - return SSLIOStream(socket.socket(), io_loop=self.io_loop) + return SSLIOStream(socket.socket(), io_loop=self.io_loop, + ssl_options=dict(cert_reqs=ssl.CERT_NONE)) class TestIOStream(TestIOStreamMixin, AsyncTestCase): @@ -752,7 +803,9 @@ class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase): return SSLIOStream(connection, io_loop=self.io_loop, **kwargs) def _make_client_iostream(self, connection, **kwargs): - return SSLIOStream(connection, io_loop=self.io_loop, **kwargs) + return SSLIOStream(connection, io_loop=self.io_loop, + ssl_options=dict(cert_reqs=ssl.CERT_NONE), + **kwargs) # This will run some tests that are basically redundant but it's the @@ -820,10 +873,10 @@ class TestIOStreamStartTLS(AsyncTestCase): recv_line = yield self.client_stream.read_until(b"\r\n") self.assertEqual(line, recv_line) - def client_start_tls(self, ssl_options=None): + def client_start_tls(self, ssl_options=None, server_hostname=None): client_stream = self.client_stream self.client_stream = None - return client_stream.start_tls(False, ssl_options) + return client_stream.start_tls(False, ssl_options, server_hostname) def server_start_tls(self, ssl_options=None): server_stream = self.server_stream @@ -842,7 +895,7 @@ class TestIOStreamStartTLS(AsyncTestCase): yield self.server_send_line(b"250 STARTTLS\r\n") yield self.client_send_line(b"STARTTLS\r\n") yield self.server_send_line(b"220 Go ahead\r\n") - client_future = self.client_start_tls() + client_future = self.client_start_tls(dict(cert_reqs=ssl.CERT_NONE)) server_future = self.server_start_tls(_server_ssl_options()) self.client_stream = yield client_future self.server_stream = yield server_future @@ -853,12 +906,123 @@ class TestIOStreamStartTLS(AsyncTestCase): @gen_test def test_handshake_fail(self): - self.server_start_tls(_server_ssl_options()) - client_future = self.client_start_tls( - dict(cert_reqs=ssl.CERT_REQUIRED, ca_certs=certifi.where())) + server_future = self.server_start_tls(_server_ssl_options()) + # Certificates are verified with the default configuration. + client_future = self.client_start_tls(server_hostname="localhost") with ExpectLog(gen_log, "SSL Error"): with self.assertRaises(ssl.SSLError): yield client_future + with self.assertRaises((ssl.SSLError, socket.error)): + yield server_future + + @unittest.skipIf(not hasattr(ssl, 'create_default_context'), + 'ssl.create_default_context not present') + @gen_test + def test_check_hostname(self): + # Test that server_hostname parameter to start_tls is being used. + # The check_hostname functionality is only available in python 2.7 and + # up and in python 3.4 and up. + server_future = self.server_start_tls(_server_ssl_options()) + client_future = self.client_start_tls( + ssl.create_default_context(), + server_hostname=b'127.0.0.1') + with ExpectLog(gen_log, "SSL Error"): + with self.assertRaises(ssl.SSLError): + yield client_future + with self.assertRaises((ssl.SSLError, socket.error)): + yield server_future + + +class WaitForHandshakeTest(AsyncTestCase): + @gen.coroutine + def connect_to_server(self, server_cls): + server = client = None + try: + sock, port = bind_unused_port() + server = server_cls(ssl_options=_server_ssl_options()) + server.add_socket(sock) + + client = SSLIOStream(socket.socket(), + ssl_options=dict(cert_reqs=ssl.CERT_NONE)) + yield client.connect(('127.0.0.1', port)) + self.assertIsNotNone(client.socket.cipher()) + finally: + if server is not None: + server.stop() + if client is not None: + client.close() + + @gen_test + def test_wait_for_handshake_callback(self): + test = self + handshake_future = Future() + + class TestServer(TCPServer): + def handle_stream(self, stream, address): + # The handshake has not yet completed. + test.assertIsNone(stream.socket.cipher()) + self.stream = stream + stream.wait_for_handshake(self.handshake_done) + + def handshake_done(self): + # Now the handshake is done and ssl information is available. + test.assertIsNotNone(self.stream.socket.cipher()) + handshake_future.set_result(None) + + yield self.connect_to_server(TestServer) + yield handshake_future + + @gen_test + def test_wait_for_handshake_future(self): + test = self + handshake_future = Future() + + class TestServer(TCPServer): + def handle_stream(self, stream, address): + test.assertIsNone(stream.socket.cipher()) + test.io_loop.spawn_callback(self.handle_connection, stream) + + @gen.coroutine + def handle_connection(self, stream): + yield stream.wait_for_handshake() + handshake_future.set_result(None) + + yield self.connect_to_server(TestServer) + yield handshake_future + + @gen_test + def test_wait_for_handshake_already_waiting_error(self): + test = self + handshake_future = Future() + + class TestServer(TCPServer): + def handle_stream(self, stream, address): + stream.wait_for_handshake(self.handshake_done) + test.assertRaises(RuntimeError, stream.wait_for_handshake) + + def handshake_done(self): + handshake_future.set_result(None) + + yield self.connect_to_server(TestServer) + yield handshake_future + + @gen_test + def test_wait_for_handshake_already_connected(self): + handshake_future = Future() + + class TestServer(TCPServer): + def handle_stream(self, stream, address): + self.stream = stream + stream.wait_for_handshake(self.handshake_done) + + def handshake_done(self): + self.stream.wait_for_handshake(self.handshake2_done) + + def handshake2_done(self): + handshake_future.set_result(None) + + yield self.connect_to_server(TestServer) + yield handshake_future @skipIfNonUnix diff --git a/tornado/test/locale_test.py b/tornado/test/locale_test.py index d12ad52f..31c57a61 100644 --- a/tornado/test/locale_test.py +++ b/tornado/test/locale_test.py @@ -41,6 +41,12 @@ class TranslationLoaderTest(unittest.TestCase): locale = tornado.locale.get("fr_FR") self.assertTrue(isinstance(locale, tornado.locale.GettextLocale)) self.assertEqual(locale.translate("school"), u("\u00e9cole")) + self.assertEqual(locale.pgettext("law", "right"), u("le droit")) + self.assertEqual(locale.pgettext("good", "right"), u("le bien")) + self.assertEqual(locale.pgettext("organization", "club", "clubs", 1), u("le club")) + self.assertEqual(locale.pgettext("organization", "club", "clubs", 2), u("les clubs")) + self.assertEqual(locale.pgettext("stick", "club", "clubs", 1), u("le b\xe2ton")) + self.assertEqual(locale.pgettext("stick", "club", "clubs", 2), u("les b\xe2tons")) class LocaleDataTest(unittest.TestCase): @@ -58,7 +64,7 @@ class EnglishTest(unittest.TestCase): self.assertEqual(locale.format_date(date, full_format=True), 'April 28, 2013 at 6:35 pm') - self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(seconds=2), full_format=False), + self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(seconds=2), full_format=False), '2 seconds ago') self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(minutes=2), full_format=False), '2 minutes ago') diff --git a/tornado/test/locks_test.py b/tornado/test/locks_test.py new file mode 100644 index 00000000..90bdafaa --- /dev/null +++ b/tornado/test/locks_test.py @@ -0,0 +1,480 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from datetime import timedelta + +from tornado import gen, locks +from tornado.gen import TimeoutError +from tornado.testing import gen_test, AsyncTestCase +from tornado.test.util import unittest + + +class ConditionTest(AsyncTestCase): + def setUp(self): + super(ConditionTest, self).setUp() + self.history = [] + + def record_done(self, future, key): + """Record the resolution of a Future returned by Condition.wait.""" + def callback(_): + if not future.result(): + # wait() resolved to False, meaning it timed out. + self.history.append('timeout') + else: + self.history.append(key) + future.add_done_callback(callback) + + def test_repr(self): + c = locks.Condition() + self.assertIn('Condition', repr(c)) + self.assertNotIn('waiters', repr(c)) + c.wait() + self.assertIn('waiters', repr(c)) + + @gen_test + def test_notify(self): + c = locks.Condition() + self.io_loop.call_later(0.01, c.notify) + yield c.wait() + + def test_notify_1(self): + c = locks.Condition() + self.record_done(c.wait(), 'wait1') + self.record_done(c.wait(), 'wait2') + c.notify(1) + self.history.append('notify1') + c.notify(1) + self.history.append('notify2') + self.assertEqual(['wait1', 'notify1', 'wait2', 'notify2'], + self.history) + + def test_notify_n(self): + c = locks.Condition() + for i in range(6): + self.record_done(c.wait(), i) + + c.notify(3) + + # Callbacks execute in the order they were registered. + self.assertEqual(list(range(3)), self.history) + c.notify(1) + self.assertEqual(list(range(4)), self.history) + c.notify(2) + self.assertEqual(list(range(6)), self.history) + + def test_notify_all(self): + c = locks.Condition() + for i in range(4): + self.record_done(c.wait(), i) + + c.notify_all() + self.history.append('notify_all') + + # Callbacks execute in the order they were registered. + self.assertEqual( + list(range(4)) + ['notify_all'], + self.history) + + @gen_test + def test_wait_timeout(self): + c = locks.Condition() + wait = c.wait(timedelta(seconds=0.01)) + self.io_loop.call_later(0.02, c.notify) # Too late. + yield gen.sleep(0.03) + self.assertFalse((yield wait)) + + @gen_test + def test_wait_timeout_preempted(self): + c = locks.Condition() + + # This fires before the wait times out. + self.io_loop.call_later(0.01, c.notify) + wait = c.wait(timedelta(seconds=0.02)) + yield gen.sleep(0.03) + yield wait # No TimeoutError. + + @gen_test + def test_notify_n_with_timeout(self): + # Register callbacks 0, 1, 2, and 3. Callback 1 has a timeout. + # Wait for that timeout to expire, then do notify(2) and make + # sure everyone runs. Verifies that a timed-out callback does + # not count against the 'n' argument to notify(). + c = locks.Condition() + self.record_done(c.wait(), 0) + self.record_done(c.wait(timedelta(seconds=0.01)), 1) + self.record_done(c.wait(), 2) + self.record_done(c.wait(), 3) + + # Wait for callback 1 to time out. + yield gen.sleep(0.02) + self.assertEqual(['timeout'], self.history) + + c.notify(2) + yield gen.sleep(0.01) + self.assertEqual(['timeout', 0, 2], self.history) + self.assertEqual(['timeout', 0, 2], self.history) + c.notify() + self.assertEqual(['timeout', 0, 2, 3], self.history) + + @gen_test + def test_notify_all_with_timeout(self): + c = locks.Condition() + self.record_done(c.wait(), 0) + self.record_done(c.wait(timedelta(seconds=0.01)), 1) + self.record_done(c.wait(), 2) + + # Wait for callback 1 to time out. + yield gen.sleep(0.02) + self.assertEqual(['timeout'], self.history) + + c.notify_all() + self.assertEqual(['timeout', 0, 2], self.history) + + @gen_test + def test_nested_notify(self): + # Ensure no notifications lost, even if notify() is reentered by a + # waiter calling notify(). + c = locks.Condition() + + # Three waiters. + futures = [c.wait() for _ in range(3)] + + # First and second futures resolved. Second future reenters notify(), + # resolving third future. + futures[1].add_done_callback(lambda _: c.notify()) + c.notify(2) + self.assertTrue(all(f.done() for f in futures)) + + @gen_test + def test_garbage_collection(self): + # Test that timed-out waiters are occasionally cleaned from the queue. + c = locks.Condition() + for _ in range(101): + c.wait(timedelta(seconds=0.01)) + + future = c.wait() + self.assertEqual(102, len(c._waiters)) + + # Let first 101 waiters time out, triggering a collection. + yield gen.sleep(0.02) + self.assertEqual(1, len(c._waiters)) + + # Final waiter is still active. + self.assertFalse(future.done()) + c.notify() + self.assertTrue(future.done()) + + +class EventTest(AsyncTestCase): + def test_repr(self): + event = locks.Event() + self.assertTrue('clear' in str(event)) + self.assertFalse('set' in str(event)) + event.set() + self.assertFalse('clear' in str(event)) + self.assertTrue('set' in str(event)) + + def test_event(self): + e = locks.Event() + future_0 = e.wait() + e.set() + future_1 = e.wait() + e.clear() + future_2 = e.wait() + + self.assertTrue(future_0.done()) + self.assertTrue(future_1.done()) + self.assertFalse(future_2.done()) + + @gen_test + def test_event_timeout(self): + e = locks.Event() + with self.assertRaises(TimeoutError): + yield e.wait(timedelta(seconds=0.01)) + + # After a timed-out waiter, normal operation works. + self.io_loop.add_timeout(timedelta(seconds=0.01), e.set) + yield e.wait(timedelta(seconds=1)) + + def test_event_set_multiple(self): + e = locks.Event() + e.set() + e.set() + self.assertTrue(e.is_set()) + + def test_event_wait_clear(self): + e = locks.Event() + f0 = e.wait() + e.clear() + f1 = e.wait() + e.set() + self.assertTrue(f0.done()) + self.assertTrue(f1.done()) + + +class SemaphoreTest(AsyncTestCase): + def test_negative_value(self): + self.assertRaises(ValueError, locks.Semaphore, value=-1) + + def test_repr(self): + sem = locks.Semaphore() + self.assertIn('Semaphore', repr(sem)) + self.assertIn('unlocked,value:1', repr(sem)) + sem.acquire() + self.assertIn('locked', repr(sem)) + self.assertNotIn('waiters', repr(sem)) + sem.acquire() + self.assertIn('waiters', repr(sem)) + + def test_acquire(self): + sem = locks.Semaphore() + f0 = sem.acquire() + self.assertTrue(f0.done()) + + # Wait for release(). + f1 = sem.acquire() + self.assertFalse(f1.done()) + f2 = sem.acquire() + sem.release() + self.assertTrue(f1.done()) + self.assertFalse(f2.done()) + sem.release() + self.assertTrue(f2.done()) + + sem.release() + # Now acquire() is instant. + self.assertTrue(sem.acquire().done()) + self.assertEqual(0, len(sem._waiters)) + + @gen_test + def test_acquire_timeout(self): + sem = locks.Semaphore(2) + yield sem.acquire() + yield sem.acquire() + acquire = sem.acquire(timedelta(seconds=0.01)) + self.io_loop.call_later(0.02, sem.release) # Too late. + yield gen.sleep(0.3) + with self.assertRaises(gen.TimeoutError): + yield acquire + + sem.acquire() + f = sem.acquire() + self.assertFalse(f.done()) + sem.release() + self.assertTrue(f.done()) + + @gen_test + def test_acquire_timeout_preempted(self): + sem = locks.Semaphore(1) + yield sem.acquire() + + # This fires before the wait times out. + self.io_loop.call_later(0.01, sem.release) + acquire = sem.acquire(timedelta(seconds=0.02)) + yield gen.sleep(0.03) + yield acquire # No TimeoutError. + + def test_release_unacquired(self): + # Unbounded releases are allowed, and increment the semaphore's value. + sem = locks.Semaphore() + sem.release() + sem.release() + + # Now the counter is 3. We can acquire three times before blocking. + self.assertTrue(sem.acquire().done()) + self.assertTrue(sem.acquire().done()) + self.assertTrue(sem.acquire().done()) + self.assertFalse(sem.acquire().done()) + + @gen_test + def test_garbage_collection(self): + # Test that timed-out waiters are occasionally cleaned from the queue. + sem = locks.Semaphore(value=0) + futures = [sem.acquire(timedelta(seconds=0.01)) for _ in range(101)] + + future = sem.acquire() + self.assertEqual(102, len(sem._waiters)) + + # Let first 101 waiters time out, triggering a collection. + yield gen.sleep(0.02) + self.assertEqual(1, len(sem._waiters)) + + # Final waiter is still active. + self.assertFalse(future.done()) + sem.release() + self.assertTrue(future.done()) + + # Prevent "Future exception was never retrieved" messages. + for future in futures: + self.assertRaises(TimeoutError, future.result) + + +class SemaphoreContextManagerTest(AsyncTestCase): + @gen_test + def test_context_manager(self): + sem = locks.Semaphore() + with (yield sem.acquire()) as yielded: + self.assertTrue(yielded is None) + + # Semaphore was released and can be acquired again. + self.assertTrue(sem.acquire().done()) + + @gen_test + def test_context_manager_exception(self): + sem = locks.Semaphore() + with self.assertRaises(ZeroDivisionError): + with (yield sem.acquire()): + 1 / 0 + + # Semaphore was released and can be acquired again. + self.assertTrue(sem.acquire().done()) + + @gen_test + def test_context_manager_timeout(self): + sem = locks.Semaphore() + with (yield sem.acquire(timedelta(seconds=0.01))): + pass + + # Semaphore was released and can be acquired again. + self.assertTrue(sem.acquire().done()) + + @gen_test + def test_context_manager_timeout_error(self): + sem = locks.Semaphore(value=0) + with self.assertRaises(gen.TimeoutError): + with (yield sem.acquire(timedelta(seconds=0.01))): + pass + + # Counter is still 0. + self.assertFalse(sem.acquire().done()) + + @gen_test + def test_context_manager_contended(self): + sem = locks.Semaphore() + history = [] + + @gen.coroutine + def f(index): + with (yield sem.acquire()): + history.append('acquired %d' % index) + yield gen.sleep(0.01) + history.append('release %d' % index) + + yield [f(i) for i in range(2)] + + expected_history = [] + for i in range(2): + expected_history.extend(['acquired %d' % i, 'release %d' % i]) + + self.assertEqual(expected_history, history) + + @gen_test + def test_yield_sem(self): + # Ensure we catch a "with (yield sem)", which should be + # "with (yield sem.acquire())". + with self.assertRaises(gen.BadYieldError): + with (yield locks.Semaphore()): + pass + + def test_context_manager_misuse(self): + # Ensure we catch a "with sem", which should be + # "with (yield sem.acquire())". + with self.assertRaises(RuntimeError): + with locks.Semaphore(): + pass + + +class BoundedSemaphoreTest(AsyncTestCase): + def test_release_unacquired(self): + sem = locks.BoundedSemaphore() + self.assertRaises(ValueError, sem.release) + # Value is 0. + sem.acquire() + # Block on acquire(). + future = sem.acquire() + self.assertFalse(future.done()) + sem.release() + self.assertTrue(future.done()) + # Value is 1. + sem.release() + self.assertRaises(ValueError, sem.release) + + +class LockTests(AsyncTestCase): + def test_repr(self): + lock = locks.Lock() + # No errors. + repr(lock) + lock.acquire() + repr(lock) + + def test_acquire_release(self): + lock = locks.Lock() + self.assertTrue(lock.acquire().done()) + future = lock.acquire() + self.assertFalse(future.done()) + lock.release() + self.assertTrue(future.done()) + + @gen_test + def test_acquire_fifo(self): + lock = locks.Lock() + self.assertTrue(lock.acquire().done()) + N = 5 + history = [] + + @gen.coroutine + def f(idx): + with (yield lock.acquire()): + history.append(idx) + + futures = [f(i) for i in range(N)] + self.assertFalse(any(future.done() for future in futures)) + lock.release() + yield futures + self.assertEqual(list(range(N)), history) + + @gen_test + def test_acquire_timeout(self): + lock = locks.Lock() + lock.acquire() + with self.assertRaises(gen.TimeoutError): + yield lock.acquire(timeout=timedelta(seconds=0.01)) + + # Still locked. + self.assertFalse(lock.acquire().done()) + + def test_multi_release(self): + lock = locks.Lock() + self.assertRaises(RuntimeError, lock.release) + lock.acquire() + lock.release() + self.assertRaises(RuntimeError, lock.release) + + @gen_test + def test_yield_lock(self): + # Ensure we catch a "with (yield lock)", which should be + # "with (yield lock.acquire())". + with self.assertRaises(gen.BadYieldError): + with (yield locks.Lock()): + pass + + def test_context_manager_misuse(self): + # Ensure we catch a "with lock", which should be + # "with (yield lock.acquire())". + with self.assertRaises(RuntimeError): + with locks.Lock(): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/tornado/test/netutil_test.py b/tornado/test/netutil_test.py index 1df1e320..7d9cad34 100644 --- a/tornado/test/netutil_test.py +++ b/tornado/test/netutil_test.py @@ -67,10 +67,12 @@ class _ResolverErrorTestMixin(object): yield self.resolver.resolve('an invalid domain', 80, socket.AF_UNSPEC) + def _failing_getaddrinfo(*args): """Dummy implementation of getaddrinfo for use in mocks""" raise socket.gaierror("mock: lookup failed") + @skipIfNoNetwork class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin): def setUp(self): diff --git a/tornado/test/process_test.py b/tornado/test/process_test.py index de727607..58cc410b 100644 --- a/tornado/test/process_test.py +++ b/tornado/test/process_test.py @@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop from tornado.log import gen_log from tornado.process import fork_processes, task_id, Subprocess from tornado.simple_httpclient import SimpleAsyncHTTPClient -from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase +from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase, gen_test from tornado.test.util import unittest, skipIfNonUnix from tornado.web import RequestHandler, Application @@ -85,7 +85,7 @@ class ProcessTest(unittest.TestCase): self.assertEqual(id, task_id()) server = HTTPServer(self.get_app()) server.add_sockets([sock]) - IOLoop.instance().start() + IOLoop.current().start() elif id == 2: self.assertEqual(id, task_id()) sock.close() @@ -200,6 +200,16 @@ class SubprocessTest(AsyncTestCase): self.assertEqual(ret, 0) self.assertEqual(subproc.returncode, ret) + @gen_test + def test_sigchild_future(self): + skip_if_twisted() + Subprocess.initialize() + self.addCleanup(Subprocess.uninitialize) + subproc = Subprocess([sys.executable, '-c', 'pass']) + ret = yield subproc.wait_for_exit() + self.assertEqual(ret, 0) + self.assertEqual(subproc.returncode, ret) + def test_sigchild_signal(self): skip_if_twisted() Subprocess.initialize(io_loop=self.io_loop) @@ -212,3 +222,22 @@ class SubprocessTest(AsyncTestCase): ret = self.wait() self.assertEqual(subproc.returncode, ret) self.assertEqual(ret, -signal.SIGTERM) + + @gen_test + def test_wait_for_exit_raise(self): + skip_if_twisted() + Subprocess.initialize() + self.addCleanup(Subprocess.uninitialize) + subproc = Subprocess([sys.executable, '-c', 'import sys; sys.exit(1)']) + with self.assertRaises(subprocess.CalledProcessError) as cm: + yield subproc.wait_for_exit() + self.assertEqual(cm.exception.returncode, 1) + + @gen_test + def test_wait_for_exit_raise_disabled(self): + skip_if_twisted() + Subprocess.initialize() + self.addCleanup(Subprocess.uninitialize) + subproc = Subprocess([sys.executable, '-c', 'import sys; sys.exit(1)']) + ret = yield subproc.wait_for_exit(raise_error=False) + self.assertEqual(ret, 1) diff --git a/tornado/test/queues_test.py b/tornado/test/queues_test.py new file mode 100644 index 00000000..f2ffb646 --- /dev/null +++ b/tornado/test/queues_test.py @@ -0,0 +1,403 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from datetime import timedelta +from random import random + +from tornado import gen, queues +from tornado.gen import TimeoutError +from tornado.testing import gen_test, AsyncTestCase +from tornado.test.util import unittest + + +class QueueBasicTest(AsyncTestCase): + def test_repr_and_str(self): + q = queues.Queue(maxsize=1) + self.assertIn(hex(id(q)), repr(q)) + self.assertNotIn(hex(id(q)), str(q)) + q.get() + + for q_str in repr(q), str(q): + self.assertTrue(q_str.startswith('= logging.ERROR: + self.error_count += 1 + elif record.levelno >= logging.WARNING: + self.warning_count += 1 + return True + + def main(): # The -W command-line option does not work in a virtualenv with # python 3 (as of virtualenv 1.7), so configure warnings @@ -92,12 +112,21 @@ def main(): # 2.7 and 3.2 warnings.filterwarnings("ignore", category=DeprecationWarning, message="Please use assert.* instead") + # unittest2 0.6 on py26 reports these as PendingDeprecationWarnings + # instead of DeprecationWarnings. + warnings.filterwarnings("ignore", category=PendingDeprecationWarning, + message="Please use assert.* instead") + # Twisted 15.0.0 triggers some warnings on py3 with -bb. + warnings.filterwarnings("ignore", category=BytesWarning, + module=r"twisted\..*") logging.getLogger("tornado.access").setLevel(logging.CRITICAL) define('httpclient', type=str, default=None, callback=lambda s: AsyncHTTPClient.configure( s, defaults=dict(allow_ipv6=False))) + define('httpserver', type=str, default=None, + callback=HTTPServer.configure) define('ioloop', type=str, default=None) define('ioloop_time_monotonic', default=False) define('resolver', type=str, default=None, @@ -121,6 +150,10 @@ def main(): IOLoop.configure(options.ioloop, **kwargs) add_parse_callback(configure_ioloop) + log_counter = LogCounter() + add_parse_callback( + lambda: logging.getLogger().handlers[0].addFilter(log_counter)) + import tornado.testing kwargs = {} if sys.version_info >= (3, 2): @@ -131,7 +164,16 @@ def main(): # detail. http://bugs.python.org/issue15626 kwargs['warnings'] = False kwargs['testRunner'] = TornadoTextTestRunner - tornado.testing.main(**kwargs) + try: + tornado.testing.main(**kwargs) + finally: + # The tests should run clean; consider it a failure if they logged + # any warnings or errors. We'd like to ban info logs too, but + # we can't count them cleanly due to interactions with LogTrapTestCase. + if log_counter.warning_count > 0 or log_counter.error_count > 0: + logging.error("logged %d warnings and %d errors", + log_counter.warning_count, log_counter.error_count) + sys.exit(1) if __name__ == '__main__': main() diff --git a/tornado/test/simple_httpclient_test.py b/tornado/test/simple_httpclient_test.py index e3fab57a..c0de22b7 100644 --- a/tornado/test/simple_httpclient_test.py +++ b/tornado/test/simple_httpclient_test.py @@ -8,19 +8,20 @@ import logging import os import re import socket +import ssl import sys from tornado import gen from tornado.httpclient import AsyncHTTPClient -from tornado.httputil import HTTPHeaders +from tornado.httputil import HTTPHeaders, ResponseStartLine from tornado.ioloop import IOLoop from tornado.log import gen_log from tornado.netutil import Resolver, bind_sockets from tornado.simple_httpclient import SimpleAsyncHTTPClient, _default_ca_certs from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler from tornado.test import httpclient_test -from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog -from tornado.test.util import skipOnTravis, skipIfNoIPv6 +from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog +from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port, unittest from tornado.web import RequestHandler, Application, asynchronous, url, stream_request_body @@ -97,15 +98,18 @@ class HostEchoHandler(RequestHandler): class NoContentLengthHandler(RequestHandler): - @gen.coroutine + @asynchronous def get(self): - # Emulate the old HTTP/1.0 behavior of returning a body with no - # content-length. Tornado handles content-length at the framework - # level so we have to go around it. - stream = self.request.connection.stream - yield stream.write(b"HTTP/1.0 200 OK\r\n\r\n" - b"hello") - stream.close() + if self.request.version.startswith('HTTP/1'): + # Emulate the old HTTP/1.0 behavior of returning a body with no + # content-length. Tornado handles content-length at the framework + # level so we have to go around it. + stream = self.request.connection.detach() + stream.write(b"HTTP/1.0 200 OK\r\n\r\n" + b"hello") + stream.close() + else: + self.finish('HTTP/1 required') class EchoPostHandler(RequestHandler): @@ -191,9 +195,6 @@ class SimpleHTTPClientTestMixin(object): response = self.wait() response.rethrow() - def test_default_certificates_exist(self): - open(_default_ca_certs()).close() - def test_gzip(self): # All the tests in this file should be using gzip, but this test # ensures that it is in fact getting compressed. @@ -235,9 +236,16 @@ class SimpleHTTPClientTestMixin(object): @skipOnTravis def test_request_timeout(self): - response = self.fetch('/trigger?wake=false', request_timeout=0.1) + timeout = 0.1 + timeout_min, timeout_max = 0.099, 0.15 + if os.name == 'nt': + timeout = 0.5 + timeout_min, timeout_max = 0.4, 0.6 + + response = self.fetch('/trigger?wake=false', request_timeout=timeout) self.assertEqual(response.code, 599) - self.assertTrue(0.099 < response.request_time < 0.15, response.request_time) + self.assertTrue(timeout_min < response.request_time < timeout_max, + response.request_time) self.assertEqual(str(response.error), "HTTP 599: Timeout") # trigger the hanging request to let it clean up after itself self.triggers.popleft()() @@ -315,10 +323,10 @@ class SimpleHTTPClientTestMixin(object): self.assertTrue(host_re.match(response.body), response.body) def test_connection_refused(self): - server_socket, port = bind_unused_port() - server_socket.close() + cleanup_func, port = refusing_port() + self.addCleanup(cleanup_func) with ExpectLog(gen_log, ".*", required=False): - self.http_client.fetch("http://localhost:%d/" % port, self.stop) + self.http_client.fetch("http://127.0.0.1:%d/" % port, self.stop) response = self.wait() self.assertEqual(599, response.code) @@ -352,7 +360,10 @@ class SimpleHTTPClientTestMixin(object): def test_no_content_length(self): response = self.fetch("/no_content_length") - self.assertEquals(b"hello", response.body) + if response.body == b"HTTP/1 required": + self.skipTest("requires HTTP/1.x") + else: + self.assertEquals(b"hello", response.body) def sync_body_producer(self, write): write(b'1234') @@ -425,6 +436,33 @@ class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase): defaults=dict(validate_cert=False), **kwargs) + def test_ssl_options(self): + resp = self.fetch("/hello", ssl_options={}) + self.assertEqual(resp.body, b"Hello world!") + + @unittest.skipIf(not hasattr(ssl, 'SSLContext'), + 'ssl.SSLContext not present') + def test_ssl_context(self): + resp = self.fetch("/hello", + ssl_options=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) + self.assertEqual(resp.body, b"Hello world!") + + def test_ssl_options_handshake_fail(self): + with ExpectLog(gen_log, "SSL Error|Uncaught exception", + required=False): + resp = self.fetch( + "/hello", ssl_options=dict(cert_reqs=ssl.CERT_REQUIRED)) + self.assertRaises(ssl.SSLError, resp.rethrow) + + @unittest.skipIf(not hasattr(ssl, 'SSLContext'), + 'ssl.SSLContext not present') + def test_ssl_context_handshake_fail(self): + with ExpectLog(gen_log, "SSL Error|Uncaught exception"): + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.verify_mode = ssl.CERT_REQUIRED + resp = self.fetch("/hello", ssl_options=ctx) + self.assertRaises(ssl.SSLError, resp.rethrow) + class CreateAsyncHTTPClientTestCase(AsyncTestCase): def setUp(self): @@ -460,6 +498,12 @@ class CreateAsyncHTTPClientTestCase(AsyncTestCase): class HTTP100ContinueTestCase(AsyncHTTPTestCase): def respond_100(self, request): + self.http1 = request.version.startswith('HTTP/1.') + if not self.http1: + request.connection.write_headers(ResponseStartLine('', 200, 'OK'), + HTTPHeaders()) + request.connection.finish() + return self.request = request self.request.connection.stream.write( b"HTTP/1.1 100 CONTINUE\r\n\r\n", @@ -476,11 +520,20 @@ class HTTP100ContinueTestCase(AsyncHTTPTestCase): def test_100_continue(self): res = self.fetch('/') + if not self.http1: + self.skipTest("requires HTTP/1.x") self.assertEqual(res.body, b'A') class HTTP204NoContentTestCase(AsyncHTTPTestCase): def respond_204(self, request): + self.http1 = request.version.startswith('HTTP/1.') + if not self.http1: + # Close the request cleanly in HTTP/2; it will be skipped anyway. + request.connection.write_headers(ResponseStartLine('', 200, 'OK'), + HTTPHeaders()) + request.connection.finish() + return # A 204 response never has a body, even if doesn't have a content-length # (which would otherwise mean read-until-close). Tornado always # sends a content-length, so we simulate here a server that sends @@ -488,14 +541,18 @@ class HTTP204NoContentTestCase(AsyncHTTPTestCase): # # Tests of a 204 response with a Content-Length header are included # in SimpleHTTPClientTestMixin. - request.connection.stream.write( + stream = request.connection.detach() + stream.write( b"HTTP/1.1 204 No content\r\n\r\n") + stream.close() def get_app(self): return self.respond_204 def test_204_no_content(self): resp = self.fetch('/') + if not self.http1: + self.skipTest("requires HTTP/1.x") self.assertEqual(resp.code, 204) self.assertEqual(resp.body, b'') @@ -574,3 +631,49 @@ class MaxHeaderSizeTest(AsyncHTTPTestCase): with ExpectLog(gen_log, "Unsatisfiable read"): response = self.fetch('/large') self.assertEqual(response.code, 599) + + +class MaxBodySizeTest(AsyncHTTPTestCase): + def get_app(self): + class SmallBody(RequestHandler): + def get(self): + self.write("a"*1024*64) + + class LargeBody(RequestHandler): + def get(self): + self.write("a"*1024*100) + + return Application([('/small', SmallBody), + ('/large', LargeBody)]) + + def get_http_client(self): + return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_body_size=1024*64) + + def test_small_body(self): + response = self.fetch('/small') + response.rethrow() + self.assertEqual(response.body, b'a'*1024*64) + + def test_large_body(self): + with ExpectLog(gen_log, "Malformed HTTP message from None: Content-Length too long"): + response = self.fetch('/large') + self.assertEqual(response.code, 599) + + +class MaxBufferSizeTest(AsyncHTTPTestCase): + def get_app(self): + + class LargeBody(RequestHandler): + def get(self): + self.write("a"*1024*100) + + return Application([('/large', LargeBody)]) + + def get_http_client(self): + # 100KB body with 64KB buffer + return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_body_size=1024*100, max_buffer_size=1024*64) + + def test_large_body(self): + response = self.fetch('/large') + response.rethrow() + self.assertEqual(response.body, b'a'*1024*100) diff --git a/tornado/test/tcpclient_test.py b/tornado/test/tcpclient_test.py index 5df4a7ab..1a4201e6 100644 --- a/tornado/test/tcpclient_test.py +++ b/tornado/test/tcpclient_test.py @@ -24,8 +24,8 @@ from tornado.concurrent import Future from tornado.netutil import bind_sockets, Resolver from tornado.tcpclient import TCPClient, _Connector from tornado.tcpserver import TCPServer -from tornado.testing import AsyncTestCase, bind_unused_port, gen_test -from tornado.test.util import skipIfNoIPv6, unittest +from tornado.testing import AsyncTestCase, gen_test +from tornado.test.util import skipIfNoIPv6, unittest, refusing_port # Fake address families for testing. Used in place of AF_INET # and AF_INET6 because some installations do not have AF_INET6. @@ -120,8 +120,8 @@ class TCPClientTest(AsyncTestCase): @gen_test def test_refused_ipv4(self): - sock, port = bind_unused_port() - sock.close() + cleanup_func, port = refusing_port() + self.addCleanup(cleanup_func) with self.assertRaises(IOError): yield self.client.connect('127.0.0.1', port) diff --git a/tornado/test/tcpserver_test.py b/tornado/test/tcpserver_test.py new file mode 100644 index 00000000..84c95076 --- /dev/null +++ b/tornado/test/tcpserver_test.py @@ -0,0 +1,38 @@ +import socket + +from tornado import gen +from tornado.iostream import IOStream +from tornado.log import app_log +from tornado.stack_context import NullContext +from tornado.tcpserver import TCPServer +from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test + + +class TCPServerTest(AsyncTestCase): + @gen_test + def test_handle_stream_coroutine_logging(self): + # handle_stream may be a coroutine and any exception in its + # Future will be logged. + class TestServer(TCPServer): + @gen.coroutine + def handle_stream(self, stream, address): + yield gen.moment + stream.close() + 1/0 + + server = client = None + try: + sock, port = bind_unused_port() + with NullContext(): + server = TestServer() + server.add_socket(sock) + client = IOStream(socket.socket()) + with ExpectLog(app_log, "Exception in callback"): + yield client.connect(('localhost', port)) + yield client.read_until_close() + yield gen.moment + finally: + if server is not None: + server.stop() + if client is not None: + client.close() diff --git a/tornado/test/twisted_test.py b/tornado/test/twisted_test.py index 2922a61e..8ace993d 100644 --- a/tornado/test/twisted_test.py +++ b/tornado/test/twisted_test.py @@ -19,15 +19,18 @@ Unittest for the twisted-style reactor. from __future__ import absolute_import, division, print_function, with_statement +import logging import os import shutil import signal +import sys import tempfile import threading +import warnings try: import fcntl - from twisted.internet.defer import Deferred + from twisted.internet.defer import Deferred, inlineCallbacks, returnValue from twisted.internet.interfaces import IReadDescriptor, IWriteDescriptor from twisted.internet.protocol import Protocol from twisted.python import log @@ -40,10 +43,12 @@ except ImportError: # The core of Twisted 12.3.0 is available on python 3, but twisted.web is not # so test for it separately. try: - from twisted.web.client import Agent + from twisted.web.client import Agent, readBody from twisted.web.resource import Resource from twisted.web.server import Site - have_twisted_web = True + # As of Twisted 15.0.0, twisted.web is present but fails our + # tests due to internal str/bytes errors. + have_twisted_web = sys.version_info < (3,) except ImportError: have_twisted_web = False @@ -52,6 +57,8 @@ try: except ImportError: import _thread as thread # py3 +from tornado.escape import utf8 +from tornado import gen from tornado.httpclient import AsyncHTTPClient from tornado.httpserver import HTTPServer from tornado.ioloop import IOLoop @@ -65,6 +72,9 @@ from tornado.web import RequestHandler, Application skipIfNoTwisted = unittest.skipUnless(have_twisted, "twisted module not present") +skipIfNoSingleDispatch = unittest.skipIf( + gen.singledispatch is None, "singledispatch module not present") + def save_signal_handlers(): saved = {} @@ -407,7 +417,7 @@ class CompatibilityTests(unittest.TestCase): # http://twistedmatrix.com/documents/current/web/howto/client.html chunks = [] client = Agent(self.reactor) - d = client.request('GET', url) + d = client.request(b'GET', utf8(url)) class Accumulator(Protocol): def __init__(self, finished): @@ -425,37 +435,98 @@ class CompatibilityTests(unittest.TestCase): return finished d.addCallback(callback) - def shutdown(ignored): - self.stop_loop() + def shutdown(failure): + if hasattr(self, 'stop_loop'): + self.stop_loop() + elif failure is not None: + # loop hasn't been initialized yet; try our best to + # get an error message out. (the runner() interaction + # should probably be refactored). + try: + failure.raiseException() + except: + logging.error('exception before starting loop', exc_info=True) d.addBoth(shutdown) runner() self.assertTrue(chunks) return ''.join(chunks) + def twisted_coroutine_fetch(self, url, runner): + body = [None] + + @gen.coroutine + def f(): + # This is simpler than the non-coroutine version, but it cheats + # by reading the body in one blob instead of streaming it with + # a Protocol. + client = Agent(self.reactor) + response = yield client.request(b'GET', utf8(url)) + with warnings.catch_warnings(): + # readBody has a buggy DeprecationWarning in Twisted 15.0: + # https://twistedmatrix.com/trac/changeset/43379 + warnings.simplefilter('ignore', category=DeprecationWarning) + body[0] = yield readBody(response) + self.stop_loop() + self.io_loop.add_callback(f) + runner() + return body[0] + def testTwistedServerTornadoClientIOLoop(self): self.start_twisted_server() response = self.tornado_fetch( - 'http://localhost:%d' % self.twisted_port, self.run_ioloop) + 'http://127.0.0.1:%d' % self.twisted_port, self.run_ioloop) self.assertEqual(response.body, 'Hello from twisted!') def testTwistedServerTornadoClientReactor(self): self.start_twisted_server() response = self.tornado_fetch( - 'http://localhost:%d' % self.twisted_port, self.run_reactor) + 'http://127.0.0.1:%d' % self.twisted_port, self.run_reactor) self.assertEqual(response.body, 'Hello from twisted!') def testTornadoServerTwistedClientIOLoop(self): self.start_tornado_server() response = self.twisted_fetch( - 'http://localhost:%d' % self.tornado_port, self.run_ioloop) + 'http://127.0.0.1:%d' % self.tornado_port, self.run_ioloop) self.assertEqual(response, 'Hello from tornado!') def testTornadoServerTwistedClientReactor(self): self.start_tornado_server() response = self.twisted_fetch( - 'http://localhost:%d' % self.tornado_port, self.run_reactor) + 'http://127.0.0.1:%d' % self.tornado_port, self.run_reactor) self.assertEqual(response, 'Hello from tornado!') + @skipIfNoSingleDispatch + def testTornadoServerTwistedCoroutineClientIOLoop(self): + self.start_tornado_server() + response = self.twisted_coroutine_fetch( + 'http://127.0.0.1:%d' % self.tornado_port, self.run_ioloop) + self.assertEqual(response, 'Hello from tornado!') + + +@skipIfNoTwisted +@skipIfNoSingleDispatch +class ConvertDeferredTest(unittest.TestCase): + def test_success(self): + @inlineCallbacks + def fn(): + if False: + # inlineCallbacks doesn't work with regular functions; + # must have a yield even if it's unreachable. + yield + returnValue(42) + f = gen.convert_yielded(fn()) + self.assertEqual(f.result(), 42) + + def test_failure(self): + @inlineCallbacks + def fn(): + if False: + yield + 1 / 0 + f = gen.convert_yielded(fn()) + with self.assertRaises(ZeroDivisionError): + f.result() + if have_twisted: # Import and run as much of twisted's test suite as possible. @@ -483,7 +554,7 @@ if have_twisted: 'test_changeUID', ], # Process tests appear to work on OSX 10.7, but not 10.6 - #'twisted.internet.test.test_process.PTYProcessTestsBuilder': [ + # 'twisted.internet.test.test_process.PTYProcessTestsBuilder': [ # 'test_systemCallUninterruptedByChildExit', # ], 'twisted.internet.test.test_tcp.TCPClientTestsBuilder': [ @@ -502,7 +573,7 @@ if have_twisted: 'twisted.internet.test.test_threads.ThreadTestsBuilder': [], 'twisted.internet.test.test_time.TimeTestsBuilder': [], # Extra third-party dependencies (pyOpenSSL) - #'twisted.internet.test.test_tls.SSLClientTestsMixin': [], + # 'twisted.internet.test.test_tls.SSLClientTestsMixin': [], 'twisted.internet.test.test_udp.UDPServerTestsBuilder': [], 'twisted.internet.test.test_unix.UNIXTestsBuilder': [ # Platform-specific. These tests would be skipped automatically @@ -588,13 +659,13 @@ if have_twisted: correctly. In some tests another TornadoReactor is layered on top of the whole stack. """ - def initialize(self): + def initialize(self, **kwargs): # When configured to use LayeredTwistedIOLoop we can't easily # get the next-best IOLoop implementation, so use the lowest common # denominator. self.real_io_loop = SelectIOLoop() reactor = TornadoReactor(io_loop=self.real_io_loop) - super(LayeredTwistedIOLoop, self).initialize(reactor=reactor) + super(LayeredTwistedIOLoop, self).initialize(reactor=reactor, **kwargs) self.add_callback(self.make_current) def close(self, all_fds=False): diff --git a/tornado/test/util.py b/tornado/test/util.py index d31bbba3..9dd9c0ce 100644 --- a/tornado/test/util.py +++ b/tornado/test/util.py @@ -4,6 +4,8 @@ import os import socket import sys +from tornado.testing import bind_unused_port + # Encapsulate the choice of unittest or unittest2 here. # To be used as 'from tornado.test.util import unittest'. if sys.version_info < (2, 7): @@ -28,3 +30,23 @@ skipIfNoNetwork = unittest.skipIf('NO_NETWORK' in os.environ, 'network access disabled') skipIfNoIPv6 = unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present') + + +def refusing_port(): + """Returns a local port number that will refuse all connections. + + Return value is (cleanup_func, port); the cleanup function + must be called to free the port to be reused. + """ + # On travis-ci, port numbers are reassigned frequently. To avoid + # collisions with other tests, we use an open client-side socket's + # ephemeral port number to ensure that nothing can listen on that + # port. + server_socket, port = bind_unused_port() + server_socket.setblocking(1) + client_socket = socket.socket() + client_socket.connect(("127.0.0.1", port)) + conn, client_addr = server_socket.accept() + conn.close() + server_socket.close() + return (client_socket.close, client_addr[1]) diff --git a/tornado/test/util_test.py b/tornado/test/util_test.py index 1cd78fe4..0936c89a 100644 --- a/tornado/test/util_test.py +++ b/tornado/test/util_test.py @@ -3,8 +3,9 @@ from __future__ import absolute_import, division, print_function, with_statement import sys import datetime +import tornado.escape from tornado.escape import utf8 -from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer, timedelta_to_seconds +from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer, timedelta_to_seconds, import_object from tornado.test.util import unittest try: @@ -45,13 +46,15 @@ class TestConfigurable(Configurable): class TestConfig1(TestConfigurable): - def initialize(self, a=None): + def initialize(self, pos_arg=None, a=None): self.a = a + self.pos_arg = pos_arg class TestConfig2(TestConfigurable): - def initialize(self, b=None): + def initialize(self, pos_arg=None, b=None): self.b = b + self.pos_arg = pos_arg class ConfigurableTest(unittest.TestCase): @@ -101,9 +104,10 @@ class ConfigurableTest(unittest.TestCase): self.assertIsInstance(obj, TestConfig1) self.assertEqual(obj.a, 3) - obj = TestConfigurable(a=4) + obj = TestConfigurable(42, a=4) self.assertIsInstance(obj, TestConfig1) self.assertEqual(obj.a, 4) + self.assertEqual(obj.pos_arg, 42) self.checkSubclasses() # args bound in configure don't apply when using the subclass directly @@ -116,9 +120,10 @@ class ConfigurableTest(unittest.TestCase): self.assertIsInstance(obj, TestConfig2) self.assertEqual(obj.b, 5) - obj = TestConfigurable(b=6) + obj = TestConfigurable(42, b=6) self.assertIsInstance(obj, TestConfig2) self.assertEqual(obj.b, 6) + self.assertEqual(obj.pos_arg, 42) self.checkSubclasses() # args bound in configure don't apply when using the subclass directly @@ -177,3 +182,20 @@ class TimedeltaToSecondsTest(unittest.TestCase): def test_timedelta_to_seconds(self): time_delta = datetime.timedelta(hours=1) self.assertEqual(timedelta_to_seconds(time_delta), 3600.0) + + +class ImportObjectTest(unittest.TestCase): + def test_import_member(self): + self.assertIs(import_object('tornado.escape.utf8'), utf8) + + def test_import_member_unicode(self): + self.assertIs(import_object(u('tornado.escape.utf8')), utf8) + + def test_import_module(self): + self.assertIs(import_object('tornado.escape'), tornado.escape) + + def test_import_module_unicode(self): + # The internal implementation of __import__ differs depending on + # whether the thing being imported is a module or not. + # This variant requires a byte string in python 2. + self.assertIs(import_object(u('tornado.escape')), tornado.escape) diff --git a/tornado/test/web_test.py b/tornado/test/web_test.py index 55c9c9e8..96edd6c2 100644 --- a/tornado/test/web_test.py +++ b/tornado/test/web_test.py @@ -11,7 +11,7 @@ from tornado.template import DictLoader from tornado.testing import AsyncHTTPTestCase, ExpectLog, gen_test from tornado.test.util import unittest from tornado.util import u, ObjectDict, unicode_type, timedelta_to_seconds -from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler +from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler, get_signature_key_version import binascii import contextlib @@ -71,10 +71,14 @@ class HelloHandler(RequestHandler): class CookieTestRequestHandler(RequestHandler): # stub out enough methods to make the secure_cookie functions work - def __init__(self): + def __init__(self, cookie_secret='0123456789', key_version=None): # don't call super.__init__ self._cookies = {} - self.application = ObjectDict(settings=dict(cookie_secret='0123456789')) + if key_version is None: + self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret)) + else: + self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret, + key_version=key_version)) def get_cookie(self, name): return self._cookies.get(name) @@ -128,6 +132,51 @@ class SecureCookieV1Test(unittest.TestCase): self.assertEqual(handler.get_secure_cookie('foo', min_version=1), b'\xe9') +# See SignedValueTest below for more. +class SecureCookieV2Test(unittest.TestCase): + KEY_VERSIONS = { + 0: 'ajklasdf0ojaisdf', + 1: 'aslkjasaolwkjsdf' + } + + def test_round_trip(self): + handler = CookieTestRequestHandler() + handler.set_secure_cookie('foo', b'bar', version=2) + self.assertEqual(handler.get_secure_cookie('foo', min_version=2), b'bar') + + def test_key_version_roundtrip(self): + handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS, + key_version=0) + handler.set_secure_cookie('foo', b'bar') + self.assertEqual(handler.get_secure_cookie('foo'), b'bar') + + def test_key_version_roundtrip_differing_version(self): + handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS, + key_version=1) + handler.set_secure_cookie('foo', b'bar') + self.assertEqual(handler.get_secure_cookie('foo'), b'bar') + + def test_key_version_increment_version(self): + handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS, + key_version=0) + handler.set_secure_cookie('foo', b'bar') + new_handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS, + key_version=1) + new_handler._cookies = handler._cookies + self.assertEqual(new_handler.get_secure_cookie('foo'), b'bar') + + def test_key_version_invalidate_version(self): + handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS, + key_version=0) + handler.set_secure_cookie('foo', b'bar') + new_key_versions = self.KEY_VERSIONS.copy() + new_key_versions.pop(0) + new_handler = CookieTestRequestHandler(cookie_secret=new_key_versions, + key_version=1) + new_handler._cookies = handler._cookies + self.assertEqual(new_handler.get_secure_cookie('foo'), None) + + class CookieTest(WebTestCase): def get_handlers(self): class SetCookieHandler(RequestHandler): @@ -171,6 +220,13 @@ class CookieTest(WebTestCase): def get(self): self.set_cookie("foo", "bar", expires_days=10) + class SetCookieFalsyFlags(RequestHandler): + def get(self): + self.set_cookie("a", "1", secure=True) + self.set_cookie("b", "1", secure=False) + self.set_cookie("c", "1", httponly=True) + self.set_cookie("d", "1", httponly=False) + return [("/set", SetCookieHandler), ("/get", GetCookieHandler), ("/set_domain", SetCookieDomainHandler), @@ -178,6 +234,7 @@ class CookieTest(WebTestCase): ("/set_overwrite", SetCookieOverwriteHandler), ("/set_max_age", SetCookieMaxAgeHandler), ("/set_expires_days", SetCookieExpiresDaysHandler), + ("/set_falsy_flags", SetCookieFalsyFlags) ] def test_set_cookie(self): @@ -237,7 +294,7 @@ class CookieTest(WebTestCase): headers = response.headers.get_list("Set-Cookie") self.assertEqual(sorted(headers), ["foo=bar; Max-Age=10; Path=/"]) - + def test_set_cookie_expires_days(self): response = self.fetch("/set_expires_days") header = response.headers.get("Set-Cookie") @@ -248,7 +305,17 @@ class CookieTest(WebTestCase): header_expires = datetime.datetime( *email.utils.parsedate(match.groupdict()["expires"])[:6]) self.assertTrue(abs(timedelta_to_seconds(expires - header_expires)) < 10) - + + def test_set_cookie_false_flags(self): + response = self.fetch("/set_falsy_flags") + headers = sorted(response.headers.get_list("Set-Cookie")) + # The secure and httponly headers are capitalized in py35 and + # lowercase in older versions. + self.assertEqual(headers[0].lower(), 'a=1; path=/; secure') + self.assertEqual(headers[1].lower(), 'b=1; path=/') + self.assertEqual(headers[2].lower(), 'c=1; httponly; path=/') + self.assertEqual(headers[3].lower(), 'd=1; path=/') + class AuthRedirectRequestHandler(RequestHandler): def initialize(self, login_url): @@ -305,7 +372,7 @@ class ConnectionCloseTest(WebTestCase): def test_connection_close(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) - s.connect(("localhost", self.get_http_port())) + s.connect(("127.0.0.1", self.get_http_port())) self.stream = IOStream(s, io_loop=self.io_loop) self.stream.write(b"GET / HTTP/1.0\r\n\r\n") self.wait() @@ -379,6 +446,12 @@ class RequestEncodingTest(WebTestCase): path_args=["a/b", "c/d"], args={})) + def test_error(self): + # Percent signs (encoded as %25) should not mess up printf-style + # messages in logs + with ExpectLog(gen_log, ".*Invalid unicode"): + self.fetch("/group/?arg=%25%e9") + class TypeCheckHandler(RequestHandler): def prepare(self): @@ -579,6 +652,7 @@ class WSGISafeWebTest(WebTestCase): url("/redirect", RedirectHandler), url("/web_redirect_permanent", WebRedirectHandler, {"url": "/web_redirect_newpath"}), url("/web_redirect", WebRedirectHandler, {"url": "/web_redirect_newpath", "permanent": False}), + url("//web_redirect_double_slash", WebRedirectHandler, {"url": '/web_redirect_newpath'}), url("/header_injection", HeaderInjectionHandler), url("/get_argument", GetArgumentHandler), url("/get_arguments", GetArgumentsHandler), @@ -712,6 +786,11 @@ js_embed() self.assertEqual(response.code, 302) self.assertEqual(response.headers['Location'], '/web_redirect_newpath') + def test_web_redirect_double_slash(self): + response = self.fetch("//web_redirect_double_slash", follow_redirects=False) + self.assertEqual(response.code, 301) + self.assertEqual(response.headers['Location'], '/web_redirect_newpath') + def test_header_injection(self): response = self.fetch("/header_injection") self.assertEqual(response.body, b"ok") @@ -1517,6 +1596,22 @@ class ExceptionHandlerTest(SimpleHandlerTestCase): self.assertEqual(response.code, 403) +@wsgi_safe +class BuggyLoggingTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + 1/0 + + def log_exception(self, typ, value, tb): + 1/0 + + def test_buggy_log_exception(self): + # Something gets logged even though the application's + # logger is broken. + with ExpectLog(app_log, '.*'): + self.fetch('/') + + @wsgi_safe class UIMethodUIModuleTest(SimpleHandlerTestCase): """Test that UI methods and modules are created correctly and @@ -1533,6 +1628,7 @@ class UIMethodUIModuleTest(SimpleHandlerTestCase): def my_ui_method(handler, x): return "In my_ui_method(%s) with handler value %s." % ( x, handler.value()) + class MyModule(UIModule): def render(self, x): return "In MyModule(%s) with handler value %s." % ( @@ -1907,7 +2003,7 @@ class StreamingRequestBodyTest(WebTestCase): def connect(self, url, connection_close): # Use a raw connection so we can control the sending of data. s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) - s.connect(("localhost", self.get_http_port())) + s.connect(("127.0.0.1", self.get_http_port())) stream = IOStream(s, io_loop=self.io_loop) stream.write(b"GET " + url + b" HTTP/1.1\r\n") if connection_close: @@ -1988,8 +2084,10 @@ class StreamingRequestFlowControlTest(WebTestCase): @gen.coroutine def prepare(self): - with self.in_method('prepare'): - yield gen.Task(IOLoop.current().add_callback) + # Note that asynchronous prepare() does not block data_received, + # so we don't use in_method here. + self.methods.append('prepare') + yield gen.Task(IOLoop.current().add_callback) @gen.coroutine def data_received(self, data): @@ -2051,9 +2149,10 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase): # When the content-length is too high, the connection is simply # closed without completing the response. An error is logged on # the server. - with ExpectLog(app_log, "Uncaught exception"): + with ExpectLog(app_log, "(Uncaught exception|Exception in callback)"): with ExpectLog(gen_log, - "Cannot send error response after headers written"): + "(Cannot send error response after headers written" + "|Failed to flush partial response)"): response = self.fetch("/high") self.assertEqual(response.code, 599) self.assertEqual(str(self.server_error), @@ -2063,9 +2162,10 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase): # When the content-length is too low, the connection is closed # without writing the last chunk, so the client never sees the request # complete (which would be a framing error). - with ExpectLog(app_log, "Uncaught exception"): + with ExpectLog(app_log, "(Uncaught exception|Exception in callback)"): with ExpectLog(gen_log, - "Cannot send error response after headers written"): + "(Cannot send error response after headers written" + "|Failed to flush partial response)"): response = self.fetch("/low") self.assertEqual(response.code, 599) self.assertEqual(str(self.server_error), @@ -2075,21 +2175,28 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase): class ClientCloseTest(SimpleHandlerTestCase): class Handler(RequestHandler): def get(self): - # Simulate a connection closed by the client during - # request processing. The client will see an error, but the - # server should respond gracefully (without logging errors - # because we were unable to write out as many bytes as - # Content-Length said we would) - self.request.connection.stream.close() - self.write('hello') + if self.request.version.startswith('HTTP/1'): + # Simulate a connection closed by the client during + # request processing. The client will see an error, but the + # server should respond gracefully (without logging errors + # because we were unable to write out as many bytes as + # Content-Length said we would) + self.request.connection.stream.close() + self.write('hello') + else: + # TODO: add a HTTP2-compatible version of this test. + self.write('requires HTTP/1.x') def test_client_close(self): response = self.fetch('/') + if response.body == b'requires HTTP/1.x': + self.skipTest('requires HTTP/1.x') self.assertEqual(response.code, 599) class SignedValueTest(unittest.TestCase): SECRET = "It's a secret to everybody" + SECRET_DICT = {0: "asdfbasdf", 1: "12312312", 2: "2342342"} def past(self): return self.present() - 86400 * 32 @@ -2151,6 +2258,7 @@ class SignedValueTest(unittest.TestCase): def test_payload_tampering(self): # These cookies are variants of the one in test_known_values. sig = "3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152" + def validate(prefix): return (b'value' == decode_signed_value(SignedValueTest.SECRET, "key", @@ -2165,6 +2273,7 @@ class SignedValueTest(unittest.TestCase): def test_signature_tampering(self): prefix = "2|1:0|10:1300000000|3:key|8:dmFsdWU=|" + def validate(sig): return (b'value' == decode_signed_value(SignedValueTest.SECRET, "key", @@ -2194,6 +2303,43 @@ class SignedValueTest(unittest.TestCase): clock=self.present) self.assertEqual(value, decoded) + def test_key_versioning_read_write_default_key(self): + value = b"\xe9" + signed = create_signed_value(SignedValueTest.SECRET_DICT, + "key", value, clock=self.present, + key_version=0) + decoded = decode_signed_value(SignedValueTest.SECRET_DICT, + "key", signed, clock=self.present) + self.assertEqual(value, decoded) + + def test_key_versioning_read_write_non_default_key(self): + value = b"\xe9" + signed = create_signed_value(SignedValueTest.SECRET_DICT, + "key", value, clock=self.present, + key_version=1) + decoded = decode_signed_value(SignedValueTest.SECRET_DICT, + "key", signed, clock=self.present) + self.assertEqual(value, decoded) + + def test_key_versioning_invalid_key(self): + value = b"\xe9" + signed = create_signed_value(SignedValueTest.SECRET_DICT, + "key", value, clock=self.present, + key_version=0) + newkeys = SignedValueTest.SECRET_DICT.copy() + newkeys.pop(0) + decoded = decode_signed_value(newkeys, + "key", signed, clock=self.present) + self.assertEqual(None, decoded) + + def test_key_version_retrieval(self): + value = b"\xe9" + signed = create_signed_value(SignedValueTest.SECRET_DICT, + "key", value, clock=self.present, + key_version=1) + key_version = get_signature_key_version(signed) + self.assertEqual(1, key_version) + @wsgi_safe class XSRFTest(SimpleHandlerTestCase): @@ -2296,7 +2442,7 @@ class XSRFTest(SimpleHandlerTestCase): token2 = self.get_token() # Each token can be used to authenticate its own request. for token in (self.xsrf_token, token2): - response = self.fetch( + response = self.fetch( "/", method="POST", body=urllib_parse.urlencode(dict(_xsrf=token)), headers=self.cookie_headers(token)) @@ -2372,6 +2518,7 @@ class FinishExceptionTest(SimpleHandlerTestCase): self.assertEqual(b'authentication required', response.body) +@wsgi_safe class DecoratorTest(WebTestCase): def get_handlers(self): class RemoveSlashHandler(RequestHandler): @@ -2405,3 +2552,85 @@ class DecoratorTest(WebTestCase): response = self.fetch("/addslash?foo=bar", follow_redirects=False) self.assertEqual(response.code, 301) self.assertEqual(response.headers['Location'], "/addslash/?foo=bar") + + +@wsgi_safe +class CacheTest(WebTestCase): + def get_handlers(self): + class EtagHandler(RequestHandler): + def get(self, computed_etag): + self.write(computed_etag) + + def compute_etag(self): + return self._write_buffer[0] + + return [ + ('/etag/(.*)', EtagHandler) + ] + + def test_wildcard_etag(self): + computed_etag = '"xyzzy"' + etags = '*' + self._test_etag(computed_etag, etags, 304) + + def test_strong_etag_match(self): + computed_etag = '"xyzzy"' + etags = '"xyzzy"' + self._test_etag(computed_etag, etags, 304) + + def test_multiple_strong_etag_match(self): + computed_etag = '"xyzzy1"' + etags = '"xyzzy1", "xyzzy2"' + self._test_etag(computed_etag, etags, 304) + + def test_strong_etag_not_match(self): + computed_etag = '"xyzzy"' + etags = '"xyzzy1"' + self._test_etag(computed_etag, etags, 200) + + def test_multiple_strong_etag_not_match(self): + computed_etag = '"xyzzy"' + etags = '"xyzzy1", "xyzzy2"' + self._test_etag(computed_etag, etags, 200) + + def test_weak_etag_match(self): + computed_etag = '"xyzzy1"' + etags = 'W/"xyzzy1"' + self._test_etag(computed_etag, etags, 304) + + def test_multiple_weak_etag_match(self): + computed_etag = '"xyzzy2"' + etags = 'W/"xyzzy1", W/"xyzzy2"' + self._test_etag(computed_etag, etags, 304) + + def test_weak_etag_not_match(self): + computed_etag = '"xyzzy2"' + etags = 'W/"xyzzy1"' + self._test_etag(computed_etag, etags, 200) + + def test_multiple_weak_etag_not_match(self): + computed_etag = '"xyzzy3"' + etags = 'W/"xyzzy1", W/"xyzzy2"' + self._test_etag(computed_etag, etags, 200) + + def _test_etag(self, computed_etag, etags, status_code): + response = self.fetch( + '/etag/' + computed_etag, + headers={'If-None-Match': etags} + ) + self.assertEqual(response.code, status_code) + + +@wsgi_safe +class RequestSummaryTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + # remote_ip is optional, although it's set by + # both HTTPServer and WSGIAdapter. + # Clobber it to make sure it doesn't break logging. + self.request.remote_ip = None + self.finish(self._request_summary()) + + def test_missing_remote_ip(self): + resp = self.fetch("/") + self.assertEqual(resp.body, b"GET / (None)") diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index e1e3ea70..6b182d07 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -12,7 +12,7 @@ from tornado.web import Application, RequestHandler from tornado.util import u try: - import tornado.websocket + import tornado.websocket # noqa from tornado.util import _websocket_mask_python except ImportError: # The unittest module presents misleading errors on ImportError @@ -53,7 +53,7 @@ class EchoHandler(TestWebSocketHandler): class ErrorInOnMessageHandler(TestWebSocketHandler): def on_message(self, message): - 1/0 + 1 / 0 class HeaderHandler(TestWebSocketHandler): @@ -75,6 +75,7 @@ class NonWebSocketHandler(RequestHandler): class CloseReasonHandler(TestWebSocketHandler): def open(self): + self.on_close_called = False self.close(1001, "goodbye") @@ -91,7 +92,7 @@ class WebSocketBaseTestCase(AsyncHTTPTestCase): @gen.coroutine def ws_connect(self, path, compression_options=None): ws = yield websocket_connect( - 'ws://localhost:%d%s' % (self.get_http_port(), path), + 'ws://127.0.0.1:%d%s' % (self.get_http_port(), path), compression_options=compression_options) raise gen.Return(ws) @@ -105,6 +106,7 @@ class WebSocketBaseTestCase(AsyncHTTPTestCase): ws.close() yield self.close_future + class WebSocketTest(WebSocketBaseTestCase): def get_app(self): self.close_future = Future() @@ -135,7 +137,7 @@ class WebSocketTest(WebSocketBaseTestCase): def test_websocket_callbacks(self): websocket_connect( - 'ws://localhost:%d/echo' % self.get_http_port(), + 'ws://127.0.0.1:%d/echo' % self.get_http_port(), io_loop=self.io_loop, callback=self.stop) ws = self.wait().result() ws.write_message('hello') @@ -189,14 +191,14 @@ class WebSocketTest(WebSocketBaseTestCase): with self.assertRaises(IOError): with ExpectLog(gen_log, ".*"): yield websocket_connect( - 'ws://localhost:%d/' % port, + 'ws://127.0.0.1:%d/' % port, io_loop=self.io_loop, connect_timeout=3600) @gen_test def test_websocket_close_buffered_data(self): ws = yield websocket_connect( - 'ws://localhost:%d/echo' % self.get_http_port()) + 'ws://127.0.0.1:%d/echo' % self.get_http_port()) ws.write_message('hello') ws.write_message('world') # Close the underlying stream. @@ -207,7 +209,7 @@ class WebSocketTest(WebSocketBaseTestCase): def test_websocket_headers(self): # Ensure that arbitrary headers can be passed through websocket_connect. ws = yield websocket_connect( - HTTPRequest('ws://localhost:%d/header' % self.get_http_port(), + HTTPRequest('ws://127.0.0.1:%d/header' % self.get_http_port(), headers={'X-Test': 'hello'})) response = yield ws.read_message() self.assertEqual(response, 'hello') @@ -221,6 +223,8 @@ class WebSocketTest(WebSocketBaseTestCase): self.assertIs(msg, None) self.assertEqual(ws.close_code, 1001) self.assertEqual(ws.close_reason, "goodbye") + # The on_close callback is called no matter which side closed. + yield self.close_future @gen_test def test_client_close_reason(self): @@ -243,11 +247,11 @@ class WebSocketTest(WebSocketBaseTestCase): def test_check_origin_valid_no_path(self): port = self.get_http_port() - url = 'ws://localhost:%d/echo' % port - headers = {'Origin': 'http://localhost:%d' % port} + url = 'ws://127.0.0.1:%d/echo' % port + headers = {'Origin': 'http://127.0.0.1:%d' % port} ws = yield websocket_connect(HTTPRequest(url, headers=headers), - io_loop=self.io_loop) + io_loop=self.io_loop) ws.write_message('hello') response = yield ws.read_message() self.assertEqual(response, 'hello') @@ -257,11 +261,11 @@ class WebSocketTest(WebSocketBaseTestCase): def test_check_origin_valid_with_path(self): port = self.get_http_port() - url = 'ws://localhost:%d/echo' % port - headers = {'Origin': 'http://localhost:%d/something' % port} + url = 'ws://127.0.0.1:%d/echo' % port + headers = {'Origin': 'http://127.0.0.1:%d/something' % port} ws = yield websocket_connect(HTTPRequest(url, headers=headers), - io_loop=self.io_loop) + io_loop=self.io_loop) ws.write_message('hello') response = yield ws.read_message() self.assertEqual(response, 'hello') @@ -271,8 +275,8 @@ class WebSocketTest(WebSocketBaseTestCase): def test_check_origin_invalid_partial_url(self): port = self.get_http_port() - url = 'ws://localhost:%d/echo' % port - headers = {'Origin': 'localhost:%d' % port} + url = 'ws://127.0.0.1:%d/echo' % port + headers = {'Origin': '127.0.0.1:%d' % port} with self.assertRaises(HTTPError) as cm: yield websocket_connect(HTTPRequest(url, headers=headers), @@ -283,8 +287,8 @@ class WebSocketTest(WebSocketBaseTestCase): def test_check_origin_invalid(self): port = self.get_http_port() - url = 'ws://localhost:%d/echo' % port - # Host is localhost, which should not be accessible from some other + url = 'ws://127.0.0.1:%d/echo' % port + # Host is 127.0.0.1, which should not be accessible from some other # domain headers = {'Origin': 'http://somewhereelse.com'} diff --git a/tornado/testing.py b/tornado/testing.py index 4d85abe9..93f0dbe1 100644 --- a/tornado/testing.py +++ b/tornado/testing.py @@ -19,6 +19,7 @@ try: from tornado.simple_httpclient import SimpleAsyncHTTPClient from tornado.ioloop import IOLoop, TimeoutError from tornado import netutil + from tornado.process import Subprocess except ImportError: # These modules are not importable on app engine. Parts of this module # won't work, but e.g. LogTrapTestCase and main() will. @@ -28,6 +29,7 @@ except ImportError: IOLoop = None netutil = None SimpleAsyncHTTPClient = None + Subprocess = None from tornado.log import gen_log, app_log from tornado.stack_context import ExceptionStackContext from tornado.util import raise_exc_info, basestring_type @@ -214,6 +216,8 @@ class AsyncTestCase(unittest.TestCase): self.io_loop.make_current() def tearDown(self): + # Clean up Subprocess, so it can be used again with a new ioloop. + Subprocess.uninitialize() self.io_loop.clear_current() if (not IOLoop.initialized() or self.io_loop is not IOLoop.instance()): @@ -413,10 +417,8 @@ class AsyncHTTPSTestCase(AsyncHTTPTestCase): Interface is generally the same as `AsyncHTTPTestCase`. """ def get_http_client(self): - # Some versions of libcurl have deadlock bugs with ssl, - # so always run these tests with SimpleAsyncHTTPClient. - return SimpleAsyncHTTPClient(io_loop=self.io_loop, force_instance=True, - defaults=dict(validate_cert=False)) + return AsyncHTTPClient(io_loop=self.io_loop, force_instance=True, + defaults=dict(validate_cert=False)) def get_httpserver_options(self): return dict(ssl_options=self.get_ssl_options()) @@ -539,6 +541,9 @@ class LogTrapTestCase(unittest.TestCase): `logging.basicConfig` and the "pretty logging" configured by `tornado.options`. It is not compatible with other log buffering mechanisms, such as those provided by some test runners. + + .. deprecated:: 4.1 + Use the unittest module's ``--buffer`` option instead, or `.ExpectLog`. """ def run(self, result=None): logger = logging.getLogger() diff --git a/tornado/util.py b/tornado/util.py index 34c4b072..606ced19 100644 --- a/tornado/util.py +++ b/tornado/util.py @@ -78,6 +78,25 @@ class GzipDecompressor(object): return self.decompressobj.flush() +# Fake unicode literal support: Python 3.2 doesn't have the u'' marker for +# literal strings, and alternative solutions like "from __future__ import +# unicode_literals" have other problems (see PEP 414). u() can be applied +# to ascii strings that include \u escapes (but they must not contain +# literal non-ascii characters). +if not isinstance(b'', type('')): + def u(s): + return s + unicode_type = str + basestring_type = str +else: + def u(s): + return s.decode('unicode_escape') + # These names don't exist in py3, so use noqa comments to disable + # warnings in flake8. + unicode_type = unicode # noqa + basestring_type = basestring # noqa + + def import_object(name): """Imports an object by name. @@ -96,6 +115,9 @@ def import_object(name): ... ImportError: No module named missing_module """ + if isinstance(name, unicode_type) and str is not unicode_type: + # On python 2 a byte string is required. + name = name.encode('utf-8') if name.count('.') == 0: return __import__(name, None, None) @@ -107,22 +129,6 @@ def import_object(name): raise ImportError("No module named %s" % parts[-1]) -# Fake unicode literal support: Python 3.2 doesn't have the u'' marker for -# literal strings, and alternative solutions like "from __future__ import -# unicode_literals" have other problems (see PEP 414). u() can be applied -# to ascii strings that include \u escapes (but they must not contain -# literal non-ascii characters). -if type('') is not type(b''): - def u(s): - return s - unicode_type = str - basestring_type = str -else: - def u(s): - return s.decode('unicode_escape') - unicode_type = unicode - basestring_type = basestring - # Deprecated alias that was used before we dropped py25 support. # Left here in case anyone outside Tornado is using it. bytes_type = bytes @@ -192,21 +198,21 @@ class Configurable(object): __impl_class = None __impl_kwargs = None - def __new__(cls, **kwargs): + def __new__(cls, *args, **kwargs): base = cls.configurable_base() - args = {} + init_kwargs = {} if cls is base: impl = cls.configured_class() if base.__impl_kwargs: - args.update(base.__impl_kwargs) + init_kwargs.update(base.__impl_kwargs) else: impl = cls - args.update(kwargs) + init_kwargs.update(kwargs) instance = super(Configurable, cls).__new__(impl) # initialize vs __init__ chosen for compatibility with AsyncHTTPClient # singleton magic. If we get rid of that we can switch to __init__ # here too. - instance.initialize(**args) + instance.initialize(*args, **init_kwargs) return instance @classmethod @@ -227,6 +233,9 @@ class Configurable(object): """Initialize a `Configurable` subclass instance. Configurable classes should use `initialize` instead of ``__init__``. + + .. versionchanged:: 4.2 + Now accepts positional arguments in addition to keyword arguments. """ @classmethod @@ -339,7 +348,7 @@ def _websocket_mask_python(mask, data): return unmasked.tostring() if (os.environ.get('TORNADO_NO_EXTENSION') or - os.environ.get('TORNADO_EXTENSION') == '0'): + os.environ.get('TORNADO_EXTENSION') == '0'): # These environment variables exist to make it easier to do performance # comparisons; they are not guaranteed to remain supported in the future. _websocket_mask = _websocket_mask_python diff --git a/tornado/web.py b/tornado/web.py index 9edd719a..9bc12933 100644 --- a/tornado/web.py +++ b/tornado/web.py @@ -19,7 +19,9 @@ features that allow it to scale to large numbers of open connections, making it ideal for `long polling `_. -Here is a simple "Hello, world" example app:: +Here is a simple "Hello, world" example app: + +.. testcode:: import tornado.ioloop import tornado.web @@ -33,7 +35,11 @@ Here is a simple "Hello, world" example app:: (r"/", MainHandler), ]) application.listen(8888) - tornado.ioloop.IOLoop.instance().start() + tornado.ioloop.IOLoop.current().start() + +.. testoutput:: + :hide: + See the :doc:`guide` for additional information. @@ -50,7 +56,8 @@ request. """ -from __future__ import absolute_import, division, print_function, with_statement +from __future__ import (absolute_import, division, + print_function, with_statement) import base64 @@ -84,7 +91,9 @@ from tornado.log import access_log, app_log, gen_log from tornado import stack_context from tornado import template from tornado.escape import utf8, _unicode -from tornado.util import import_object, ObjectDict, raise_exc_info, unicode_type, _websocket_mask +from tornado.util import (import_object, ObjectDict, raise_exc_info, + unicode_type, _websocket_mask) +from tornado.httputil import split_host_and_port try: @@ -130,12 +139,11 @@ May be overridden by passing a ``version`` keyword argument. DEFAULT_SIGNED_VALUE_MIN_VERSION = 1 """The oldest signed value accepted by `.RequestHandler.get_secure_cookie`. -May be overrided by passing a ``min_version`` keyword argument. +May be overridden by passing a ``min_version`` keyword argument. .. versionadded:: 3.2.1 """ - class RequestHandler(object): """Subclass this class and define `get()` or `post()` to make a handler. @@ -267,6 +275,7 @@ class RequestHandler(object): if _has_stream_request_body(self.__class__): if not self.request.body.done(): self.request.body.set_exception(iostream.StreamClosedError()) + self.request.body.exception() def clear(self): """Resets all headers and content for this response.""" @@ -382,6 +391,12 @@ class RequestHandler(object): The returned values are always unicode. """ + + # Make sure `get_arguments` isn't accidentally being called with a + # positional argument that's assumed to be a default (like in + # `get_argument`.) + assert isinstance(strip, bool) + return self._get_arguments(name, self.request.arguments, strip) def get_body_argument(self, name, default=_ARG_DEFAULT, strip=True): @@ -398,7 +413,8 @@ class RequestHandler(object): .. versionadded:: 3.2 """ - return self._get_argument(name, default, self.request.body_arguments, strip) + return self._get_argument(name, default, self.request.body_arguments, + strip) def get_body_arguments(self, name, strip=True): """Returns a list of the body arguments with the given name. @@ -425,7 +441,8 @@ class RequestHandler(object): .. versionadded:: 3.2 """ - return self._get_argument(name, default, self.request.query_arguments, strip) + return self._get_argument(name, default, + self.request.query_arguments, strip) def get_query_arguments(self, name, strip=True): """Returns a list of the query arguments with the given name. @@ -480,7 +497,8 @@ class RequestHandler(object): @property def cookies(self): - """An alias for `self.request.cookies <.httputil.HTTPServerRequest.cookies>`.""" + """An alias for + `self.request.cookies <.httputil.HTTPServerRequest.cookies>`.""" return self.request.cookies def get_cookie(self, name, default=None): @@ -522,6 +540,12 @@ class RequestHandler(object): for k, v in kwargs.items(): if k == 'max_age': k = 'max-age' + + # skip falsy values for httponly and secure flags because + # SimpleCookie sets them regardless + if k in ['httponly', 'secure'] and not v: + continue + morsel[k] = v def clear_cookie(self, name, path="/", domain=None): @@ -588,8 +612,15 @@ class RequestHandler(object): and made it the default. """ self.require_setting("cookie_secret", "secure cookies") - return create_signed_value(self.application.settings["cookie_secret"], - name, value, version=version) + secret = self.application.settings["cookie_secret"] + key_version = None + if isinstance(secret, dict): + if self.application.settings.get("key_version") is None: + raise Exception("key_version setting must be used for secret_key dicts") + key_version = self.application.settings["key_version"] + + return create_signed_value(secret, name, value, version=version, + key_version=key_version) def get_secure_cookie(self, name, value=None, max_age_days=31, min_version=None): @@ -610,6 +641,17 @@ class RequestHandler(object): name, value, max_age_days=max_age_days, min_version=min_version) + def get_secure_cookie_key_version(self, name, value=None): + """Returns the signing key version of the secure cookie. + + The version is returned as int. + """ + self.require_setting("cookie_secret", "secure cookies") + if value is None: + value = self.get_cookie(name) + return get_signature_key_version(value) + + def redirect(self, url, permanent=False, status=None): """Sends a redirect to the given (optionally relative) URL. @@ -625,8 +667,7 @@ class RequestHandler(object): else: assert isinstance(status, int) and 300 <= status <= 399 self.set_status(status) - self.set_header("Location", urlparse.urljoin(utf8(self.request.uri), - utf8(url))) + self.set_header("Location", utf8(url)) self.finish() def write(self, chunk): @@ -646,16 +687,14 @@ class RequestHandler(object): https://github.com/facebook/tornado/issues/1009 """ if self._finished: - raise RuntimeError("Cannot write() after finish(). May be caused " - "by using async operations without the " - "@asynchronous decorator.") + raise RuntimeError("Cannot write() after finish()") if not isinstance(chunk, (bytes, unicode_type, dict)): - raise TypeError("write() only accepts bytes, unicode, and dict objects") + message = "write() only accepts bytes, unicode, and dict objects" + if isinstance(chunk, list): + message += ". Lists not accepted for security reasons; see http://www.tornadoweb.org/en/stable/web.html#tornado.web.RequestHandler.write" + raise TypeError(message) if isinstance(chunk, dict): - if 'unwrap_json' in chunk: - chunk = chunk['unwrap_json'] - else: - chunk = escape.json_encode(chunk) + chunk = escape.json_encode(chunk) self.set_header("Content-Type", "application/json; charset=UTF-8") chunk = utf8(chunk) self._write_buffer.append(chunk) @@ -786,6 +825,7 @@ class RequestHandler(object): current_user=self.current_user, locale=self.locale, _=self.locale.translate, + pgettext=self.locale.pgettext, static_url=self.static_url, xsrf_form_html=self.xsrf_form_html, reverse_url=self.reverse_url @@ -830,7 +870,8 @@ class RequestHandler(object): for transform in self._transforms: self._status_code, self._headers, chunk = \ transform.transform_first_chunk( - self._status_code, self._headers, chunk, include_footers) + self._status_code, self._headers, + chunk, include_footers) # Ignore the chunk and only write the headers for HEAD requests if self.request.method == "HEAD": chunk = None @@ -842,7 +883,7 @@ class RequestHandler(object): for cookie in self._new_cookie.values(): self.add_header("Set-Cookie", cookie.OutputString(None)) - start_line = httputil.ResponseStartLine(self.request.version, + start_line = httputil.ResponseStartLine('', self._status_code, self._reason) return self.request.connection.write_headers( @@ -861,9 +902,7 @@ class RequestHandler(object): def finish(self, chunk=None): """Finishes this response, ending the HTTP request.""" if self._finished: - raise RuntimeError("finish() called twice. May be caused " - "by using async operations without the " - "@asynchronous decorator.") + raise RuntimeError("finish() called twice") if chunk is not None: self.write(chunk) @@ -915,7 +954,15 @@ class RequestHandler(object): if self._headers_written: gen_log.error("Cannot send error response after headers written") if not self._finished: - self.finish() + # If we get an error between writing headers and finishing, + # we are unlikely to be able to finish due to a + # Content-Length mismatch. Try anyway to release the + # socket. + try: + self.finish() + except Exception: + gen_log.error("Failed to flush partial response", + exc_info=True) return self.clear() @@ -1122,28 +1169,37 @@ class RequestHandler(object): """Convert a cookie string into a the tuple form returned by _get_raw_xsrf_token. """ - m = _signed_value_version_re.match(utf8(cookie)) - if m: - version = int(m.group(1)) - if version == 2: - _, mask, masked_token, timestamp = cookie.split("|") - mask = binascii.a2b_hex(utf8(mask)) - token = _websocket_mask( - mask, binascii.a2b_hex(utf8(masked_token))) - timestamp = int(timestamp) - return version, token, timestamp + + try: + m = _signed_value_version_re.match(utf8(cookie)) + + if m: + version = int(m.group(1)) + if version == 2: + _, mask, masked_token, timestamp = cookie.split("|") + + mask = binascii.a2b_hex(utf8(mask)) + token = _websocket_mask( + mask, binascii.a2b_hex(utf8(masked_token))) + timestamp = int(timestamp) + return version, token, timestamp + else: + # Treat unknown versions as not present instead of failing. + raise Exception("Unknown xsrf cookie version") else: - # Treat unknown versions as not present instead of failing. - return None, None, None - else: - version = 1 - try: - token = binascii.a2b_hex(utf8(cookie)) - except (binascii.Error, TypeError): - token = utf8(cookie) - # We don't have a usable timestamp in older versions. - timestamp = int(time.time()) - return (version, token, timestamp) + version = 1 + try: + token = binascii.a2b_hex(utf8(cookie)) + except (binascii.Error, TypeError): + token = utf8(cookie) + # We don't have a usable timestamp in older versions. + timestamp = int(time.time()) + return (version, token, timestamp) + except Exception: + # Catch exceptions and return nothing instead of failing. + gen_log.debug("Uncaught exception in _decode_xsrf_token", + exc_info=True) + return None, None, None def check_xsrf_cookie(self): """Verifies that the ``_xsrf`` cookie matches the ``_xsrf`` argument. @@ -1282,9 +1338,27 @@ class RequestHandler(object): before completing the request. The ``Etag`` header should be set (perhaps with `set_etag_header`) before calling this method. """ - etag = self._headers.get("Etag") - inm = utf8(self.request.headers.get("If-None-Match", "")) - return bool(etag and inm and inm.find(etag) >= 0) + computed_etag = utf8(self._headers.get("Etag", "")) + # Find all weak and strong etag values from If-None-Match header + # because RFC 7232 allows multiple etag values in a single header. + etags = re.findall( + br'\*|(?:W/)?"[^"]*"', + utf8(self.request.headers.get("If-None-Match", "")) + ) + if not computed_etag or not etags: + return False + + match = False + if etags[0] == b'*': + match = True + else: + # Use a weak comparison when comparing entity-tags. + val = lambda x: x[2:] if x.startswith(b'W/') else x + for etag in etags: + if val(etag) == val(computed_etag): + match = True + break + return match def _stack_context_handle_exception(self, type, value, traceback): try: @@ -1344,7 +1418,10 @@ class RequestHandler(object): if self._auto_finish and not self._finished: self.finish() except Exception as e: - self._handle_request_exception(e) + try: + self._handle_request_exception(e) + except Exception: + app_log.error("Exception in exception handler", exc_info=True) if (self._prepared_future is not None and not self._prepared_future.done()): # In case we failed before setting _prepared_future, do it @@ -1369,8 +1446,8 @@ class RequestHandler(object): self.application.log_request(self) def _request_summary(self): - return self.request.method + " " + self.request.uri + \ - " (" + self.request.remote_ip + ")" + return "%s %s (%s)" % (self.request.method, self.request.uri, + self.request.remote_ip) def _handle_request_exception(self, e): if isinstance(e, Finish): @@ -1378,7 +1455,12 @@ class RequestHandler(object): if not self._finished: self.finish() return - self.log_exception(*sys.exc_info()) + try: + self.log_exception(*sys.exc_info()) + except Exception: + # An error here should still get a best-effort send_error() + # to avoid leaking the connection. + app_log.error("Error in exception logger", exc_info=True) if self._finished: # Extra errors after the request has been finished should # be logged, but there is no reason to continue to try and @@ -1441,10 +1523,11 @@ class RequestHandler(object): def asynchronous(method): """Wrap request handler methods with this if they are asynchronous. - This decorator is unnecessary if the method is also decorated with - ``@gen.coroutine`` (it is legal but unnecessary to use the two - decorators together, in which case ``@asynchronous`` must be - first). + This decorator is for callback-style asynchronous methods; for + coroutines, use the ``@gen.coroutine`` decorator without + ``@asynchronous``. (It is legal for legacy reasons to use the two + decorators together provided ``@asynchronous`` is first, but + ``@asynchronous`` will be ignored in this case) This decorator should only be applied to the :ref:`HTTP verb methods `; its behavior is undefined for any other method. @@ -1457,10 +1540,12 @@ def asynchronous(method): method returns. It is up to the request handler to call `self.finish() ` to finish the HTTP request. Without this decorator, the request is automatically - finished when the ``get()`` or ``post()`` method returns. Example:: + finished when the ``get()`` or ``post()`` method returns. Example: - class MyRequestHandler(web.RequestHandler): - @web.asynchronous + .. testcode:: + + class MyRequestHandler(RequestHandler): + @asynchronous def get(self): http = httpclient.AsyncHTTPClient() http.fetch("http://friendfeed.com/", self._on_download) @@ -1469,18 +1554,23 @@ def asynchronous(method): self.write("Downloaded!") self.finish() + .. testoutput:: + :hide: + .. versionadded:: 3.1 The ability to use ``@gen.coroutine`` without ``@asynchronous``. + """ # Delay the IOLoop import because it's not available on app engine. from tornado.ioloop import IOLoop + @functools.wraps(method) def wrapper(self, *args, **kwargs): self._auto_finish = False with stack_context.ExceptionStackContext( self._stack_context_handle_exception): result = method(self, *args, **kwargs) - if isinstance(result, Future): + if is_future(result): # If @asynchronous is used with @gen.coroutine, (but # not @gen.engine), we can automatically finish the # request when the future resolves. Additionally, @@ -1521,7 +1611,7 @@ def stream_request_body(cls): the entire body has been read. There is a subtle interaction between ``data_received`` and asynchronous - ``prepare``: The first call to ``data_recieved`` may occur at any point + ``prepare``: The first call to ``data_received`` may occur at any point after the call to ``prepare`` has returned *or yielded*. """ if not issubclass(cls, RequestHandler): @@ -1591,7 +1681,7 @@ class Application(httputil.HTTPServerConnectionDelegate): ]) http_server = httpserver.HTTPServer(application) http_server.listen(8080) - ioloop.IOLoop.instance().start() + ioloop.IOLoop.current().start() The constructor for this class takes in a list of `URLSpec` objects or (regexp, request_class) tuples. When we receive requests, we @@ -1689,7 +1779,7 @@ class Application(httputil.HTTPServerConnectionDelegate): `.TCPServer.bind`/`.TCPServer.start` methods directly. Note that after calling this method you still need to call - ``IOLoop.instance().start()`` to start the server. + ``IOLoop.current().start()`` to start the server. """ # import is here rather than top level because HTTPServer # is not importable on appengine @@ -1732,7 +1822,7 @@ class Application(httputil.HTTPServerConnectionDelegate): self.transforms.append(transform_class) def _get_host_handlers(self, request): - host = request.host.lower().split(':')[0] + host = split_host_and_port(request.host.lower())[0] matches = [] for pattern, handlers in self.handlers: if pattern.match(host): @@ -1773,9 +1863,9 @@ class Application(httputil.HTTPServerConnectionDelegate): except TypeError: pass - def start_request(self, connection): + def start_request(self, server_conn, request_conn): # Modern HTTPServer interface - return _RequestDispatcher(self, connection) + return _RequestDispatcher(self, request_conn) def __call__(self, request): # Legacy HTTPServer interface @@ -1831,7 +1921,8 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate): def headers_received(self, start_line, headers): self.set_request(httputil.HTTPServerRequest( - connection=self.connection, start_line=start_line, headers=headers)) + connection=self.connection, start_line=start_line, + headers=headers)) if self.stream_request_body: self.request.body = Future() return self.execute() @@ -1848,7 +1939,9 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate): handlers = app._get_host_handlers(self.request) if not handlers: self.handler_class = RedirectHandler - self.handler_kwargs = dict(url="http://" + app.default_host + "/") + self.handler_kwargs = dict(url="%s://%s/" + % (self.request.protocol, + app.default_host)) return for spec in handlers: match = spec.regex.match(self.request.path) @@ -1914,11 +2007,14 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate): if self.stream_request_body: self.handler._prepared_future = Future() # Note that if an exception escapes handler._execute it will be - # trapped in the Future it returns (which we are ignoring here). + # trapped in the Future it returns (which we are ignoring here, + # leaving it to be logged when the Future is GC'd). # However, that shouldn't happen because _execute has a blanket # except handler, and we cannot easily access the IOLoop here to - # call add_future. - self.handler._execute(transforms, *self.path_args, **self.path_kwargs) + # call add_future (because of the requirement to remain compatible + # with WSGI) + f = self.handler._execute(transforms, *self.path_args, + **self.path_kwargs) # If we are streaming the request body, then execute() is finished # when the handler has prepared to receive the body. If not, # it doesn't matter when execute() finishes (so we return None) @@ -1952,6 +2048,8 @@ class HTTPError(Exception): self.log_message = log_message self.args = args self.reason = kwargs.get('reason', None) + if log_message and not args: + self.log_message = log_message.replace('%', '%%') def __str__(self): message = "HTTP %d: %s" % ( @@ -2212,7 +2310,8 @@ class StaticFileHandler(RequestHandler): if content_type: self.set_header("Content-Type", content_type) - cache_time = self.get_cache_time(self.path, self.modified, content_type) + cache_time = self.get_cache_time(self.path, self.modified, + content_type) if cache_time > 0: self.set_header("Expires", datetime.datetime.utcnow() + datetime.timedelta(seconds=cache_time)) @@ -2381,7 +2480,8 @@ class StaticFileHandler(RequestHandler): .. versionadded:: 3.1 """ stat_result = self._stat() - modified = datetime.datetime.utcfromtimestamp(stat_result[stat.ST_MTIME]) + modified = datetime.datetime.utcfromtimestamp( + stat_result[stat.ST_MTIME]) return modified def get_content_type(self): @@ -2624,6 +2724,8 @@ class UIModule(object): UI modules often execute additional queries, and they can include additional CSS and JavaScript that will be included in the output page, which is automatically inserted on page render. + + Subclasses of UIModule must override the `render` method. """ def __init__(self, handler): self.handler = handler @@ -2636,31 +2738,45 @@ class UIModule(object): return self.handler.current_user def render(self, *args, **kwargs): - """Overridden in subclasses to return this module's output.""" + """Override in subclasses to return this module's output.""" raise NotImplementedError() def embedded_javascript(self): - """Returns a JavaScript string that will be embedded in the page.""" + """Override to return a JavaScript string + to be embedded in the page.""" return None def javascript_files(self): - """Returns a list of JavaScript files required by this module.""" + """Override to return a list of JavaScript files needed by this module. + + If the return values are relative paths, they will be passed to + `RequestHandler.static_url`; otherwise they will be used as-is. + """ return None def embedded_css(self): - """Returns a CSS string that will be embedded in the page.""" + """Override to return a CSS string + that will be embedded in the page.""" return None def css_files(self): - """Returns a list of CSS files required by this module.""" + """Override to returns a list of CSS files required by this module. + + If the return values are relative paths, they will be passed to + `RequestHandler.static_url`; otherwise they will be used as-is. + """ return None def html_head(self): - """Returns a CSS string that will be put in the element""" + """Override to return an HTML string that will be put in the + element. + """ return None def html_body(self): - """Returns an HTML string that will be put in the element""" + """Override to return an HTML string that will be put at the end of + the element. + """ return None def render_string(self, path, **kwargs): @@ -2862,11 +2978,13 @@ else: return result == 0 -def create_signed_value(secret, name, value, version=None, clock=None): +def create_signed_value(secret, name, value, version=None, clock=None, + key_version=None): if version is None: version = DEFAULT_SIGNED_VALUE_VERSION if clock is None: clock = time.time + timestamp = utf8(str(int(clock()))) value = base64.b64encode(utf8(value)) if version == 1: @@ -2883,7 +3001,7 @@ def create_signed_value(secret, name, value, version=None, clock=None): # # The fields are: # - format version (i.e. 2; no length prefix) - # - key version (currently 0; reserved for future key rotation features) + # - key version (integer, default is 0) # - timestamp (integer seconds since epoch) # - name (not encoded; assumed to be ~alphanumeric) # - value (base64-encoded) @@ -2891,34 +3009,32 @@ def create_signed_value(secret, name, value, version=None, clock=None): def format_field(s): return utf8("%d:" % len(s)) + utf8(s) to_sign = b"|".join([ - b"2|1:0", + b"2", + format_field(str(key_version or 0)), format_field(timestamp), format_field(name), format_field(value), b'']) + + if isinstance(secret, dict): + assert key_version is not None, 'Key version must be set when sign key dict is used' + assert version >= 2, 'Version must be at least 2 for key version support' + secret = secret[key_version] + signature = _create_signature_v2(secret, to_sign) return to_sign + signature else: raise ValueError("Unsupported version %d" % version) -# A leading version number in decimal with no leading zeros, followed by a pipe. +# A leading version number in decimal +# with no leading zeros, followed by a pipe. _signed_value_version_re = re.compile(br"^([1-9][0-9]*)\|(.*)$") -def decode_signed_value(secret, name, value, max_age_days=31, clock=None, min_version=None): - if clock is None: - clock = time.time - if min_version is None: - min_version = DEFAULT_SIGNED_VALUE_MIN_VERSION - if min_version > 2: - raise ValueError("Unsupported min_version %d" % min_version) - if not value: - return None - - # Figure out what version this is. Version 1 did not include an +def _get_version(value): + # Figures out what version value is. Version 1 did not include an # explicit version field and started with arbitrary base64 data, # which makes this tricky. - value = utf8(value) m = _signed_value_version_re.match(value) if m is None: version = 1 @@ -2935,13 +3051,31 @@ def decode_signed_value(secret, name, value, max_age_days=31, clock=None, min_ve version = 1 except ValueError: version = 1 + return version + + +def decode_signed_value(secret, name, value, max_age_days=31, + clock=None, min_version=None): + if clock is None: + clock = time.time + if min_version is None: + min_version = DEFAULT_SIGNED_VALUE_MIN_VERSION + if min_version > 2: + raise ValueError("Unsupported min_version %d" % min_version) + if not value: + return None + + value = utf8(value) + version = _get_version(value) if version < min_version: return None if version == 1: - return _decode_signed_value_v1(secret, name, value, max_age_days, clock) + return _decode_signed_value_v1(secret, name, value, + max_age_days, clock) elif version == 2: - return _decode_signed_value_v2(secret, name, value, max_age_days, clock) + return _decode_signed_value_v2(secret, name, value, + max_age_days, clock) else: return None @@ -2964,7 +3098,8 @@ def _decode_signed_value_v1(secret, name, value, max_age_days, clock): # digits from the payload to the timestamp without altering the # signature. For backwards compatibility, sanity-check timestamp # here instead of modifying _cookie_signature. - gen_log.warning("Cookie timestamp in future; possible tampering %r", value) + gen_log.warning("Cookie timestamp in future; possible tampering %r", + value) return None if parts[1].startswith(b"0"): gen_log.warning("Tampered cookie %r", value) @@ -2975,7 +3110,7 @@ def _decode_signed_value_v1(secret, name, value, max_age_days, clock): return None -def _decode_signed_value_v2(secret, name, value, max_age_days, clock): +def _decode_fields_v2(value): def _consume_field(s): length, _, rest = s.partition(b':') n = int(length) @@ -2986,16 +3121,28 @@ def _decode_signed_value_v2(secret, name, value, max_age_days, clock): raise ValueError("malformed v2 signed value field") rest = rest[n + 1:] return field_value, rest + rest = value[2:] # remove version number + key_version, rest = _consume_field(rest) + timestamp, rest = _consume_field(rest) + name_field, rest = _consume_field(rest) + value_field, passed_sig = _consume_field(rest) + return int(key_version), timestamp, name_field, value_field, passed_sig + + +def _decode_signed_value_v2(secret, name, value, max_age_days, clock): try: - key_version, rest = _consume_field(rest) - timestamp, rest = _consume_field(rest) - name_field, rest = _consume_field(rest) - value_field, rest = _consume_field(rest) + key_version, timestamp, name_field, value_field, passed_sig = _decode_fields_v2(value) except ValueError: return None - passed_sig = rest signed_string = value[:-len(passed_sig)] + + if isinstance(secret, dict): + try: + secret = secret[key_version] + except KeyError: + return None + expected_sig = _create_signature_v2(secret, signed_string) if not _time_independent_equals(passed_sig, expected_sig): return None @@ -3011,6 +3158,19 @@ def _decode_signed_value_v2(secret, name, value, max_age_days, clock): return None +def get_signature_key_version(value): + value = utf8(value) + version = _get_version(value) + if version < 2: + return None + try: + key_version, _, _, _, _ = _decode_fields_v2(value) + except ValueError: + return None + + return key_version + + def _create_signature_v1(secret, *parts): hash = hmac.new(utf8(secret), digestmod=hashlib.sha1) for part in parts: diff --git a/tornado/websocket.py b/tornado/websocket.py index d960b0e4..adf238be 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -16,7 +16,8 @@ the protocol (known as "draft 76") and are not compatible with this module. Removed support for the draft 76 protocol version. """ -from __future__ import absolute_import, division, print_function, with_statement +from __future__ import (absolute_import, division, + print_function, with_statement) # Author: Jacob Kristhammar, 2010 import base64 @@ -39,9 +40,9 @@ from tornado.tcpclient import TCPClient from tornado.util import _websocket_mask try: - from urllib.parse import urlparse # py2 + from urllib.parse import urlparse # py2 except ImportError: - from urlparse import urlparse # py3 + from urlparse import urlparse # py3 try: xrange # py2 @@ -74,17 +75,22 @@ class WebSocketHandler(tornado.web.RequestHandler): http://tools.ietf.org/html/rfc6455. Here is an example WebSocket handler that echos back all received messages - back to the client:: + back to the client: - class EchoWebSocket(websocket.WebSocketHandler): + .. testcode:: + + class EchoWebSocket(tornado.websocket.WebSocketHandler): def open(self): - print "WebSocket opened" + print("WebSocket opened") def on_message(self, message): self.write_message(u"You said: " + message) def on_close(self): - print "WebSocket closed" + print("WebSocket closed") + + .. testoutput:: + :hide: WebSockets are not standard HTTP connections. The "handshake" is HTTP, but after the handshake, the protocol is @@ -129,6 +135,7 @@ class WebSocketHandler(tornado.web.RequestHandler): self.close_code = None self.close_reason = None self.stream = None + self._on_close_called = False @tornado.web.asynchronous def get(self, *args, **kwargs): @@ -138,16 +145,22 @@ class WebSocketHandler(tornado.web.RequestHandler): # Upgrade header should be present and should be equal to WebSocket if self.request.headers.get("Upgrade", "").lower() != 'websocket': self.set_status(400) - self.finish("Can \"Upgrade\" only to \"WebSocket\".") + log_msg = "Can \"Upgrade\" only to \"WebSocket\"." + self.finish(log_msg) + gen_log.debug(log_msg) return - # Connection header should be upgrade. Some proxy servers/load balancers + # Connection header should be upgrade. + # Some proxy servers/load balancers # might mess with it. headers = self.request.headers - connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(",")) + connection = map(lambda s: s.strip().lower(), + headers.get("Connection", "").split(",")) if 'upgrade' not in connection: self.set_status(400) - self.finish("\"Connection\" must be \"Upgrade\".") + log_msg = "\"Connection\" must be \"Upgrade\"." + self.finish(log_msg) + gen_log.debug(log_msg) return # Handle WebSocket Origin naming convention differences @@ -159,30 +172,29 @@ class WebSocketHandler(tornado.web.RequestHandler): else: origin = self.request.headers.get("Sec-Websocket-Origin", None) - # If there was an origin header, check to make sure it matches # according to check_origin. When the origin is None, we assume it # did not come from a browser and that it can be passed on. if origin is not None and not self.check_origin(origin): self.set_status(403) - self.finish("Cross origin websockets not allowed") + log_msg = "Cross origin websockets not allowed" + self.finish(log_msg) + gen_log.debug(log_msg) return self.stream = self.request.connection.detach() self.stream.set_close_callback(self.on_connection_close) - if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"): - self.ws_connection = WebSocketProtocol13( - self, compression_options=self.get_compression_options()) + self.ws_connection = self.get_websocket_protocol() + if self.ws_connection: self.ws_connection.accept_connection() else: if not self.stream.closed(): self.stream.write(tornado.escape.utf8( "HTTP/1.1 426 Upgrade Required\r\n" - "Sec-WebSocket-Version: 8\r\n\r\n")) + "Sec-WebSocket-Version: 7, 8, 13\r\n\r\n")) self.stream.close() - def write_message(self, message, binary=False): """Sends the given message to the client of this Web Socket. @@ -229,7 +241,7 @@ class WebSocketHandler(tornado.web.RequestHandler): """ return None - def open(self): + def open(self, *args, **kwargs): """Invoked when a new WebSocket is opened. The arguments to `open` are extracted from the `tornado.web.URLSpec` @@ -350,6 +362,8 @@ class WebSocketHandler(tornado.web.RequestHandler): if self.ws_connection: self.ws_connection.on_connection_close() self.ws_connection = None + if not self._on_close_called: + self._on_close_called = True self.on_close() def send_error(self, *args, **kwargs): @@ -362,6 +376,13 @@ class WebSocketHandler(tornado.web.RequestHandler): # we can close the connection more gracefully. self.stream.close() + def get_websocket_protocol(self): + websocket_version = self.request.headers.get("Sec-WebSocket-Version") + if websocket_version in ("7", "8", "13"): + return WebSocketProtocol13( + self, compression_options=self.get_compression_options()) + + def _wrap_method(method): def _disallow_for_websocket(self, *args, **kwargs): if self.stream is None: @@ -499,7 +520,8 @@ class WebSocketProtocol13(WebSocketProtocol): self._handle_websocket_headers() self._accept_connection() except ValueError: - gen_log.debug("Malformed WebSocket request received", exc_info=True) + gen_log.debug("Malformed WebSocket request received", + exc_info=True) self._abort() return @@ -535,18 +557,19 @@ class WebSocketProtocol13(WebSocketProtocol): selected = self.handler.select_subprotocol(subprotocols) if selected: assert selected in subprotocols - subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected + subprotocol_header = ("Sec-WebSocket-Protocol: %s\r\n" + % selected) extension_header = '' extensions = self._parse_extensions_header(self.request.headers) for ext in extensions: if (ext[0] == 'permessage-deflate' and - self._compression_options is not None): + self._compression_options is not None): # TODO: negotiate parameters if compression_options # specifies limits. self._create_compressors('server', ext[1]) if ('client_max_window_bits' in ext[1] and - ext[1]['client_max_window_bits'] is None): + ext[1]['client_max_window_bits'] is None): # Don't echo an offered client_max_window_bits # parameter with no value. del ext[1]['client_max_window_bits'] @@ -591,7 +614,7 @@ class WebSocketProtocol13(WebSocketProtocol): extensions = self._parse_extensions_header(headers) for ext in extensions: if (ext[0] == 'permessage-deflate' and - self._compression_options is not None): + self._compression_options is not None): self._create_compressors('client', ext[1]) else: raise ValueError("unsupported extension %r", ext) @@ -703,7 +726,8 @@ class WebSocketProtocol13(WebSocketProtocol): if self._masked_frame: self.stream.read_bytes(4, self._on_masking_key) else: - self.stream.read_bytes(self._frame_length, self._on_frame_data) + self.stream.read_bytes(self._frame_length, + self._on_frame_data) elif payloadlen == 126: self.stream.read_bytes(2, self._on_frame_length_16) elif payloadlen == 127: @@ -737,7 +761,8 @@ class WebSocketProtocol13(WebSocketProtocol): self._wire_bytes_in += len(data) self._frame_mask = data try: - self.stream.read_bytes(self._frame_length, self._on_masked_frame_data) + self.stream.read_bytes(self._frame_length, + self._on_masked_frame_data) except StreamClosedError: self._abort() @@ -852,12 +877,15 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): This class should not be instantiated directly; use the `websocket_connect` function instead. """ - def __init__(self, io_loop, request, compression_options=None): + def __init__(self, io_loop, request, on_message_callback=None, + compression_options=None): self.compression_options = compression_options self.connect_future = TracebackFuture() + self.protocol = None self.read_future = None self.read_queue = collections.deque() self.key = base64.b64encode(os.urandom(16)) + self._on_message_callback = on_message_callback scheme, sep, rest = request.url.partition(':') scheme = {'ws': 'http', 'wss': 'https'}[scheme] @@ -880,7 +908,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): self.tcp_client = TCPClient(io_loop=io_loop) super(WebSocketClientConnection, self).__init__( io_loop, None, request, lambda: None, self._on_http_response, - 104857600, self.tcp_client, 65536) + 104857600, self.tcp_client, 65536, 104857600) def close(self, code=None, reason=None): """Closes the websocket connection. @@ -919,9 +947,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): start_line, headers) self.headers = headers - self.protocol = WebSocketProtocol13( - self, mask_outgoing=True, - compression_options=self.compression_options) + self.protocol = self.get_websocket_protocol() self.protocol._process_server_headers(self.key, self.headers) self.protocol._receive_frame() @@ -946,6 +972,9 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): def read_message(self, callback=None): """Reads a message from the WebSocket server. + If on_message_callback was specified at WebSocket + initialization, this function will never return messages + Returns a future whose result is the message, or None if the connection is closed. If a callback argument is given it will be called with the future when it is @@ -962,7 +991,9 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): return future def on_message(self, message): - if self.read_future is not None: + if self._on_message_callback: + self._on_message_callback(message) + elif self.read_future is not None: self.read_future.set_result(message) self.read_future = None else: @@ -971,9 +1002,13 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): def on_pong(self, data): pass + def get_websocket_protocol(self): + return WebSocketProtocol13(self, mask_outgoing=True, + compression_options=self.compression_options) + def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None, - compression_options=None): + on_message_callback=None, compression_options=None): """Client-side websocket support. Takes a url and returns a Future whose result is a @@ -982,11 +1017,26 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None, ``compression_options`` is interpreted in the same way as the return value of `.WebSocketHandler.get_compression_options`. + The connection supports two styles of operation. In the coroutine + style, the application typically calls + `~.WebSocketClientConnection.read_message` in a loop:: + + conn = yield websocket_connection(loop) + while True: + msg = yield conn.read_message() + if msg is None: break + # Do something with msg + + In the callback style, pass an ``on_message_callback`` to + ``websocket_connect``. In both styles, a message of ``None`` + indicates that the connection has been closed. + .. versionchanged:: 3.2 Also accepts ``HTTPRequest`` objects in place of urls. .. versionchanged:: 4.1 - Added ``compression_options``. + Added ``compression_options`` and ``on_message_callback``. + The ``io_loop`` argument is deprecated. """ if io_loop is None: io_loop = IOLoop.current() @@ -1000,7 +1050,9 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None, request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout) request = httpclient._RequestProxy( request, httpclient.HTTPRequest._DEFAULTS) - conn = WebSocketClientConnection(io_loop, request, compression_options) + conn = WebSocketClientConnection(io_loop, request, + on_message_callback=on_message_callback, + compression_options=compression_options) if callback is not None: io_loop.add_future(conn.connect_future, callback) return conn.connect_future diff --git a/tornado/wsgi.py b/tornado/wsgi.py index f3aa6650..59e6c559 100644 --- a/tornado/wsgi.py +++ b/tornado/wsgi.py @@ -207,7 +207,7 @@ class WSGIAdapter(object): body = environ["wsgi.input"].read( int(headers["Content-Length"])) else: - body = "" + body = b"" protocol = environ["wsgi.url_scheme"] remote_ip = environ.get("REMOTE_ADDR", "") if environ.get("HTTP_HOST"): @@ -253,7 +253,7 @@ class WSGIContainer(object): container = tornado.wsgi.WSGIContainer(simple_app) http_server = tornado.httpserver.HTTPServer(container) http_server.listen(8888) - tornado.ioloop.IOLoop.instance().start() + tornado.ioloop.IOLoop.current().start() This class is intended to let other frameworks (Django, web.py, etc) run on the Tornado HTTP server and I/O loop. @@ -284,7 +284,8 @@ class WSGIContainer(object): if not data: raise Exception("WSGI app did not call start_response") - status_code = int(data["status"].split()[0]) + status_code, reason = data["status"].split(' ', 1) + status_code = int(status_code) headers = data["headers"] header_set = set(k.lower() for (k, v) in headers) body = escape.utf8(body) @@ -296,13 +297,12 @@ class WSGIContainer(object): if "server" not in header_set: headers.append(("Server", "TornadoServer/%s" % tornado.version)) - parts = [escape.utf8("HTTP/1.1 " + data["status"] + "\r\n")] + start_line = httputil.ResponseStartLine("HTTP/1.1", status_code, reason) + header_obj = httputil.HTTPHeaders() for key, value in headers: - parts.append(escape.utf8(key) + b": " + escape.utf8(value) + b"\r\n") - parts.append(b"\r\n") - parts.append(body) - request.write(b"".join(parts)) - request.finish() + header_obj.add(key, value) + request.connection.write_headers(start_line, header_obj, chunk=body) + request.connection.finish() self._log(status_code, request) @staticmethod