mirror of
https://github.com/SickGear/SickGear.git
synced 2025-01-22 01:23:43 +00:00
Update Tornado webserver to 4.2.dev1 (609dbb9).
Conflicts: CHANGES.md
This commit is contained in:
parent
bed370f811
commit
84fb3e5df9
65 changed files with 4882 additions and 1164 deletions
|
@ -1,5 +1,6 @@
|
|||
### 0.x.x (2015-xx-xx xx:xx:xx UTC)
|
||||
|
||||
* Update Tornado webserver to 4.2.dev1 (609dbb9)
|
||||
* Change network names to only display on top line of Day by Day layout on Episode View
|
||||
* Reposition country part of network name into the hover over in Day by Day layout
|
||||
* Add ToTV provider
|
||||
|
|
|
@ -25,5 +25,5 @@ from __future__ import absolute_import, division, print_function, with_statement
|
|||
# is zero for an official release, positive for a development branch,
|
||||
# or negative for a release candidate or beta (after the base version
|
||||
# number has been incremented)
|
||||
version = "4.1.dev1"
|
||||
version_info = (4, 1, 0, -100)
|
||||
version = "4.2.dev1"
|
||||
version_info = (4, 2, 0, -100)
|
||||
|
|
488
tornado/auth.py
488
tornado/auth.py
|
@ -32,7 +32,9 @@ They all take slightly different arguments due to the fact all these
|
|||
services implement authentication and authorization slightly differently.
|
||||
See the individual service classes below for complete documentation.
|
||||
|
||||
Example usage for Google OpenID::
|
||||
Example usage for Google OAuth:
|
||||
|
||||
.. testcode::
|
||||
|
||||
class GoogleOAuth2LoginHandler(tornado.web.RequestHandler,
|
||||
tornado.auth.GoogleOAuth2Mixin):
|
||||
|
@ -51,6 +53,10 @@ Example usage for Google OpenID::
|
|||
response_type='code',
|
||||
extra_params={'approval_prompt': 'auto'})
|
||||
|
||||
.. testoutput::
|
||||
:hide:
|
||||
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
All of the callback interfaces in this module are now guaranteed
|
||||
to run their callback with an argument of ``None`` on error.
|
||||
|
@ -69,7 +75,7 @@ import hmac
|
|||
import time
|
||||
import uuid
|
||||
|
||||
from tornado.concurrent import TracebackFuture, chain_future, return_future
|
||||
from tornado.concurrent import TracebackFuture, return_future
|
||||
from tornado import gen
|
||||
from tornado import httpclient
|
||||
from tornado import escape
|
||||
|
@ -123,6 +129,7 @@ def _auth_return_future(f):
|
|||
if callback is not None:
|
||||
future.add_done_callback(
|
||||
functools.partial(_auth_future_to_callback, callback))
|
||||
|
||||
def handle_exception(typ, value, tb):
|
||||
if future.done():
|
||||
return False
|
||||
|
@ -138,9 +145,6 @@ def _auth_return_future(f):
|
|||
class OpenIdMixin(object):
|
||||
"""Abstract implementation of OpenID and Attribute Exchange.
|
||||
|
||||
See `GoogleMixin` below for a customized example (which also
|
||||
includes OAuth support).
|
||||
|
||||
Class attributes:
|
||||
|
||||
* ``_OPENID_ENDPOINT``: the identity provider's URI.
|
||||
|
@ -312,8 +316,7 @@ class OpenIdMixin(object):
|
|||
class OAuthMixin(object):
|
||||
"""Abstract implementation of OAuth 1.0 and 1.0a.
|
||||
|
||||
See `TwitterMixin` and `FriendFeedMixin` below for example implementations,
|
||||
or `GoogleMixin` for an OAuth/OpenID hybrid.
|
||||
See `TwitterMixin` below for an example implementation.
|
||||
|
||||
Class attributes:
|
||||
|
||||
|
@ -565,7 +568,8 @@ class OAuthMixin(object):
|
|||
class OAuth2Mixin(object):
|
||||
"""Abstract implementation of OAuth 2.0.
|
||||
|
||||
See `FacebookGraphMixin` below for an example implementation.
|
||||
See `FacebookGraphMixin` or `GoogleOAuth2Mixin` below for example
|
||||
implementations.
|
||||
|
||||
Class attributes:
|
||||
|
||||
|
@ -629,7 +633,9 @@ class TwitterMixin(OAuthMixin):
|
|||
URL you registered as your application's callback URL.
|
||||
|
||||
When your application is set up, you can use this mixin like this
|
||||
to authenticate the user with Twitter and get access to their stream::
|
||||
to authenticate the user with Twitter and get access to their stream:
|
||||
|
||||
.. testcode::
|
||||
|
||||
class TwitterLoginHandler(tornado.web.RequestHandler,
|
||||
tornado.auth.TwitterMixin):
|
||||
|
@ -641,6 +647,9 @@ class TwitterMixin(OAuthMixin):
|
|||
else:
|
||||
yield self.authorize_redirect()
|
||||
|
||||
.. testoutput::
|
||||
:hide:
|
||||
|
||||
The user object returned by `~OAuthMixin.get_authenticated_user`
|
||||
includes the attributes ``username``, ``name``, ``access_token``,
|
||||
and all of the custom Twitter user attributes described at
|
||||
|
@ -689,7 +698,9 @@ class TwitterMixin(OAuthMixin):
|
|||
`~OAuthMixin.get_authenticated_user`. The user returned through that
|
||||
process includes an 'access_token' attribute that can be used
|
||||
to make authenticated requests via this method. Example
|
||||
usage::
|
||||
usage:
|
||||
|
||||
.. testcode::
|
||||
|
||||
class MainHandler(tornado.web.RequestHandler,
|
||||
tornado.auth.TwitterMixin):
|
||||
|
@ -706,6 +717,9 @@ class TwitterMixin(OAuthMixin):
|
|||
return
|
||||
self.finish("Posted a message!")
|
||||
|
||||
.. testoutput::
|
||||
:hide:
|
||||
|
||||
"""
|
||||
if path.startswith('http:') or path.startswith('https:'):
|
||||
# Raw urls are useful for e.g. search which doesn't follow the
|
||||
|
@ -757,223 +771,6 @@ class TwitterMixin(OAuthMixin):
|
|||
raise gen.Return(user)
|
||||
|
||||
|
||||
class FriendFeedMixin(OAuthMixin):
|
||||
"""FriendFeed OAuth authentication.
|
||||
|
||||
To authenticate with FriendFeed, register your application with
|
||||
FriendFeed at http://friendfeed.com/api/applications. Then copy
|
||||
your Consumer Key and Consumer Secret to the application
|
||||
`~tornado.web.Application.settings` ``friendfeed_consumer_key``
|
||||
and ``friendfeed_consumer_secret``. Use this mixin on the handler
|
||||
for the URL you registered as your application's Callback URL.
|
||||
|
||||
When your application is set up, you can use this mixin like this
|
||||
to authenticate the user with FriendFeed and get access to their feed::
|
||||
|
||||
class FriendFeedLoginHandler(tornado.web.RequestHandler,
|
||||
tornado.auth.FriendFeedMixin):
|
||||
@tornado.gen.coroutine
|
||||
def get(self):
|
||||
if self.get_argument("oauth_token", None):
|
||||
user = yield self.get_authenticated_user()
|
||||
# Save the user using e.g. set_secure_cookie()
|
||||
else:
|
||||
yield self.authorize_redirect()
|
||||
|
||||
The user object returned by `~OAuthMixin.get_authenticated_user()` includes the
|
||||
attributes ``username``, ``name``, and ``description`` in addition to
|
||||
``access_token``. You should save the access token with the user;
|
||||
it is required to make requests on behalf of the user later with
|
||||
`friendfeed_request()`.
|
||||
"""
|
||||
_OAUTH_VERSION = "1.0"
|
||||
_OAUTH_REQUEST_TOKEN_URL = "https://friendfeed.com/account/oauth/request_token"
|
||||
_OAUTH_ACCESS_TOKEN_URL = "https://friendfeed.com/account/oauth/access_token"
|
||||
_OAUTH_AUTHORIZE_URL = "https://friendfeed.com/account/oauth/authorize"
|
||||
_OAUTH_NO_CALLBACKS = True
|
||||
_OAUTH_VERSION = "1.0"
|
||||
|
||||
@_auth_return_future
|
||||
def friendfeed_request(self, path, callback, access_token=None,
|
||||
post_args=None, **args):
|
||||
"""Fetches the given relative API path, e.g., "/bret/friends"
|
||||
|
||||
If the request is a POST, ``post_args`` should be provided. Query
|
||||
string arguments should be given as keyword arguments.
|
||||
|
||||
All the FriendFeed methods are documented at
|
||||
http://friendfeed.com/api/documentation.
|
||||
|
||||
Many methods require an OAuth access token which you can
|
||||
obtain through `~OAuthMixin.authorize_redirect` and
|
||||
`~OAuthMixin.get_authenticated_user`. The user returned
|
||||
through that process includes an ``access_token`` attribute that
|
||||
can be used to make authenticated requests via this
|
||||
method.
|
||||
|
||||
Example usage::
|
||||
|
||||
class MainHandler(tornado.web.RequestHandler,
|
||||
tornado.auth.FriendFeedMixin):
|
||||
@tornado.web.authenticated
|
||||
@tornado.gen.coroutine
|
||||
def get(self):
|
||||
new_entry = yield self.friendfeed_request(
|
||||
"/entry",
|
||||
post_args={"body": "Testing Tornado Web Server"},
|
||||
access_token=self.current_user["access_token"])
|
||||
|
||||
if not new_entry:
|
||||
# Call failed; perhaps missing permission?
|
||||
yield self.authorize_redirect()
|
||||
return
|
||||
self.finish("Posted a message!")
|
||||
|
||||
"""
|
||||
# Add the OAuth resource request signature if we have credentials
|
||||
url = "http://friendfeed-api.com/v2" + path
|
||||
if access_token:
|
||||
all_args = {}
|
||||
all_args.update(args)
|
||||
all_args.update(post_args or {})
|
||||
method = "POST" if post_args is not None else "GET"
|
||||
oauth = self._oauth_request_parameters(
|
||||
url, access_token, all_args, method=method)
|
||||
args.update(oauth)
|
||||
if args:
|
||||
url += "?" + urllib_parse.urlencode(args)
|
||||
callback = functools.partial(self._on_friendfeed_request, callback)
|
||||
http = self.get_auth_http_client()
|
||||
if post_args is not None:
|
||||
http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args),
|
||||
callback=callback)
|
||||
else:
|
||||
http.fetch(url, callback=callback)
|
||||
|
||||
def _on_friendfeed_request(self, future, response):
|
||||
if response.error:
|
||||
future.set_exception(AuthError(
|
||||
"Error response %s fetching %s" % (response.error,
|
||||
response.request.url)))
|
||||
return
|
||||
future.set_result(escape.json_decode(response.body))
|
||||
|
||||
def _oauth_consumer_token(self):
|
||||
self.require_setting("friendfeed_consumer_key", "FriendFeed OAuth")
|
||||
self.require_setting("friendfeed_consumer_secret", "FriendFeed OAuth")
|
||||
return dict(
|
||||
key=self.settings["friendfeed_consumer_key"],
|
||||
secret=self.settings["friendfeed_consumer_secret"])
|
||||
|
||||
@gen.coroutine
|
||||
def _oauth_get_user_future(self, access_token, callback):
|
||||
user = yield self.friendfeed_request(
|
||||
"/feedinfo/" + access_token["username"],
|
||||
include="id,name,description", access_token=access_token)
|
||||
if user:
|
||||
user["username"] = user["id"]
|
||||
callback(user)
|
||||
|
||||
def _parse_user_response(self, callback, user):
|
||||
if user:
|
||||
user["username"] = user["id"]
|
||||
callback(user)
|
||||
|
||||
|
||||
class GoogleMixin(OpenIdMixin, OAuthMixin):
|
||||
"""Google Open ID / OAuth authentication.
|
||||
|
||||
.. deprecated:: 4.0
|
||||
New applications should use `GoogleOAuth2Mixin`
|
||||
below instead of this class. As of May 19, 2014, Google has stopped
|
||||
supporting registration-free authentication.
|
||||
|
||||
No application registration is necessary to use Google for
|
||||
authentication or to access Google resources on behalf of a user.
|
||||
|
||||
Google implements both OpenID and OAuth in a hybrid mode. If you
|
||||
just need the user's identity, use
|
||||
`~OpenIdMixin.authenticate_redirect`. If you need to make
|
||||
requests to Google on behalf of the user, use
|
||||
`authorize_redirect`. On return, parse the response with
|
||||
`~OpenIdMixin.get_authenticated_user`. We send a dict containing
|
||||
the values for the user, including ``email``, ``name``, and
|
||||
``locale``.
|
||||
|
||||
Example usage::
|
||||
|
||||
class GoogleLoginHandler(tornado.web.RequestHandler,
|
||||
tornado.auth.GoogleMixin):
|
||||
@tornado.gen.coroutine
|
||||
def get(self):
|
||||
if self.get_argument("openid.mode", None):
|
||||
user = yield self.get_authenticated_user()
|
||||
# Save the user with e.g. set_secure_cookie()
|
||||
else:
|
||||
yield self.authenticate_redirect()
|
||||
"""
|
||||
_OPENID_ENDPOINT = "https://www.google.com/accounts/o8/ud"
|
||||
_OAUTH_ACCESS_TOKEN_URL = "https://www.google.com/accounts/OAuthGetAccessToken"
|
||||
|
||||
@return_future
|
||||
def authorize_redirect(self, oauth_scope, callback_uri=None,
|
||||
ax_attrs=["name", "email", "language", "username"],
|
||||
callback=None):
|
||||
"""Authenticates and authorizes for the given Google resource.
|
||||
|
||||
Some of the available resources which can be used in the ``oauth_scope``
|
||||
argument are:
|
||||
|
||||
* Gmail Contacts - http://www.google.com/m8/feeds/
|
||||
* Calendar - http://www.google.com/calendar/feeds/
|
||||
* Finance - http://finance.google.com/finance/feeds/
|
||||
|
||||
You can authorize multiple resources by separating the resource
|
||||
URLs with a space.
|
||||
|
||||
.. versionchanged:: 3.1
|
||||
Returns a `.Future` and takes an optional callback. These are
|
||||
not strictly necessary as this method is synchronous,
|
||||
but they are supplied for consistency with
|
||||
`OAuthMixin.authorize_redirect`.
|
||||
"""
|
||||
callback_uri = callback_uri or self.request.uri
|
||||
args = self._openid_args(callback_uri, ax_attrs=ax_attrs,
|
||||
oauth_scope=oauth_scope)
|
||||
self.redirect(self._OPENID_ENDPOINT + "?" + urllib_parse.urlencode(args))
|
||||
callback()
|
||||
|
||||
@_auth_return_future
|
||||
def get_authenticated_user(self, callback):
|
||||
"""Fetches the authenticated user data upon redirect."""
|
||||
# Look to see if we are doing combined OpenID/OAuth
|
||||
oauth_ns = ""
|
||||
for name, values in self.request.arguments.items():
|
||||
if name.startswith("openid.ns.") and \
|
||||
values[-1] == b"http://specs.openid.net/extensions/oauth/1.0":
|
||||
oauth_ns = name[10:]
|
||||
break
|
||||
token = self.get_argument("openid." + oauth_ns + ".request_token", "")
|
||||
if token:
|
||||
http = self.get_auth_http_client()
|
||||
token = dict(key=token, secret="")
|
||||
http.fetch(self._oauth_access_token_url(token),
|
||||
functools.partial(self._on_access_token, callback))
|
||||
else:
|
||||
chain_future(OpenIdMixin.get_authenticated_user(self),
|
||||
callback)
|
||||
|
||||
def _oauth_consumer_token(self):
|
||||
self.require_setting("google_consumer_key", "Google OAuth")
|
||||
self.require_setting("google_consumer_secret", "Google OAuth")
|
||||
return dict(
|
||||
key=self.settings["google_consumer_key"],
|
||||
secret=self.settings["google_consumer_secret"])
|
||||
|
||||
def _oauth_get_user_future(self, access_token):
|
||||
return OpenIdMixin.get_authenticated_user(self)
|
||||
|
||||
|
||||
class GoogleOAuth2Mixin(OAuth2Mixin):
|
||||
"""Google authentication using OAuth2.
|
||||
|
||||
|
@ -1001,7 +798,9 @@ class GoogleOAuth2Mixin(OAuth2Mixin):
|
|||
def get_authenticated_user(self, redirect_uri, code, callback):
|
||||
"""Handles the login for the Google user, returning a user object.
|
||||
|
||||
Example usage::
|
||||
Example usage:
|
||||
|
||||
.. testcode::
|
||||
|
||||
class GoogleOAuth2LoginHandler(tornado.web.RequestHandler,
|
||||
tornado.auth.GoogleOAuth2Mixin):
|
||||
|
@ -1019,6 +818,10 @@ class GoogleOAuth2Mixin(OAuth2Mixin):
|
|||
scope=['profile', 'email'],
|
||||
response_type='code',
|
||||
extra_params={'approval_prompt': 'auto'})
|
||||
|
||||
.. testoutput::
|
||||
:hide:
|
||||
|
||||
"""
|
||||
http = self.get_auth_http_client()
|
||||
body = urllib_parse.urlencode({
|
||||
|
@ -1051,217 +854,6 @@ class GoogleOAuth2Mixin(OAuth2Mixin):
|
|||
return httpclient.AsyncHTTPClient()
|
||||
|
||||
|
||||
class FacebookMixin(object):
|
||||
"""Facebook Connect authentication.
|
||||
|
||||
.. deprecated:: 1.1
|
||||
New applications should use `FacebookGraphMixin`
|
||||
below instead of this class. This class does not support the
|
||||
Future-based interface seen on other classes in this module.
|
||||
|
||||
To authenticate with Facebook, register your application with
|
||||
Facebook at http://www.facebook.com/developers/apps.php. Then
|
||||
copy your API Key and Application Secret to the application settings
|
||||
``facebook_api_key`` and ``facebook_secret``.
|
||||
|
||||
When your application is set up, you can use this mixin like this
|
||||
to authenticate the user with Facebook::
|
||||
|
||||
class FacebookHandler(tornado.web.RequestHandler,
|
||||
tornado.auth.FacebookMixin):
|
||||
@tornado.web.asynchronous
|
||||
def get(self):
|
||||
if self.get_argument("session", None):
|
||||
self.get_authenticated_user(self._on_auth)
|
||||
return
|
||||
yield self.authenticate_redirect()
|
||||
|
||||
def _on_auth(self, user):
|
||||
if not user:
|
||||
raise tornado.web.HTTPError(500, "Facebook auth failed")
|
||||
# Save the user using, e.g., set_secure_cookie()
|
||||
|
||||
The user object returned by `get_authenticated_user` includes the
|
||||
attributes ``facebook_uid`` and ``name`` in addition to session attributes
|
||||
like ``session_key``. You should save the session key with the user; it is
|
||||
required to make requests on behalf of the user later with
|
||||
`facebook_request`.
|
||||
"""
|
||||
@return_future
|
||||
def authenticate_redirect(self, callback_uri=None, cancel_uri=None,
|
||||
extended_permissions=None, callback=None):
|
||||
"""Authenticates/installs this app for the current user.
|
||||
|
||||
.. versionchanged:: 3.1
|
||||
Returns a `.Future` and takes an optional callback. These are
|
||||
not strictly necessary as this method is synchronous,
|
||||
but they are supplied for consistency with
|
||||
`OAuthMixin.authorize_redirect`.
|
||||
"""
|
||||
self.require_setting("facebook_api_key", "Facebook Connect")
|
||||
callback_uri = callback_uri or self.request.uri
|
||||
args = {
|
||||
"api_key": self.settings["facebook_api_key"],
|
||||
"v": "1.0",
|
||||
"fbconnect": "true",
|
||||
"display": "page",
|
||||
"next": urlparse.urljoin(self.request.full_url(), callback_uri),
|
||||
"return_session": "true",
|
||||
}
|
||||
if cancel_uri:
|
||||
args["cancel_url"] = urlparse.urljoin(
|
||||
self.request.full_url(), cancel_uri)
|
||||
if extended_permissions:
|
||||
if isinstance(extended_permissions, (unicode_type, bytes)):
|
||||
extended_permissions = [extended_permissions]
|
||||
args["req_perms"] = ",".join(extended_permissions)
|
||||
self.redirect("http://www.facebook.com/login.php?" +
|
||||
urllib_parse.urlencode(args))
|
||||
callback()
|
||||
|
||||
def authorize_redirect(self, extended_permissions, callback_uri=None,
|
||||
cancel_uri=None, callback=None):
|
||||
"""Redirects to an authorization request for the given FB resource.
|
||||
|
||||
The available resource names are listed at
|
||||
http://wiki.developers.facebook.com/index.php/Extended_permission.
|
||||
The most common resource types include:
|
||||
|
||||
* publish_stream
|
||||
* read_stream
|
||||
* email
|
||||
* sms
|
||||
|
||||
extended_permissions can be a single permission name or a list of
|
||||
names. To get the session secret and session key, call
|
||||
get_authenticated_user() just as you would with
|
||||
authenticate_redirect().
|
||||
|
||||
.. versionchanged:: 3.1
|
||||
Returns a `.Future` and takes an optional callback. These are
|
||||
not strictly necessary as this method is synchronous,
|
||||
but they are supplied for consistency with
|
||||
`OAuthMixin.authorize_redirect`.
|
||||
"""
|
||||
return self.authenticate_redirect(callback_uri, cancel_uri,
|
||||
extended_permissions,
|
||||
callback=callback)
|
||||
|
||||
def get_authenticated_user(self, callback):
|
||||
"""Fetches the authenticated Facebook user.
|
||||
|
||||
The authenticated user includes the special Facebook attributes
|
||||
'session_key' and 'facebook_uid' in addition to the standard
|
||||
user attributes like 'name'.
|
||||
"""
|
||||
self.require_setting("facebook_api_key", "Facebook Connect")
|
||||
session = escape.json_decode(self.get_argument("session"))
|
||||
self.facebook_request(
|
||||
method="facebook.users.getInfo",
|
||||
callback=functools.partial(
|
||||
self._on_get_user_info, callback, session),
|
||||
session_key=session["session_key"],
|
||||
uids=session["uid"],
|
||||
fields="uid,first_name,last_name,name,locale,pic_square,"
|
||||
"profile_url,username")
|
||||
|
||||
def facebook_request(self, method, callback, **args):
|
||||
"""Makes a Facebook API REST request.
|
||||
|
||||
We automatically include the Facebook API key and signature, but
|
||||
it is the callers responsibility to include 'session_key' and any
|
||||
other required arguments to the method.
|
||||
|
||||
The available Facebook methods are documented here:
|
||||
http://wiki.developers.facebook.com/index.php/API
|
||||
|
||||
Here is an example for the stream.get() method::
|
||||
|
||||
class MainHandler(tornado.web.RequestHandler,
|
||||
tornado.auth.FacebookMixin):
|
||||
@tornado.web.authenticated
|
||||
@tornado.web.asynchronous
|
||||
def get(self):
|
||||
self.facebook_request(
|
||||
method="stream.get",
|
||||
callback=self._on_stream,
|
||||
session_key=self.current_user["session_key"])
|
||||
|
||||
def _on_stream(self, stream):
|
||||
if stream is None:
|
||||
# Not authorized to read the stream yet?
|
||||
self.redirect(self.authorize_redirect("read_stream"))
|
||||
return
|
||||
self.render("stream.html", stream=stream)
|
||||
|
||||
"""
|
||||
self.require_setting("facebook_api_key", "Facebook Connect")
|
||||
self.require_setting("facebook_secret", "Facebook Connect")
|
||||
if not method.startswith("facebook."):
|
||||
method = "facebook." + method
|
||||
args["api_key"] = self.settings["facebook_api_key"]
|
||||
args["v"] = "1.0"
|
||||
args["method"] = method
|
||||
args["call_id"] = str(long(time.time() * 1e6))
|
||||
args["format"] = "json"
|
||||
args["sig"] = self._signature(args)
|
||||
url = "http://api.facebook.com/restserver.php?" + \
|
||||
urllib_parse.urlencode(args)
|
||||
http = self.get_auth_http_client()
|
||||
http.fetch(url, callback=functools.partial(
|
||||
self._parse_response, callback))
|
||||
|
||||
def _on_get_user_info(self, callback, session, users):
|
||||
if users is None:
|
||||
callback(None)
|
||||
return
|
||||
callback({
|
||||
"name": users[0]["name"],
|
||||
"first_name": users[0]["first_name"],
|
||||
"last_name": users[0]["last_name"],
|
||||
"uid": users[0]["uid"],
|
||||
"locale": users[0]["locale"],
|
||||
"pic_square": users[0]["pic_square"],
|
||||
"profile_url": users[0]["profile_url"],
|
||||
"username": users[0].get("username"),
|
||||
"session_key": session["session_key"],
|
||||
"session_expires": session.get("expires"),
|
||||
})
|
||||
|
||||
def _parse_response(self, callback, response):
|
||||
if response.error:
|
||||
gen_log.warning("HTTP error from Facebook: %s", response.error)
|
||||
callback(None)
|
||||
return
|
||||
try:
|
||||
json = escape.json_decode(response.body)
|
||||
except Exception:
|
||||
gen_log.warning("Invalid JSON from Facebook: %r", response.body)
|
||||
callback(None)
|
||||
return
|
||||
if isinstance(json, dict) and json.get("error_code"):
|
||||
gen_log.warning("Facebook error: %d: %r", json["error_code"],
|
||||
json.get("error_msg"))
|
||||
callback(None)
|
||||
return
|
||||
callback(json)
|
||||
|
||||
def _signature(self, args):
|
||||
parts = ["%s=%s" % (n, args[n]) for n in sorted(args.keys())]
|
||||
body = "".join(parts) + self.settings["facebook_secret"]
|
||||
if isinstance(body, unicode_type):
|
||||
body = body.encode("utf-8")
|
||||
return hashlib.md5(body).hexdigest()
|
||||
|
||||
def get_auth_http_client(self):
|
||||
"""Returns the `.AsyncHTTPClient` instance to be used for auth requests.
|
||||
|
||||
May be overridden by subclasses to use an HTTP client other than
|
||||
the default.
|
||||
"""
|
||||
return httpclient.AsyncHTTPClient()
|
||||
|
||||
|
||||
class FacebookGraphMixin(OAuth2Mixin):
|
||||
"""Facebook authentication using the new Graph API and OAuth2."""
|
||||
_OAUTH_ACCESS_TOKEN_URL = "https://graph.facebook.com/oauth/access_token?"
|
||||
|
@ -1274,9 +866,12 @@ class FacebookGraphMixin(OAuth2Mixin):
|
|||
code, callback, extra_fields=None):
|
||||
"""Handles the login for the Facebook user, returning a user object.
|
||||
|
||||
Example usage::
|
||||
Example usage:
|
||||
|
||||
class FacebookGraphLoginHandler(LoginHandler, tornado.auth.FacebookGraphMixin):
|
||||
.. testcode::
|
||||
|
||||
class FacebookGraphLoginHandler(tornado.web.RequestHandler,
|
||||
tornado.auth.FacebookGraphMixin):
|
||||
@tornado.gen.coroutine
|
||||
def get(self):
|
||||
if self.get_argument("code", False):
|
||||
|
@ -1291,6 +886,10 @@ class FacebookGraphMixin(OAuth2Mixin):
|
|||
redirect_uri='/auth/facebookgraph/',
|
||||
client_id=self.settings["facebook_api_key"],
|
||||
extra_params={"scope": "read_stream,offline_access"})
|
||||
|
||||
.. testoutput::
|
||||
:hide:
|
||||
|
||||
"""
|
||||
http = self.get_auth_http_client()
|
||||
args = {
|
||||
|
@ -1358,7 +957,9 @@ class FacebookGraphMixin(OAuth2Mixin):
|
|||
process includes an ``access_token`` attribute that can be
|
||||
used to make authenticated requests via this method.
|
||||
|
||||
Example usage::
|
||||
Example usage:
|
||||
|
||||
..testcode::
|
||||
|
||||
class MainHandler(tornado.web.RequestHandler,
|
||||
tornado.auth.FacebookGraphMixin):
|
||||
|
@ -1376,6 +977,9 @@ class FacebookGraphMixin(OAuth2Mixin):
|
|||
return
|
||||
self.finish("Posted a message!")
|
||||
|
||||
.. testoutput::
|
||||
:hide:
|
||||
|
||||
The given path is relative to ``self._FACEBOOK_BASE_URL``,
|
||||
by default "https://graph.facebook.com".
|
||||
|
||||
|
|
|
@ -100,6 +100,14 @@ try:
|
|||
except ImportError:
|
||||
signal = None
|
||||
|
||||
# os.execv is broken on Windows and can't properly parse command line
|
||||
# arguments and executable name if they contain whitespaces. subprocess
|
||||
# fixes that behavior.
|
||||
# This distinction is also important because when we use execv, we want to
|
||||
# close the IOLoop and all its file descriptors, to guard against any
|
||||
# file descriptors that were not set CLOEXEC. When execv is not available,
|
||||
# we must not close the IOLoop because we want the process to exit cleanly.
|
||||
_has_execv = sys.platform != 'win32'
|
||||
|
||||
_watched_files = set()
|
||||
_reload_hooks = []
|
||||
|
@ -108,13 +116,18 @@ _io_loops = weakref.WeakKeyDictionary()
|
|||
|
||||
|
||||
def start(io_loop=None, check_time=500):
|
||||
"""Begins watching source files for changes using the given `.IOLoop`. """
|
||||
"""Begins watching source files for changes.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
The ``io_loop`` argument is deprecated.
|
||||
"""
|
||||
io_loop = io_loop or ioloop.IOLoop.current()
|
||||
if io_loop in _io_loops:
|
||||
return
|
||||
_io_loops[io_loop] = True
|
||||
if len(_io_loops) > 1:
|
||||
gen_log.warning("tornado.autoreload started more than once in the same process")
|
||||
if _has_execv:
|
||||
add_reload_hook(functools.partial(io_loop.close, all_fds=True))
|
||||
modify_times = {}
|
||||
callback = functools.partial(_reload_on_update, modify_times)
|
||||
|
@ -162,7 +175,7 @@ def _reload_on_update(modify_times):
|
|||
# processes restarted themselves, they'd all restart and then
|
||||
# all call fork_processes again.
|
||||
return
|
||||
for module in sys.modules.values():
|
||||
for module in list(sys.modules.values()):
|
||||
# Some modules play games with sys.modules (e.g. email/__init__.py
|
||||
# in the standard library), and occasionally this can cause strange
|
||||
# failures in getattr. Just ignore anything that's not an ordinary
|
||||
|
@ -211,10 +224,7 @@ def _reload():
|
|||
not os.environ.get("PYTHONPATH", "").startswith(path_prefix)):
|
||||
os.environ["PYTHONPATH"] = (path_prefix +
|
||||
os.environ.get("PYTHONPATH", ""))
|
||||
if sys.platform == 'win32':
|
||||
# os.execv is broken on Windows and can't properly parse command line
|
||||
# arguments and executable name if they contain whitespaces. subprocess
|
||||
# fixes that behavior.
|
||||
if not _has_execv:
|
||||
subprocess.Popen([sys.executable] + sys.argv)
|
||||
sys.exit(0)
|
||||
else:
|
||||
|
@ -234,7 +244,10 @@ def _reload():
|
|||
# this error specifically.
|
||||
os.spawnv(os.P_NOWAIT, sys.executable,
|
||||
[sys.executable] + sys.argv)
|
||||
sys.exit(0)
|
||||
# At this point the IOLoop has been closed and finally
|
||||
# blocks will experience errors if we allow the stack to
|
||||
# unwind, so just exit uncleanly.
|
||||
os._exit(0)
|
||||
|
||||
_USAGE = """\
|
||||
Usage:
|
||||
|
|
|
@ -25,11 +25,13 @@ module.
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
|
||||
import functools
|
||||
import platform
|
||||
import traceback
|
||||
import sys
|
||||
|
||||
from tornado.log import app_log
|
||||
from tornado.stack_context import ExceptionStackContext, wrap
|
||||
from tornado.util import raise_exc_info, ArgReplacer
|
||||
from tornado.log import app_log
|
||||
|
||||
try:
|
||||
from concurrent import futures
|
||||
|
@ -37,9 +39,90 @@ except ImportError:
|
|||
futures = None
|
||||
|
||||
|
||||
# Can the garbage collector handle cycles that include __del__ methods?
|
||||
# This is true in cpython beginning with version 3.4 (PEP 442).
|
||||
_GC_CYCLE_FINALIZERS = (platform.python_implementation() == 'CPython' and
|
||||
sys.version_info >= (3, 4))
|
||||
|
||||
|
||||
class ReturnValueIgnoredError(Exception):
|
||||
pass
|
||||
|
||||
# This class and associated code in the future object is derived
|
||||
# from the Trollius project, a backport of asyncio to Python 2.x - 3.x
|
||||
|
||||
|
||||
class _TracebackLogger(object):
|
||||
"""Helper to log a traceback upon destruction if not cleared.
|
||||
|
||||
This solves a nasty problem with Futures and Tasks that have an
|
||||
exception set: if nobody asks for the exception, the exception is
|
||||
never logged. This violates the Zen of Python: 'Errors should
|
||||
never pass silently. Unless explicitly silenced.'
|
||||
|
||||
However, we don't want to log the exception as soon as
|
||||
set_exception() is called: if the calling code is written
|
||||
properly, it will get the exception and handle it properly. But
|
||||
we *do* want to log it if result() or exception() was never called
|
||||
-- otherwise developers waste a lot of time wondering why their
|
||||
buggy code fails silently.
|
||||
|
||||
An earlier attempt added a __del__() method to the Future class
|
||||
itself, but this backfired because the presence of __del__()
|
||||
prevents garbage collection from breaking cycles. A way out of
|
||||
this catch-22 is to avoid having a __del__() method on the Future
|
||||
class itself, but instead to have a reference to a helper object
|
||||
with a __del__() method that logs the traceback, where we ensure
|
||||
that the helper object doesn't participate in cycles, and only the
|
||||
Future has a reference to it.
|
||||
|
||||
The helper object is added when set_exception() is called. When
|
||||
the Future is collected, and the helper is present, the helper
|
||||
object is also collected, and its __del__() method will log the
|
||||
traceback. When the Future's result() or exception() method is
|
||||
called (and a helper object is present), it removes the the helper
|
||||
object, after calling its clear() method to prevent it from
|
||||
logging.
|
||||
|
||||
One downside is that we do a fair amount of work to extract the
|
||||
traceback from the exception, even when it is never logged. It
|
||||
would seem cheaper to just store the exception object, but that
|
||||
references the traceback, which references stack frames, which may
|
||||
reference the Future, which references the _TracebackLogger, and
|
||||
then the _TracebackLogger would be included in a cycle, which is
|
||||
what we're trying to avoid! As an optimization, we don't
|
||||
immediately format the exception; we only do the work when
|
||||
activate() is called, which call is delayed until after all the
|
||||
Future's callbacks have run. Since usually a Future has at least
|
||||
one callback (typically set by 'yield From') and usually that
|
||||
callback extracts the callback, thereby removing the need to
|
||||
format the exception.
|
||||
|
||||
PS. I don't claim credit for this solution. I first heard of it
|
||||
in a discussion about closing files when they are collected.
|
||||
"""
|
||||
|
||||
__slots__ = ('exc_info', 'formatted_tb')
|
||||
|
||||
def __init__(self, exc_info):
|
||||
self.exc_info = exc_info
|
||||
self.formatted_tb = None
|
||||
|
||||
def activate(self):
|
||||
exc_info = self.exc_info
|
||||
if exc_info is not None:
|
||||
self.exc_info = None
|
||||
self.formatted_tb = traceback.format_exception(*exc_info)
|
||||
|
||||
def clear(self):
|
||||
self.exc_info = None
|
||||
self.formatted_tb = None
|
||||
|
||||
def __del__(self):
|
||||
if self.formatted_tb:
|
||||
app_log.error('Future exception was never retrieved: %s',
|
||||
''.join(self.formatted_tb).rstrip())
|
||||
|
||||
|
||||
class Future(object):
|
||||
"""Placeholder for an asynchronous result.
|
||||
|
@ -68,12 +151,23 @@ class Future(object):
|
|||
if that package was available and fall back to the thread-unsafe
|
||||
implementation if it was not.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
If a `.Future` contains an error but that error is never observed
|
||||
(by calling ``result()``, ``exception()``, or ``exc_info()``),
|
||||
a stack trace will be logged when the `.Future` is garbage collected.
|
||||
This normally indicates an error in the application, but in cases
|
||||
where it results in undesired logging it may be necessary to
|
||||
suppress the logging by ensuring that the exception is observed:
|
||||
``f.add_done_callback(lambda f: f.exception())``.
|
||||
"""
|
||||
def __init__(self):
|
||||
self._done = False
|
||||
self._result = None
|
||||
self._exception = None
|
||||
self._exc_info = None
|
||||
|
||||
self._log_traceback = False # Used for Python >= 3.4
|
||||
self._tb_logger = None # Used for Python <= 3.3
|
||||
|
||||
self._callbacks = []
|
||||
|
||||
def cancel(self):
|
||||
|
@ -100,16 +194,21 @@ class Future(object):
|
|||
"""Returns True if the future has finished running."""
|
||||
return self._done
|
||||
|
||||
def _clear_tb_log(self):
|
||||
self._log_traceback = False
|
||||
if self._tb_logger is not None:
|
||||
self._tb_logger.clear()
|
||||
self._tb_logger = None
|
||||
|
||||
def result(self, timeout=None):
|
||||
"""If the operation succeeded, return its result. If it failed,
|
||||
re-raise its exception.
|
||||
"""
|
||||
self._clear_tb_log()
|
||||
if self._result is not None:
|
||||
return self._result
|
||||
if self._exc_info is not None:
|
||||
raise_exc_info(self._exc_info)
|
||||
elif self._exception is not None:
|
||||
raise self._exception
|
||||
self._check_done()
|
||||
return self._result
|
||||
|
||||
|
@ -117,8 +216,9 @@ class Future(object):
|
|||
"""If the operation raised an exception, return the `Exception`
|
||||
object. Otherwise returns None.
|
||||
"""
|
||||
if self._exception is not None:
|
||||
return self._exception
|
||||
self._clear_tb_log()
|
||||
if self._exc_info is not None:
|
||||
return self._exc_info[1]
|
||||
else:
|
||||
self._check_done()
|
||||
return None
|
||||
|
@ -147,14 +247,17 @@ class Future(object):
|
|||
|
||||
def set_exception(self, exception):
|
||||
"""Sets the exception of a ``Future.``"""
|
||||
self._exception = exception
|
||||
self._set_done()
|
||||
self.set_exc_info(
|
||||
(exception.__class__,
|
||||
exception,
|
||||
getattr(exception, '__traceback__', None)))
|
||||
|
||||
def exc_info(self):
|
||||
"""Returns a tuple in the same format as `sys.exc_info` or None.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
"""
|
||||
self._clear_tb_log()
|
||||
return self._exc_info
|
||||
|
||||
def set_exc_info(self, exc_info):
|
||||
|
@ -165,7 +268,18 @@ class Future(object):
|
|||
.. versionadded:: 4.0
|
||||
"""
|
||||
self._exc_info = exc_info
|
||||
self.set_exception(exc_info[1])
|
||||
self._log_traceback = True
|
||||
if not _GC_CYCLE_FINALIZERS:
|
||||
self._tb_logger = _TracebackLogger(exc_info)
|
||||
|
||||
try:
|
||||
self._set_done()
|
||||
finally:
|
||||
# Activate the logger after all callbacks have had a
|
||||
# chance to call result() or exception().
|
||||
if self._log_traceback and self._tb_logger is not None:
|
||||
self._tb_logger.activate()
|
||||
self._exc_info = exc_info
|
||||
|
||||
def _check_done(self):
|
||||
if not self._done:
|
||||
|
@ -177,10 +291,25 @@ class Future(object):
|
|||
try:
|
||||
cb(self)
|
||||
except Exception:
|
||||
app_log.exception('exception calling callback %r for %r',
|
||||
app_log.exception('Exception in callback %r for %r',
|
||||
cb, self)
|
||||
self._callbacks = None
|
||||
|
||||
# On Python 3.3 or older, objects with a destructor part of a reference
|
||||
# cycle are never destroyed. It's no longer the case on Python 3.4 thanks to
|
||||
# the PEP 442.
|
||||
if _GC_CYCLE_FINALIZERS:
|
||||
def __del__(self):
|
||||
if not self._log_traceback:
|
||||
# set_exception() was not called, or result() or exception()
|
||||
# has consumed the exception
|
||||
return
|
||||
|
||||
tb = traceback.format_exception(*self._exc_info)
|
||||
|
||||
app_log.error('Future %r exception was never retrieved: %s',
|
||||
self, ''.join(tb).rstrip())
|
||||
|
||||
TracebackFuture = Future
|
||||
|
||||
if futures is None:
|
||||
|
@ -208,24 +337,42 @@ class DummyExecutor(object):
|
|||
dummy_executor = DummyExecutor()
|
||||
|
||||
|
||||
def run_on_executor(fn):
|
||||
def run_on_executor(*args, **kwargs):
|
||||
"""Decorator to run a synchronous method asynchronously on an executor.
|
||||
|
||||
The decorated method may be called with a ``callback`` keyword
|
||||
argument and returns a future.
|
||||
|
||||
This decorator should be used only on methods of objects with attributes
|
||||
``executor`` and ``io_loop``.
|
||||
The `.IOLoop` and executor to be used are determined by the ``io_loop``
|
||||
and ``executor`` attributes of ``self``. To use different attributes,
|
||||
pass keyword arguments to the decorator::
|
||||
|
||||
@run_on_executor(executor='_thread_pool')
|
||||
def foo(self):
|
||||
pass
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
Added keyword arguments to use alternative attributes.
|
||||
"""
|
||||
def run_on_executor_decorator(fn):
|
||||
executor = kwargs.get("executor", "executor")
|
||||
io_loop = kwargs.get("io_loop", "io_loop")
|
||||
@functools.wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
callback = kwargs.pop("callback", None)
|
||||
future = self.executor.submit(fn, self, *args, **kwargs)
|
||||
future = getattr(self, executor).submit(fn, self, *args, **kwargs)
|
||||
if callback:
|
||||
self.io_loop.add_future(future,
|
||||
lambda future: callback(future.result()))
|
||||
getattr(self, io_loop).add_future(
|
||||
future, lambda future: callback(future.result()))
|
||||
return future
|
||||
return wrapper
|
||||
if args and kwargs:
|
||||
raise ValueError("cannot combine positional and keyword args")
|
||||
if len(args) == 1:
|
||||
return run_on_executor_decorator(args[0])
|
||||
elif len(args) != 0:
|
||||
raise ValueError("expected 1 argument, got %d", len(args))
|
||||
return run_on_executor_decorator
|
||||
|
||||
|
||||
_NO_RESULT = object()
|
||||
|
@ -250,7 +397,9 @@ def return_future(f):
|
|||
wait for the function to complete (perhaps by yielding it in a
|
||||
`.gen.engine` function, or passing it to `.IOLoop.add_future`).
|
||||
|
||||
Usage::
|
||||
Usage:
|
||||
|
||||
.. testcode::
|
||||
|
||||
@return_future
|
||||
def future_func(arg1, arg2, callback):
|
||||
|
@ -262,6 +411,8 @@ def return_future(f):
|
|||
yield future_func(arg1, arg2)
|
||||
callback()
|
||||
|
||||
..
|
||||
|
||||
Note that ``@return_future`` and ``@gen.engine`` can be applied to the
|
||||
same function, provided ``@return_future`` appears first. However,
|
||||
consider using ``@gen.coroutine`` instead of this combination.
|
||||
|
@ -293,7 +444,7 @@ def return_future(f):
|
|||
# If the initial synchronous part of f() raised an exception,
|
||||
# go ahead and raise it to the caller directly without waiting
|
||||
# for them to inspect the Future.
|
||||
raise_exc_info(exc_info)
|
||||
future.result()
|
||||
|
||||
# If the caller passed in a callback, schedule it to be called
|
||||
# when the future resolves. It is important that this happens
|
||||
|
|
|
@ -28,12 +28,13 @@ from io import BytesIO
|
|||
|
||||
from tornado import httputil
|
||||
from tornado import ioloop
|
||||
from tornado.log import gen_log
|
||||
from tornado import stack_context
|
||||
|
||||
from tornado.escape import utf8, native_str
|
||||
from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main
|
||||
|
||||
curl_log = logging.getLogger('tornado.curl_httpclient')
|
||||
|
||||
|
||||
class CurlAsyncHTTPClient(AsyncHTTPClient):
|
||||
def initialize(self, io_loop, max_clients=10, defaults=None):
|
||||
|
@ -207,8 +208,24 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
|
|||
"callback": callback,
|
||||
"curl_start_time": time.time(),
|
||||
}
|
||||
self._curl_setup_request(curl, request, curl.info["buffer"],
|
||||
try:
|
||||
self._curl_setup_request(
|
||||
curl, request, curl.info["buffer"],
|
||||
curl.info["headers"])
|
||||
except Exception as e:
|
||||
# If there was an error in setup, pass it on
|
||||
# to the callback. Note that allowing the
|
||||
# error to escape here will appear to work
|
||||
# most of the time since we are still in the
|
||||
# caller's original stack frame, but when
|
||||
# _process_queue() is called from
|
||||
# _finish_pending_requests the exceptions have
|
||||
# nowhere to go.
|
||||
callback(HTTPResponse(
|
||||
request=request,
|
||||
code=599,
|
||||
error=e))
|
||||
else:
|
||||
self._multi.add_handle(curl)
|
||||
|
||||
if not started:
|
||||
|
@ -257,7 +274,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
|
|||
|
||||
def _curl_create(self):
|
||||
curl = pycurl.Curl()
|
||||
if gen_log.isEnabledFor(logging.DEBUG):
|
||||
if curl_log.isEnabledFor(logging.DEBUG):
|
||||
curl.setopt(pycurl.VERBOSE, 1)
|
||||
curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug)
|
||||
return curl
|
||||
|
@ -288,8 +305,8 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
|
|||
functools.partial(self._curl_header_callback,
|
||||
headers, request.header_callback))
|
||||
if request.streaming_callback:
|
||||
write_function = lambda chunk: self.io_loop.add_callback(
|
||||
request.streaming_callback, chunk)
|
||||
def write_function(chunk):
|
||||
self.io_loop.add_callback(request.streaming_callback, chunk)
|
||||
else:
|
||||
write_function = buffer.write
|
||||
if bytes is str: # py2
|
||||
|
@ -381,6 +398,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
|
|||
% request.method)
|
||||
|
||||
request_buffer = BytesIO(utf8(request.body))
|
||||
|
||||
def ioctl(cmd):
|
||||
if cmd == curl.IOCMD_RESTARTREAD:
|
||||
request_buffer.seek(0)
|
||||
|
@ -403,11 +421,11 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
|
|||
raise ValueError("Unsupported auth_mode %s" % request.auth_mode)
|
||||
|
||||
curl.setopt(pycurl.USERPWD, native_str(userpwd))
|
||||
gen_log.debug("%s %s (username: %r)", request.method, request.url,
|
||||
curl_log.debug("%s %s (username: %r)", request.method, request.url,
|
||||
request.auth_username)
|
||||
else:
|
||||
curl.unsetopt(pycurl.USERPWD)
|
||||
gen_log.debug("%s %s", request.method, request.url)
|
||||
curl_log.debug("%s %s", request.method, request.url)
|
||||
|
||||
if request.client_cert is not None:
|
||||
curl.setopt(pycurl.SSLCERT, request.client_cert)
|
||||
|
@ -415,6 +433,9 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
|
|||
if request.client_key is not None:
|
||||
curl.setopt(pycurl.SSLKEY, request.client_key)
|
||||
|
||||
if request.ssl_options is not None:
|
||||
raise ValueError("ssl_options not supported in curl_httpclient")
|
||||
|
||||
if threading.activeCount() > 1:
|
||||
# libcurl/pycurl is not thread-safe by default. When multiple threads
|
||||
# are used, signals should be disabled. This has the side effect
|
||||
|
@ -448,12 +469,12 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
|
|||
def _curl_debug(self, debug_type, debug_msg):
|
||||
debug_types = ('I', '<', '>', '<', '>')
|
||||
if debug_type == 0:
|
||||
gen_log.debug('%s', debug_msg.strip())
|
||||
curl_log.debug('%s', debug_msg.strip())
|
||||
elif debug_type in (1, 2):
|
||||
for line in debug_msg.splitlines():
|
||||
gen_log.debug('%s %s', debug_types[debug_type], line)
|
||||
curl_log.debug('%s %s', debug_types[debug_type], line)
|
||||
elif debug_type == 4:
|
||||
gen_log.debug('%s %r', debug_types[debug_type], debug_msg)
|
||||
curl_log.debug('%s %r', debug_types[debug_type], debug_msg)
|
||||
|
||||
|
||||
class CurlError(HTTPError):
|
||||
|
|
|
@ -82,7 +82,7 @@ def json_encode(value):
|
|||
# JSON permits but does not require forward slashes to be escaped.
|
||||
# This is useful when json data is emitted in a <script> tag
|
||||
# in HTML, as it prevents </script> tags from prematurely terminating
|
||||
# the javscript. Some json libraries do this escaping by default,
|
||||
# the javascript. Some json libraries do this escaping by default,
|
||||
# although python's standard library does not, so we do it here.
|
||||
# http://stackoverflow.com/questions/1580647/json-why-are-forward-slashes-escaped
|
||||
return json.dumps(value).replace("</", "<\\/")
|
||||
|
|
349
tornado/gen.py
349
tornado/gen.py
|
@ -3,7 +3,9 @@ work in an asynchronous environment. Code using the ``gen`` module
|
|||
is technically asynchronous, but it is written as a single generator
|
||||
instead of a collection of separate functions.
|
||||
|
||||
For example, the following asynchronous handler::
|
||||
For example, the following asynchronous handler:
|
||||
|
||||
.. testcode::
|
||||
|
||||
class AsyncHandler(RequestHandler):
|
||||
@asynchronous
|
||||
|
@ -16,7 +18,12 @@ For example, the following asynchronous handler::
|
|||
do_something_with_response(response)
|
||||
self.render("template.html")
|
||||
|
||||
could be written with ``gen`` as::
|
||||
.. testoutput::
|
||||
:hide:
|
||||
|
||||
could be written with ``gen`` as:
|
||||
|
||||
.. testcode::
|
||||
|
||||
class GenAsyncHandler(RequestHandler):
|
||||
@gen.coroutine
|
||||
|
@ -26,12 +33,17 @@ could be written with ``gen`` as::
|
|||
do_something_with_response(response)
|
||||
self.render("template.html")
|
||||
|
||||
.. testoutput::
|
||||
:hide:
|
||||
|
||||
Most asynchronous functions in Tornado return a `.Future`;
|
||||
yielding this object returns its `~.Future.result`.
|
||||
|
||||
You can also yield a list or dict of ``Futures``, which will be
|
||||
started at the same time and run in parallel; a list or dict of results will
|
||||
be returned when they are all finished::
|
||||
be returned when they are all finished:
|
||||
|
||||
.. testcode::
|
||||
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
|
@ -43,8 +55,24 @@ be returned when they are all finished::
|
|||
response3 = response_dict['response3']
|
||||
response4 = response_dict['response4']
|
||||
|
||||
.. testoutput::
|
||||
:hide:
|
||||
|
||||
If the `~functools.singledispatch` library is available (standard in
|
||||
Python 3.4, available via the `singledispatch
|
||||
<https://pypi.python.org/pypi/singledispatch>`_ package on older
|
||||
versions), additional types of objects may be yielded. Tornado includes
|
||||
support for ``asyncio.Future`` and Twisted's ``Deferred`` class when
|
||||
``tornado.platform.asyncio`` and ``tornado.platform.twisted`` are imported.
|
||||
See the `convert_yielded` function to extend this mechanism.
|
||||
|
||||
.. versionchanged:: 3.2
|
||||
Dict support added.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
Support added for yielding ``asyncio`` Futures and Twisted Deferreds
|
||||
via ``singledispatch``.
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
|
||||
|
@ -53,10 +81,21 @@ import functools
|
|||
import itertools
|
||||
import sys
|
||||
import types
|
||||
import weakref
|
||||
|
||||
from tornado.concurrent import Future, TracebackFuture, is_future, chain_future
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.log import app_log
|
||||
from tornado import stack_context
|
||||
from tornado.util import raise_exc_info
|
||||
|
||||
try:
|
||||
from functools import singledispatch # py34+
|
||||
except ImportError as e:
|
||||
try:
|
||||
from singledispatch import singledispatch # backport
|
||||
except ImportError:
|
||||
singledispatch = None
|
||||
|
||||
|
||||
class KeyReuseError(Exception):
|
||||
|
@ -101,9 +140,11 @@ def engine(func):
|
|||
which use ``self.finish()`` in place of a callback argument.
|
||||
"""
|
||||
func = _make_coroutine_wrapper(func, replace_callback=False)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
future = func(*args, **kwargs)
|
||||
|
||||
def final_callback(future):
|
||||
if future.result() is not None:
|
||||
raise ReturnValueIgnoredError(
|
||||
|
@ -241,6 +282,113 @@ class Return(Exception):
|
|||
self.value = value
|
||||
|
||||
|
||||
class WaitIterator(object):
|
||||
"""Provides an iterator to yield the results of futures as they finish.
|
||||
|
||||
Yielding a set of futures like this:
|
||||
|
||||
``results = yield [future1, future2]``
|
||||
|
||||
pauses the coroutine until both ``future1`` and ``future2``
|
||||
return, and then restarts the coroutine with the results of both
|
||||
futures. If either future is an exception, the expression will
|
||||
raise that exception and all the results will be lost.
|
||||
|
||||
If you need to get the result of each future as soon as possible,
|
||||
or if you need the result of some futures even if others produce
|
||||
errors, you can use ``WaitIterator``::
|
||||
|
||||
wait_iterator = gen.WaitIterator(future1, future2)
|
||||
while not wait_iterator.done():
|
||||
try:
|
||||
result = yield wait_iterator.next()
|
||||
except Exception as e:
|
||||
print("Error {} from {}".format(e, wait_iterator.current_future))
|
||||
else:
|
||||
print("Result {} received from {} at {}".format(
|
||||
result, wait_iterator.current_future,
|
||||
wait_iterator.current_index))
|
||||
|
||||
Because results are returned as soon as they are available the
|
||||
output from the iterator *will not be in the same order as the
|
||||
input arguments*. If you need to know which future produced the
|
||||
current result, you can use the attributes
|
||||
``WaitIterator.current_future``, or ``WaitIterator.current_index``
|
||||
to get the index of the future from the input list. (if keyword
|
||||
arguments were used in the construction of the `WaitIterator`,
|
||||
``current_index`` will use the corresponding keyword).
|
||||
|
||||
.. versionadded:: 4.1
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
if args and kwargs:
|
||||
raise ValueError(
|
||||
"You must provide args or kwargs, not both")
|
||||
|
||||
if kwargs:
|
||||
self._unfinished = dict((f, k) for (k, f) in kwargs.items())
|
||||
futures = list(kwargs.values())
|
||||
else:
|
||||
self._unfinished = dict((f, i) for (i, f) in enumerate(args))
|
||||
futures = args
|
||||
|
||||
self._finished = collections.deque()
|
||||
self.current_index = self.current_future = None
|
||||
self._running_future = None
|
||||
|
||||
# Use a weak reference to self to avoid cycles that may delay
|
||||
# garbage collection.
|
||||
self_ref = weakref.ref(self)
|
||||
for future in futures:
|
||||
future.add_done_callback(functools.partial(
|
||||
self._done_callback, self_ref))
|
||||
|
||||
def done(self):
|
||||
"""Returns True if this iterator has no more results."""
|
||||
if self._finished or self._unfinished:
|
||||
return False
|
||||
# Clear the 'current' values when iteration is done.
|
||||
self.current_index = self.current_future = None
|
||||
return True
|
||||
|
||||
def next(self):
|
||||
"""Returns a `.Future` that will yield the next available result.
|
||||
|
||||
Note that this `.Future` will not be the same object as any of
|
||||
the inputs.
|
||||
"""
|
||||
self._running_future = TracebackFuture()
|
||||
# As long as there is an active _running_future, we must
|
||||
# ensure that the WaitIterator is not GC'd (due to the
|
||||
# use of weak references in __init__). Add a callback that
|
||||
# references self so there is a hard reference that will be
|
||||
# cleared automatically when this Future finishes.
|
||||
self._running_future.add_done_callback(lambda f: self)
|
||||
|
||||
if self._finished:
|
||||
self._return_result(self._finished.popleft())
|
||||
|
||||
return self._running_future
|
||||
|
||||
@staticmethod
|
||||
def _done_callback(self_ref, done):
|
||||
self = self_ref()
|
||||
if self is not None:
|
||||
if self._running_future and not self._running_future.done():
|
||||
self._return_result(done)
|
||||
else:
|
||||
self._finished.append(done)
|
||||
|
||||
def _return_result(self, done):
|
||||
"""Called set the returned future's state that of the future
|
||||
we yielded, and set the current future for the iterator.
|
||||
"""
|
||||
chain_future(done, self._running_future)
|
||||
|
||||
self.current_future = done
|
||||
self.current_index = self._unfinished.pop(done)
|
||||
|
||||
|
||||
class YieldPoint(object):
|
||||
"""Base class for objects that may be yielded from the generator.
|
||||
|
||||
|
@ -355,11 +503,13 @@ def Task(func, *args, **kwargs):
|
|||
yielded.
|
||||
"""
|
||||
future = Future()
|
||||
|
||||
def handle_exception(typ, value, tb):
|
||||
if future.done():
|
||||
return False
|
||||
future.set_exc_info((typ, value, tb))
|
||||
return True
|
||||
|
||||
def set_result(result):
|
||||
if future.done():
|
||||
return
|
||||
|
@ -371,6 +521,11 @@ def Task(func, *args, **kwargs):
|
|||
|
||||
class YieldFuture(YieldPoint):
|
||||
def __init__(self, future, io_loop=None):
|
||||
"""Adapts a `.Future` to the `YieldPoint` interface.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
The ``io_loop`` argument is deprecated.
|
||||
"""
|
||||
self.future = future
|
||||
self.io_loop = io_loop or IOLoop.current()
|
||||
|
||||
|
@ -382,7 +537,7 @@ class YieldFuture(YieldPoint):
|
|||
self.io_loop.add_future(self.future, runner.result_callback(self.key))
|
||||
else:
|
||||
self.runner = None
|
||||
self.result = self.future.result()
|
||||
self.result_fn = self.future.result
|
||||
|
||||
def is_ready(self):
|
||||
if self.runner is not None:
|
||||
|
@ -394,7 +549,7 @@ class YieldFuture(YieldPoint):
|
|||
if self.runner is not None:
|
||||
return self.runner.pop_result(self.key).result()
|
||||
else:
|
||||
return self.result
|
||||
return self.result_fn()
|
||||
|
||||
|
||||
class Multi(YieldPoint):
|
||||
|
@ -408,8 +563,18 @@ class Multi(YieldPoint):
|
|||
Instead of a list, the argument may also be a dictionary whose values are
|
||||
Futures, in which case a parallel dictionary is returned mapping the same
|
||||
keys to their results.
|
||||
|
||||
It is not normally necessary to call this class directly, as it
|
||||
will be created automatically as needed. However, calling it directly
|
||||
allows you to use the ``quiet_exceptions`` argument to control
|
||||
the logging of multiple exceptions.
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
If multiple ``YieldPoints`` fail, any exceptions after the first
|
||||
(which is raised) will be logged. Added the ``quiet_exceptions``
|
||||
argument to suppress this logging for selected exception types.
|
||||
"""
|
||||
def __init__(self, children):
|
||||
def __init__(self, children, quiet_exceptions=()):
|
||||
self.keys = None
|
||||
if isinstance(children, dict):
|
||||
self.keys = list(children.keys())
|
||||
|
@ -421,6 +586,7 @@ class Multi(YieldPoint):
|
|||
self.children.append(i)
|
||||
assert all(isinstance(i, YieldPoint) for i in self.children)
|
||||
self.unfinished_children = set(self.children)
|
||||
self.quiet_exceptions = quiet_exceptions
|
||||
|
||||
def start(self, runner):
|
||||
for i in self.children:
|
||||
|
@ -433,14 +599,27 @@ class Multi(YieldPoint):
|
|||
return not self.unfinished_children
|
||||
|
||||
def get_result(self):
|
||||
result = (i.get_result() for i in self.children)
|
||||
if self.keys is not None:
|
||||
return dict(zip(self.keys, result))
|
||||
result_list = []
|
||||
exc_info = None
|
||||
for f in self.children:
|
||||
try:
|
||||
result_list.append(f.get_result())
|
||||
except Exception as e:
|
||||
if exc_info is None:
|
||||
exc_info = sys.exc_info()
|
||||
else:
|
||||
return list(result)
|
||||
if not isinstance(e, self.quiet_exceptions):
|
||||
app_log.error("Multiple exceptions in yield list",
|
||||
exc_info=True)
|
||||
if exc_info is not None:
|
||||
raise_exc_info(exc_info)
|
||||
if self.keys is not None:
|
||||
return dict(zip(self.keys, result_list))
|
||||
else:
|
||||
return list(result_list)
|
||||
|
||||
|
||||
def multi_future(children):
|
||||
def multi_future(children, quiet_exceptions=()):
|
||||
"""Wait for multiple asynchronous futures in parallel.
|
||||
|
||||
Takes a list of ``Futures`` (but *not* other ``YieldPoints``) and returns
|
||||
|
@ -453,12 +632,21 @@ def multi_future(children):
|
|||
Futures, in which case a parallel dictionary is returned mapping the same
|
||||
keys to their results.
|
||||
|
||||
It is not necessary to call `multi_future` explcitly, since the engine will
|
||||
do so automatically when the generator yields a list of `Futures`.
|
||||
This function is faster than the `Multi` `YieldPoint` because it does not
|
||||
require the creation of a stack context.
|
||||
It is not normally necessary to call `multi_future` explcitly,
|
||||
since the engine will do so automatically when the generator
|
||||
yields a list of ``Futures``. However, calling it directly
|
||||
allows you to use the ``quiet_exceptions`` argument to control
|
||||
the logging of multiple exceptions.
|
||||
|
||||
This function is faster than the `Multi` `YieldPoint` because it
|
||||
does not require the creation of a stack context.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
If multiple ``Futures`` fail, any exceptions after the first (which is
|
||||
raised) will be logged. Added the ``quiet_exceptions``
|
||||
argument to suppress this logging for selected exception types.
|
||||
"""
|
||||
if isinstance(children, dict):
|
||||
keys = list(children.keys())
|
||||
|
@ -471,19 +659,31 @@ def multi_future(children):
|
|||
future = Future()
|
||||
if not children:
|
||||
future.set_result({} if keys is not None else [])
|
||||
|
||||
def callback(f):
|
||||
unfinished_children.remove(f)
|
||||
if not unfinished_children:
|
||||
result_list = []
|
||||
for f in children:
|
||||
try:
|
||||
result_list = [i.result() for i in children]
|
||||
except Exception:
|
||||
future.set_exc_info(sys.exc_info())
|
||||
result_list.append(f.result())
|
||||
except Exception as e:
|
||||
if future.done():
|
||||
if not isinstance(e, quiet_exceptions):
|
||||
app_log.error("Multiple exceptions in yield list",
|
||||
exc_info=True)
|
||||
else:
|
||||
future.set_exc_info(sys.exc_info())
|
||||
if not future.done():
|
||||
if keys is not None:
|
||||
future.set_result(dict(zip(keys, result_list)))
|
||||
else:
|
||||
future.set_result(result_list)
|
||||
|
||||
listening = set()
|
||||
for f in children:
|
||||
if f not in listening:
|
||||
listening.add(f)
|
||||
f.add_done_callback(callback)
|
||||
return future
|
||||
|
||||
|
@ -504,7 +704,7 @@ def maybe_future(x):
|
|||
return fut
|
||||
|
||||
|
||||
def with_timeout(timeout, future, io_loop=None):
|
||||
def with_timeout(timeout, future, io_loop=None, quiet_exceptions=()):
|
||||
"""Wraps a `.Future` in a timeout.
|
||||
|
||||
Raises `TimeoutError` if the input future does not complete before
|
||||
|
@ -512,9 +712,17 @@ def with_timeout(timeout, future, io_loop=None):
|
|||
`.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or an absolute time
|
||||
relative to `.IOLoop.time`)
|
||||
|
||||
If the wrapped `.Future` fails after it has timed out, the exception
|
||||
will be logged unless it is of a type contained in ``quiet_exceptions``
|
||||
(which may be an exception type or a sequence of types).
|
||||
|
||||
Currently only supports Futures, not other `YieldPoint` classes.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
Added the ``quiet_exceptions`` argument and the logging of unhandled
|
||||
exceptions.
|
||||
"""
|
||||
# TODO: allow yield points in addition to futures?
|
||||
# Tricky to do with stack_context semantics.
|
||||
|
@ -528,9 +736,21 @@ def with_timeout(timeout, future, io_loop=None):
|
|||
chain_future(future, result)
|
||||
if io_loop is None:
|
||||
io_loop = IOLoop.current()
|
||||
|
||||
def error_callback(future):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
if not isinstance(e, quiet_exceptions):
|
||||
app_log.error("Exception in Future %r after timeout",
|
||||
future, exc_info=True)
|
||||
|
||||
def timeout_callback():
|
||||
result.set_exception(TimeoutError("Timeout"))
|
||||
# In case the wrapped future goes on to fail, log it.
|
||||
future.add_done_callback(error_callback)
|
||||
timeout_handle = io_loop.add_timeout(
|
||||
timeout,
|
||||
lambda: result.set_exception(TimeoutError("Timeout")))
|
||||
timeout, timeout_callback)
|
||||
if isinstance(future, Future):
|
||||
# We know this future will resolve on the IOLoop, so we don't
|
||||
# need the extra thread-safety of IOLoop.add_future (and we also
|
||||
|
@ -545,6 +765,25 @@ def with_timeout(timeout, future, io_loop=None):
|
|||
return result
|
||||
|
||||
|
||||
def sleep(duration):
|
||||
"""Return a `.Future` that resolves after the given number of seconds.
|
||||
|
||||
When used with ``yield`` in a coroutine, this is a non-blocking
|
||||
analogue to `time.sleep` (which should not be used in coroutines
|
||||
because it is blocking)::
|
||||
|
||||
yield gen.sleep(0.5)
|
||||
|
||||
Note that calling this function on its own does nothing; you must
|
||||
wait on the `.Future` it returns (usually by yielding it).
|
||||
|
||||
.. versionadded:: 4.1
|
||||
"""
|
||||
f = Future()
|
||||
IOLoop.current().call_later(duration, lambda: f.set_result(None))
|
||||
return f
|
||||
|
||||
|
||||
_null_future = Future()
|
||||
_null_future.set_result(None)
|
||||
|
||||
|
@ -638,13 +877,20 @@ class Runner(object):
|
|||
self.future = None
|
||||
try:
|
||||
orig_stack_contexts = stack_context._state.contexts
|
||||
exc_info = None
|
||||
|
||||
try:
|
||||
value = future.result()
|
||||
except Exception:
|
||||
self.had_exception = True
|
||||
yielded = self.gen.throw(*sys.exc_info())
|
||||
exc_info = sys.exc_info()
|
||||
|
||||
if exc_info is not None:
|
||||
yielded = self.gen.throw(*exc_info)
|
||||
exc_info = None
|
||||
else:
|
||||
yielded = self.gen.send(value)
|
||||
|
||||
if stack_context._state.contexts is not orig_stack_contexts:
|
||||
self.gen.throw(
|
||||
stack_context.StackContextInconsistentError(
|
||||
|
@ -678,19 +924,20 @@ class Runner(object):
|
|||
self.running = False
|
||||
|
||||
def handle_yield(self, yielded):
|
||||
if isinstance(yielded, list):
|
||||
if all(is_future(f) for f in yielded):
|
||||
yielded = multi_future(yielded)
|
||||
else:
|
||||
# Lists containing YieldPoints require stack contexts;
|
||||
# other lists are handled via multi_future in convert_yielded.
|
||||
if (isinstance(yielded, list) and
|
||||
any(isinstance(f, YieldPoint) for f in yielded)):
|
||||
yielded = Multi(yielded)
|
||||
elif isinstance(yielded, dict):
|
||||
if all(is_future(f) for f in yielded.values()):
|
||||
yielded = multi_future(yielded)
|
||||
else:
|
||||
elif (isinstance(yielded, dict) and
|
||||
any(isinstance(f, YieldPoint) for f in yielded.values())):
|
||||
yielded = Multi(yielded)
|
||||
|
||||
if isinstance(yielded, YieldPoint):
|
||||
# YieldPoints are too closely coupled to the Runner to go
|
||||
# through the generic convert_yielded mechanism.
|
||||
self.future = TracebackFuture()
|
||||
|
||||
def start_yield_point():
|
||||
try:
|
||||
yielded.start(self)
|
||||
|
@ -702,12 +949,14 @@ class Runner(object):
|
|||
except Exception:
|
||||
self.future = TracebackFuture()
|
||||
self.future.set_exc_info(sys.exc_info())
|
||||
|
||||
if self.stack_context_deactivate is None:
|
||||
# Start a stack context if this is the first
|
||||
# YieldPoint we've seen.
|
||||
with stack_context.ExceptionStackContext(
|
||||
self.handle_exception) as deactivate:
|
||||
self.stack_context_deactivate = deactivate
|
||||
|
||||
def cb():
|
||||
start_yield_point()
|
||||
self.run()
|
||||
|
@ -715,16 +964,17 @@ class Runner(object):
|
|||
return False
|
||||
else:
|
||||
start_yield_point()
|
||||
elif is_future(yielded):
|
||||
self.future = yielded
|
||||
else:
|
||||
try:
|
||||
self.future = convert_yielded(yielded)
|
||||
except BadYieldError:
|
||||
self.future = TracebackFuture()
|
||||
self.future.set_exc_info(sys.exc_info())
|
||||
|
||||
if not self.future.done() or self.future is moment:
|
||||
self.io_loop.add_future(
|
||||
self.future, lambda f: self.run())
|
||||
return False
|
||||
else:
|
||||
self.future = TracebackFuture()
|
||||
self.future.set_exception(BadYieldError(
|
||||
"yielded unknown object %r" % (yielded,)))
|
||||
return True
|
||||
|
||||
def result_callback(self, key):
|
||||
|
@ -763,3 +1013,30 @@ def _argument_adapter(callback):
|
|||
else:
|
||||
callback(None)
|
||||
return wrapper
|
||||
|
||||
|
||||
def convert_yielded(yielded):
|
||||
"""Convert a yielded object into a `.Future`.
|
||||
|
||||
The default implementation accepts lists, dictionaries, and Futures.
|
||||
|
||||
If the `~functools.singledispatch` library is available, this function
|
||||
may be extended to support additional types. For example::
|
||||
|
||||
@convert_yielded.register(asyncio.Future)
|
||||
def _(asyncio_future):
|
||||
return tornado.platform.asyncio.to_tornado_future(asyncio_future)
|
||||
|
||||
.. versionadded:: 4.1
|
||||
"""
|
||||
# Lists and dicts containing YieldPoints were handled separately
|
||||
# via Multi().
|
||||
if isinstance(yielded, (list, dict)):
|
||||
return multi_future(yielded)
|
||||
elif is_future(yielded):
|
||||
return yielded
|
||||
else:
|
||||
raise BadYieldError("yielded unknown object %r" % (yielded,))
|
||||
|
||||
if singledispatch is not None:
|
||||
convert_yielded = singledispatch(convert_yielded)
|
||||
|
|
|
@ -37,6 +37,7 @@ class _QuietException(Exception):
|
|||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class _ExceptionLoggingContext(object):
|
||||
"""Used with the ``with`` statement when calling delegate methods to
|
||||
log any exceptions with the given logger. Any exceptions caught are
|
||||
|
@ -53,6 +54,7 @@ class _ExceptionLoggingContext(object):
|
|||
self.logger.error("Uncaught exception", exc_info=(typ, value, tb))
|
||||
raise _QuietException
|
||||
|
||||
|
||||
class HTTP1ConnectionParameters(object):
|
||||
"""Parameters for `.HTTP1Connection` and `.HTTP1ServerConnection`.
|
||||
"""
|
||||
|
@ -162,7 +164,8 @@ class HTTP1Connection(httputil.HTTPConnection):
|
|||
header_data = yield gen.with_timeout(
|
||||
self.stream.io_loop.time() + self.params.header_timeout,
|
||||
header_future,
|
||||
io_loop=self.stream.io_loop)
|
||||
io_loop=self.stream.io_loop,
|
||||
quiet_exceptions=iostream.StreamClosedError)
|
||||
except gen.TimeoutError:
|
||||
self.close()
|
||||
raise gen.Return(False)
|
||||
|
@ -221,7 +224,8 @@ class HTTP1Connection(httputil.HTTPConnection):
|
|||
try:
|
||||
yield gen.with_timeout(
|
||||
self.stream.io_loop.time() + self._body_timeout,
|
||||
body_future, self.stream.io_loop)
|
||||
body_future, self.stream.io_loop,
|
||||
quiet_exceptions=iostream.StreamClosedError)
|
||||
except gen.TimeoutError:
|
||||
gen_log.info("Timeout reading body from %s",
|
||||
self.context)
|
||||
|
@ -326,8 +330,10 @@ class HTTP1Connection(httputil.HTTPConnection):
|
|||
|
||||
def write_headers(self, start_line, headers, chunk=None, callback=None):
|
||||
"""Implements `.HTTPConnection.write_headers`."""
|
||||
lines = []
|
||||
if self.is_client:
|
||||
self._request_start_line = start_line
|
||||
lines.append(utf8('%s %s HTTP/1.1' % (start_line[0], start_line[1])))
|
||||
# Client requests with a non-empty body must have either a
|
||||
# Content-Length or a Transfer-Encoding.
|
||||
self._chunking_output = (
|
||||
|
@ -336,6 +342,7 @@ class HTTP1Connection(httputil.HTTPConnection):
|
|||
'Transfer-Encoding' not in headers)
|
||||
else:
|
||||
self._response_start_line = start_line
|
||||
lines.append(utf8('HTTP/1.1 %s %s' % (start_line[1], start_line[2])))
|
||||
self._chunking_output = (
|
||||
# TODO: should this use
|
||||
# self._request_start_line.version or
|
||||
|
@ -365,7 +372,6 @@ class HTTP1Connection(httputil.HTTPConnection):
|
|||
self._expected_content_remaining = int(headers['Content-Length'])
|
||||
else:
|
||||
self._expected_content_remaining = None
|
||||
lines = [utf8("%s %s %s" % start_line)]
|
||||
lines.extend([utf8(n) + b": " + utf8(v) for n, v in headers.get_all()])
|
||||
for line in lines:
|
||||
if b'\n' in line:
|
||||
|
@ -374,6 +380,7 @@ class HTTP1Connection(httputil.HTTPConnection):
|
|||
if self.stream.closed():
|
||||
future = self._write_future = Future()
|
||||
future.set_exception(iostream.StreamClosedError())
|
||||
future.exception()
|
||||
else:
|
||||
if callback is not None:
|
||||
self._write_callback = stack_context.wrap(callback)
|
||||
|
@ -412,6 +419,7 @@ class HTTP1Connection(httputil.HTTPConnection):
|
|||
if self.stream.closed():
|
||||
future = self._write_future = Future()
|
||||
self._write_future.set_exception(iostream.StreamClosedError())
|
||||
self._write_future.exception()
|
||||
else:
|
||||
if callback is not None:
|
||||
self._write_callback = stack_context.wrap(callback)
|
||||
|
@ -451,6 +459,9 @@ class HTTP1Connection(httputil.HTTPConnection):
|
|||
self._pending_write.add_done_callback(self._finish_request)
|
||||
|
||||
def _on_write_complete(self, future):
|
||||
exc = future.exception()
|
||||
if exc is not None and not isinstance(exc, iostream.StreamClosedError):
|
||||
future.result()
|
||||
if self._write_callback is not None:
|
||||
callback = self._write_callback
|
||||
self._write_callback = None
|
||||
|
@ -491,8 +502,9 @@ class HTTP1Connection(httputil.HTTPConnection):
|
|||
# we SHOULD ignore at least one empty line before the request.
|
||||
# http://tools.ietf.org/html/rfc7230#section-3.5
|
||||
data = native_str(data.decode('latin1')).lstrip("\r\n")
|
||||
eol = data.find("\r\n")
|
||||
start_line = data[:eol]
|
||||
# RFC 7230 section allows for both CRLF and bare LF.
|
||||
eol = data.find("\n")
|
||||
start_line = data[:eol].rstrip("\r")
|
||||
try:
|
||||
headers = httputil.HTTPHeaders.parse(data[eol:])
|
||||
except ValueError:
|
||||
|
@ -686,8 +698,7 @@ class HTTP1ServerConnection(object):
|
|||
# This exception was already logged.
|
||||
conn.close()
|
||||
return
|
||||
except Exception as e:
|
||||
if 1 != e.errno:
|
||||
except Exception:
|
||||
gen_log.error("Uncaught exception", exc_info=True)
|
||||
conn.close()
|
||||
return
|
||||
|
|
|
@ -72,7 +72,7 @@ class HTTPClient(object):
|
|||
http_client.close()
|
||||
"""
|
||||
def __init__(self, async_client_class=None, **kwargs):
|
||||
self._io_loop = IOLoop()
|
||||
self._io_loop = IOLoop(make_current=False)
|
||||
if async_client_class is None:
|
||||
async_client_class = AsyncHTTPClient
|
||||
self._async_client = async_client_class(self._io_loop, **kwargs)
|
||||
|
@ -95,11 +95,11 @@ class HTTPClient(object):
|
|||
If it is a string, we construct an `HTTPRequest` using any additional
|
||||
kwargs: ``HTTPRequest(request, **kwargs)``
|
||||
|
||||
If an error occurs during the fetch, we raise an `HTTPError`.
|
||||
If an error occurs during the fetch, we raise an `HTTPError` unless
|
||||
the ``raise_error`` keyword argument is set to False.
|
||||
"""
|
||||
response = self._io_loop.run_sync(functools.partial(
|
||||
self._async_client.fetch, request, **kwargs))
|
||||
response.rethrow()
|
||||
return response
|
||||
|
||||
|
||||
|
@ -136,6 +136,9 @@ class AsyncHTTPClient(Configurable):
|
|||
# or with force_instance:
|
||||
client = AsyncHTTPClient(force_instance=True,
|
||||
defaults=dict(user_agent="MyUserAgent"))
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
The ``io_loop`` argument is deprecated.
|
||||
"""
|
||||
@classmethod
|
||||
def configurable_base(cls):
|
||||
|
@ -200,7 +203,7 @@ class AsyncHTTPClient(Configurable):
|
|||
raise RuntimeError("inconsistent AsyncHTTPClient cache")
|
||||
del self._instance_cache[self.io_loop]
|
||||
|
||||
def fetch(self, request, callback=None, **kwargs):
|
||||
def fetch(self, request, callback=None, raise_error=True, **kwargs):
|
||||
"""Executes a request, asynchronously returning an `HTTPResponse`.
|
||||
|
||||
The request may be either a string URL or an `HTTPRequest` object.
|
||||
|
@ -208,8 +211,10 @@ class AsyncHTTPClient(Configurable):
|
|||
kwargs: ``HTTPRequest(request, **kwargs)``
|
||||
|
||||
This method returns a `.Future` whose result is an
|
||||
`HTTPResponse`. The ``Future`` will raise an `HTTPError` if
|
||||
the request returned a non-200 response code.
|
||||
`HTTPResponse`. By default, the ``Future`` will raise an `HTTPError`
|
||||
if the request returned a non-200 response code. Instead, if
|
||||
``raise_error`` is set to False, the response will always be
|
||||
returned regardless of the response code.
|
||||
|
||||
If a ``callback`` is given, it will be invoked with the `HTTPResponse`.
|
||||
In the callback interface, `HTTPError` is not automatically raised.
|
||||
|
@ -243,7 +248,7 @@ class AsyncHTTPClient(Configurable):
|
|||
future.add_done_callback(handle_future)
|
||||
|
||||
def handle_response(response):
|
||||
if response.error:
|
||||
if raise_error and response.error:
|
||||
future.set_exception(response.error)
|
||||
else:
|
||||
future.set_result(response)
|
||||
|
@ -304,7 +309,8 @@ class HTTPRequest(object):
|
|||
validate_cert=None, ca_certs=None,
|
||||
allow_ipv6=None,
|
||||
client_key=None, client_cert=None, body_producer=None,
|
||||
expect_100_continue=False, decompress_response=None):
|
||||
expect_100_continue=False, decompress_response=None,
|
||||
ssl_options=None):
|
||||
r"""All parameters except ``url`` are optional.
|
||||
|
||||
:arg string url: URL to fetch
|
||||
|
@ -374,12 +380,15 @@ class HTTPRequest(object):
|
|||
:arg string ca_certs: filename of CA certificates in PEM format,
|
||||
or None to use defaults. See note below when used with
|
||||
``curl_httpclient``.
|
||||
:arg bool allow_ipv6: Use IPv6 when available? Default is false in
|
||||
``simple_httpclient`` and true in ``curl_httpclient``
|
||||
:arg string client_key: Filename for client SSL key, if any. See
|
||||
note below when used with ``curl_httpclient``.
|
||||
:arg string client_cert: Filename for client SSL certificate, if any.
|
||||
See note below when used with ``curl_httpclient``.
|
||||
:arg ssl.SSLContext ssl_options: `ssl.SSLContext` object for use in
|
||||
``simple_httpclient`` (unsupported by ``curl_httpclient``).
|
||||
Overrides ``validate_cert``, ``ca_certs``, ``client_key``,
|
||||
and ``client_cert``.
|
||||
:arg bool allow_ipv6: Use IPv6 when available? Default is true.
|
||||
:arg bool expect_100_continue: If true, send the
|
||||
``Expect: 100-continue`` header and wait for a continue response
|
||||
before sending the request body. Only supported with
|
||||
|
@ -402,6 +411,9 @@ class HTTPRequest(object):
|
|||
|
||||
.. versionadded:: 4.0
|
||||
The ``body_producer`` and ``expect_100_continue`` arguments.
|
||||
|
||||
.. versionadded:: 4.2
|
||||
The ``ssl_options`` argument.
|
||||
"""
|
||||
# Note that some of these attributes go through property setters
|
||||
# defined below.
|
||||
|
@ -439,6 +451,7 @@ class HTTPRequest(object):
|
|||
self.allow_ipv6 = allow_ipv6
|
||||
self.client_key = client_key
|
||||
self.client_cert = client_cert
|
||||
self.ssl_options = ssl_options
|
||||
self.expect_100_continue = expect_100_continue
|
||||
self.start_time = time.time()
|
||||
|
||||
|
|
|
@ -37,35 +37,17 @@ from tornado import httputil
|
|||
from tornado import iostream
|
||||
from tornado import netutil
|
||||
from tornado.tcpserver import TCPServer
|
||||
from tornado.util import Configurable
|
||||
|
||||
|
||||
class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
|
||||
class HTTPServer(TCPServer, Configurable,
|
||||
httputil.HTTPServerConnectionDelegate):
|
||||
r"""A non-blocking, single-threaded HTTP server.
|
||||
|
||||
A server is defined by either a request callback that takes a
|
||||
`.HTTPServerRequest` as an argument or a `.HTTPServerConnectionDelegate`
|
||||
instance.
|
||||
|
||||
A simple example server that echoes back the URI you requested::
|
||||
|
||||
import tornado.httpserver
|
||||
import tornado.ioloop
|
||||
from tornado import httputil
|
||||
|
||||
def handle_request(request):
|
||||
message = "You requested %s\n" % request.uri
|
||||
request.connection.write_headers(
|
||||
httputil.ResponseStartLine('HTTP/1.1', 200, 'OK'),
|
||||
{"Content-Length": str(len(message))})
|
||||
request.connection.write(message)
|
||||
request.connection.finish()
|
||||
|
||||
http_server = tornado.httpserver.HTTPServer(handle_request)
|
||||
http_server.listen(8888)
|
||||
tornado.ioloop.IOLoop.instance().start()
|
||||
|
||||
Applications should use the methods of `.HTTPConnection` to write
|
||||
their response.
|
||||
A server is defined by a subclass of `.HTTPServerConnectionDelegate`,
|
||||
or, for backwards compatibility, a callback that takes an
|
||||
`.HTTPServerRequest` as an argument. The delegate is usually a
|
||||
`tornado.web.Application`.
|
||||
|
||||
`HTTPServer` supports keep-alive connections by default
|
||||
(automatically for HTTP/1.1, or for HTTP/1.0 when the client
|
||||
|
@ -80,15 +62,15 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
|
|||
if Tornado is run behind an SSL-decoding proxy that does not set one of
|
||||
the supported ``xheaders``.
|
||||
|
||||
To make this server serve SSL traffic, send the ``ssl_options`` dictionary
|
||||
argument with the arguments required for the `ssl.wrap_socket` method,
|
||||
including ``certfile`` and ``keyfile``. (In Python 3.2+ you can pass
|
||||
an `ssl.SSLContext` object instead of a dict)::
|
||||
To make this server serve SSL traffic, send the ``ssl_options`` keyword
|
||||
argument with an `ssl.SSLContext` object. For compatibility with older
|
||||
versions of Python ``ssl_options`` may also be a dictionary of keyword
|
||||
arguments for the `ssl.wrap_socket` method.::
|
||||
|
||||
HTTPServer(applicaton, ssl_options={
|
||||
"certfile": os.path.join(data_dir, "mydomain.crt"),
|
||||
"keyfile": os.path.join(data_dir, "mydomain.key"),
|
||||
})
|
||||
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"),
|
||||
os.path.join(data_dir, "mydomain.key"))
|
||||
HTTPServer(applicaton, ssl_options=ssl_ctx)
|
||||
|
||||
`HTTPServer` initialization follows one of three patterns (the
|
||||
initialization methods are defined on `tornado.tcpserver.TCPServer`):
|
||||
|
@ -97,7 +79,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
|
|||
|
||||
server = HTTPServer(app)
|
||||
server.listen(8888)
|
||||
IOLoop.instance().start()
|
||||
IOLoop.current().start()
|
||||
|
||||
In many cases, `tornado.web.Application.listen` can be used to avoid
|
||||
the need to explicitly create the `HTTPServer`.
|
||||
|
@ -108,7 +90,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
|
|||
server = HTTPServer(app)
|
||||
server.bind(8888)
|
||||
server.start(0) # Forks multiple sub-processes
|
||||
IOLoop.instance().start()
|
||||
IOLoop.current().start()
|
||||
|
||||
When using this interface, an `.IOLoop` must *not* be passed
|
||||
to the `HTTPServer` constructor. `~.TCPServer.start` will always start
|
||||
|
@ -120,7 +102,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
|
|||
tornado.process.fork_processes(0)
|
||||
server = HTTPServer(app)
|
||||
server.add_sockets(sockets)
|
||||
IOLoop.instance().start()
|
||||
IOLoop.current().start()
|
||||
|
||||
The `~.TCPServer.add_sockets` interface is more complicated,
|
||||
but it can be used with `tornado.process.fork_processes` to
|
||||
|
@ -134,8 +116,24 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
|
|||
``idle_connection_timeout``, ``body_timeout``, ``max_body_size``
|
||||
arguments. Added support for `.HTTPServerConnectionDelegate`
|
||||
instances as ``request_callback``.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
`.HTTPServerConnectionDelegate.start_request` is now called with
|
||||
two arguments ``(server_conn, request_conn)`` (in accordance with the
|
||||
documentation) instead of one ``(request_conn)``.
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
`HTTPServer` is now a subclass of `tornado.util.Configurable`.
|
||||
"""
|
||||
def __init__(self, request_callback, no_keep_alive=False, io_loop=None,
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Ignore args to __init__; real initialization belongs in
|
||||
# initialize since we're Configurable. (there's something
|
||||
# weird in initialization order between this class,
|
||||
# Configurable, and TCPServer so we can't leave __init__ out
|
||||
# completely)
|
||||
pass
|
||||
|
||||
def initialize(self, request_callback, no_keep_alive=False, io_loop=None,
|
||||
xheaders=False, ssl_options=None, protocol=None,
|
||||
decompress_request=False,
|
||||
chunk_size=None, max_header_size=None,
|
||||
|
@ -157,6 +155,14 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
|
|||
read_chunk_size=chunk_size)
|
||||
self._connections = set()
|
||||
|
||||
@classmethod
|
||||
def configurable_base(cls):
|
||||
return HTTPServer
|
||||
|
||||
@classmethod
|
||||
def configurable_default(cls):
|
||||
return HTTPServer
|
||||
|
||||
@gen.coroutine
|
||||
def close_all_connections(self):
|
||||
while self._connections:
|
||||
|
@ -173,7 +179,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
|
|||
conn.start_serving(self)
|
||||
|
||||
def start_request(self, server_conn, request_conn):
|
||||
return _ServerRequestAdapter(self, request_conn)
|
||||
return _ServerRequestAdapter(self, server_conn, request_conn)
|
||||
|
||||
def on_close(self, server_conn):
|
||||
self._connections.remove(server_conn)
|
||||
|
@ -246,13 +252,14 @@ class _ServerRequestAdapter(httputil.HTTPMessageDelegate):
|
|||
"""Adapts the `HTTPMessageDelegate` interface to the interface expected
|
||||
by our clients.
|
||||
"""
|
||||
def __init__(self, server, connection):
|
||||
def __init__(self, server, server_conn, request_conn):
|
||||
self.server = server
|
||||
self.connection = connection
|
||||
self.connection = request_conn
|
||||
self.request = None
|
||||
if isinstance(server.request_callback,
|
||||
httputil.HTTPServerConnectionDelegate):
|
||||
self.delegate = server.request_callback.start_request(connection)
|
||||
self.delegate = server.request_callback.start_request(
|
||||
server_conn, request_conn)
|
||||
self._chunks = None
|
||||
else:
|
||||
self.delegate = None
|
||||
|
|
|
@ -62,6 +62,11 @@ except ImportError:
|
|||
pass
|
||||
|
||||
|
||||
# RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line
|
||||
# terminator and ignore any preceding CR.
|
||||
_CRLF_RE = re.compile(r'\r?\n')
|
||||
|
||||
|
||||
class _NormalizedHeaderCache(dict):
|
||||
"""Dynamic cached mapping of header names to Http-Header-Case.
|
||||
|
||||
|
@ -193,7 +198,7 @@ class HTTPHeaders(dict):
|
|||
[('Content-Length', '42'), ('Content-Type', 'text/html')]
|
||||
"""
|
||||
h = cls()
|
||||
for line in headers.splitlines():
|
||||
for line in _CRLF_RE.split(headers):
|
||||
if line:
|
||||
h.parse_line(line)
|
||||
return h
|
||||
|
@ -229,6 +234,14 @@ class HTTPHeaders(dict):
|
|||
# default implementation returns dict(self), not the subclass
|
||||
return HTTPHeaders(self)
|
||||
|
||||
# Use our overridden copy method for the copy.copy module.
|
||||
__copy__ = copy
|
||||
|
||||
def __deepcopy__(self, memo_dict):
|
||||
# Our values are immutable strings, so our standard copy is
|
||||
# effectively a deep copy.
|
||||
return self.copy()
|
||||
|
||||
|
||||
class HTTPServerRequest(object):
|
||||
"""A single HTTP request.
|
||||
|
@ -331,7 +344,7 @@ class HTTPServerRequest(object):
|
|||
self.uri = uri
|
||||
self.version = version
|
||||
self.headers = headers or HTTPHeaders()
|
||||
self.body = body or ""
|
||||
self.body = body or b""
|
||||
|
||||
# set remote IP and protocol
|
||||
context = getattr(connection, 'context', None)
|
||||
|
@ -380,6 +393,8 @@ class HTTPServerRequest(object):
|
|||
to write the response.
|
||||
"""
|
||||
assert isinstance(chunk, bytes)
|
||||
assert self.version.startswith("HTTP/1."), \
|
||||
"deprecated interface ony supported in HTTP/1.x"
|
||||
self.connection.write(chunk, callback=callback)
|
||||
|
||||
def finish(self):
|
||||
|
@ -406,15 +421,14 @@ class HTTPServerRequest(object):
|
|||
def get_ssl_certificate(self, binary_form=False):
|
||||
"""Returns the client's SSL certificate, if any.
|
||||
|
||||
To use client certificates, the HTTPServer must have been constructed
|
||||
with cert_reqs set in ssl_options, e.g.::
|
||||
To use client certificates, the HTTPServer's
|
||||
`ssl.SSLContext.verify_mode` field must be set, e.g.::
|
||||
|
||||
server = HTTPServer(app,
|
||||
ssl_options=dict(
|
||||
certfile="foo.crt",
|
||||
keyfile="foo.key",
|
||||
cert_reqs=ssl.CERT_REQUIRED,
|
||||
ca_certs="cacert.crt"))
|
||||
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
ssl_ctx.load_cert_chain("foo.crt", "foo.key")
|
||||
ssl_ctx.load_verify_locations("cacerts.pem")
|
||||
ssl_ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
server = HTTPServer(app, ssl_options=ssl_ctx)
|
||||
|
||||
By default, the return value is a dictionary (or None, if no
|
||||
client certificate is present). If ``binary_form`` is true, a
|
||||
|
@ -543,6 +557,8 @@ class HTTPConnection(object):
|
|||
headers.
|
||||
:arg callback: a callback to be run when the write is complete.
|
||||
|
||||
The ``version`` field of ``start_line`` is ignored.
|
||||
|
||||
Returns a `.Future` if no callback is given.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
@ -689,6 +705,7 @@ def parse_body_arguments(content_type, body, arguments, files, headers=None):
|
|||
if values:
|
||||
arguments.setdefault(name, []).extend(values)
|
||||
elif content_type.startswith("multipart/form-data"):
|
||||
try:
|
||||
fields = content_type.split(";")
|
||||
for field in fields:
|
||||
k, sep, v = field.strip().partition("=")
|
||||
|
@ -696,7 +713,9 @@ def parse_body_arguments(content_type, body, arguments, files, headers=None):
|
|||
parse_multipart_form_data(utf8(v), body, arguments, files)
|
||||
break
|
||||
else:
|
||||
gen_log.warning("Invalid multipart/form-data")
|
||||
raise ValueError("multipart boundary not found")
|
||||
except Exception as e:
|
||||
gen_log.warning("Invalid multipart/form-data: %s", e)
|
||||
|
||||
|
||||
def parse_multipart_form_data(boundary, data, arguments, files):
|
||||
|
@ -782,7 +801,7 @@ def parse_request_start_line(line):
|
|||
method, path, version = line.split(" ")
|
||||
except ValueError:
|
||||
raise HTTPInputError("Malformed HTTP request line")
|
||||
if not version.startswith("HTTP/"):
|
||||
if not re.match(r"^HTTP/1\.[0-9]$", version):
|
||||
raise HTTPInputError(
|
||||
"Malformed HTTP version in HTTP Request-Line: %r" % version)
|
||||
return RequestStartLine(method, path, version)
|
||||
|
@ -801,7 +820,7 @@ def parse_response_start_line(line):
|
|||
ResponseStartLine(version='HTTP/1.1', code=200, reason='OK')
|
||||
"""
|
||||
line = native_str(line)
|
||||
match = re.match("(HTTP/1.[01]) ([0-9]+) ([^\r]*)", line)
|
||||
match = re.match("(HTTP/1.[0-9]) ([0-9]+) ([^\r]*)", line)
|
||||
if not match:
|
||||
raise HTTPInputError("Error parsing response start line")
|
||||
return ResponseStartLine(match.group(1), int(match.group(2)),
|
||||
|
@ -873,3 +892,20 @@ def _encode_header(key, pdict):
|
|||
def doctests():
|
||||
import doctest
|
||||
return doctest.DocTestSuite()
|
||||
|
||||
|
||||
def split_host_and_port(netloc):
|
||||
"""Returns ``(host, port)`` tuple from ``netloc``.
|
||||
|
||||
Returned ``port`` will be ``None`` if not present.
|
||||
|
||||
.. versionadded:: 4.1
|
||||
"""
|
||||
match = re.match(r'^(.+):(\d+)$', netloc)
|
||||
if match:
|
||||
host = match.group(1)
|
||||
port = int(match.group(2))
|
||||
else:
|
||||
host = netloc
|
||||
port = None
|
||||
return (host, port)
|
||||
|
|
|
@ -41,6 +41,7 @@ import sys
|
|||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import math
|
||||
|
||||
from tornado.concurrent import TracebackFuture, is_future
|
||||
from tornado.log import app_log, gen_log
|
||||
|
@ -76,35 +77,52 @@ class IOLoop(Configurable):
|
|||
simultaneous connections, you should use a system that supports
|
||||
either ``epoll`` or ``kqueue``.
|
||||
|
||||
Example usage for a simple TCP server::
|
||||
Example usage for a simple TCP server:
|
||||
|
||||
.. testcode::
|
||||
|
||||
import errno
|
||||
import functools
|
||||
import ioloop
|
||||
import tornado.ioloop
|
||||
import socket
|
||||
|
||||
def connection_ready(sock, fd, events):
|
||||
while True:
|
||||
try:
|
||||
connection, address = sock.accept()
|
||||
except socket.error, e:
|
||||
except socket.error as e:
|
||||
if e.args[0] not in (errno.EWOULDBLOCK, errno.EAGAIN):
|
||||
raise
|
||||
return
|
||||
connection.setblocking(0)
|
||||
handle_connection(connection, address)
|
||||
|
||||
if __name__ == '__main__':
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.setblocking(0)
|
||||
sock.bind(("", port))
|
||||
sock.listen(128)
|
||||
|
||||
io_loop = ioloop.IOLoop.instance()
|
||||
io_loop = tornado.ioloop.IOLoop.current()
|
||||
callback = functools.partial(connection_ready, sock)
|
||||
io_loop.add_handler(sock.fileno(), callback, io_loop.READ)
|
||||
io_loop.start()
|
||||
|
||||
.. testoutput::
|
||||
:hide:
|
||||
|
||||
By default, a newly-constructed `IOLoop` becomes the thread's current
|
||||
`IOLoop`, unless there already is a current `IOLoop`. This behavior
|
||||
can be controlled with the ``make_current`` argument to the `IOLoop`
|
||||
constructor: if ``make_current=True``, the new `IOLoop` will always
|
||||
try to become current and it raises an error if there is already a
|
||||
current instance. If ``make_current=False``, the new `IOLoop` will
|
||||
not try to become current.
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
Added the ``make_current`` keyword argument to the `IOLoop`
|
||||
constructor.
|
||||
"""
|
||||
# Constants from the epoll module
|
||||
_EPOLLIN = 0x001
|
||||
|
@ -133,7 +151,8 @@ class IOLoop(Configurable):
|
|||
|
||||
Most applications have a single, global `IOLoop` running on the
|
||||
main thread. Use this method to get this instance from
|
||||
another thread. To get the current thread's `IOLoop`, use `current()`.
|
||||
another thread. In most other cases, it is better to use `current()`
|
||||
to get the current thread's `IOLoop`.
|
||||
"""
|
||||
if not hasattr(IOLoop, "_instance"):
|
||||
with IOLoop._instance_lock:
|
||||
|
@ -167,28 +186,26 @@ class IOLoop(Configurable):
|
|||
del IOLoop._instance
|
||||
|
||||
@staticmethod
|
||||
def current():
|
||||
def current(instance=True):
|
||||
"""Returns the current thread's `IOLoop`.
|
||||
|
||||
If an `IOLoop` is currently running or has been marked as current
|
||||
by `make_current`, returns that instance. Otherwise returns
|
||||
`IOLoop.instance()`, i.e. the main thread's `IOLoop`.
|
||||
|
||||
A common pattern for classes that depend on ``IOLoops`` is to use
|
||||
a default argument to enable programs with multiple ``IOLoops``
|
||||
but not require the argument for simpler applications::
|
||||
|
||||
class MyClass(object):
|
||||
def __init__(self, io_loop=None):
|
||||
self.io_loop = io_loop or IOLoop.current()
|
||||
If an `IOLoop` is currently running or has been marked as
|
||||
current by `make_current`, returns that instance. If there is
|
||||
no current `IOLoop`, returns `IOLoop.instance()` (i.e. the
|
||||
main thread's `IOLoop`, creating one if necessary) if ``instance``
|
||||
is true.
|
||||
|
||||
In general you should use `IOLoop.current` as the default when
|
||||
constructing an asynchronous object, and use `IOLoop.instance`
|
||||
when you mean to communicate to the main thread from a different
|
||||
one.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
Added ``instance`` argument to control the fallback to
|
||||
`IOLoop.instance()`.
|
||||
"""
|
||||
current = getattr(IOLoop._current, "instance", None)
|
||||
if current is None:
|
||||
if current is None and instance:
|
||||
return IOLoop.instance()
|
||||
return current
|
||||
|
||||
|
@ -200,6 +217,10 @@ class IOLoop(Configurable):
|
|||
`make_current` explicitly before starting the `IOLoop`,
|
||||
so that code run at startup time can find the right
|
||||
instance.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
An `IOLoop` created while there is no current `IOLoop`
|
||||
will automatically become current.
|
||||
"""
|
||||
IOLoop._current.instance = self
|
||||
|
||||
|
@ -223,8 +244,14 @@ class IOLoop(Configurable):
|
|||
from tornado.platform.select import SelectIOLoop
|
||||
return SelectIOLoop
|
||||
|
||||
def initialize(self):
|
||||
pass
|
||||
def initialize(self, make_current=None):
|
||||
if make_current is None:
|
||||
if IOLoop.current(instance=False) is None:
|
||||
self.make_current()
|
||||
elif make_current:
|
||||
if IOLoop.current(instance=False) is None:
|
||||
raise RuntimeError("current IOLoop already exists")
|
||||
self.make_current()
|
||||
|
||||
def close(self, all_fds=False):
|
||||
"""Closes the `IOLoop`, freeing any resources used.
|
||||
|
@ -390,7 +417,7 @@ class IOLoop(Configurable):
|
|||
# do stuff...
|
||||
|
||||
if __name__ == '__main__':
|
||||
IOLoop.instance().run_sync(main)
|
||||
IOLoop.current().run_sync(main)
|
||||
"""
|
||||
future_cell = [None]
|
||||
|
||||
|
@ -633,8 +660,8 @@ class PollIOLoop(IOLoop):
|
|||
(Linux), `tornado.platform.kqueue.KQueueIOLoop` (BSD and Mac), or
|
||||
`tornado.platform.select.SelectIOLoop` (all platforms).
|
||||
"""
|
||||
def initialize(self, impl, time_func=None):
|
||||
super(PollIOLoop, self).initialize()
|
||||
def initialize(self, impl, time_func=None, **kwargs):
|
||||
super(PollIOLoop, self).initialize(**kwargs)
|
||||
self._impl = impl
|
||||
if hasattr(self._impl, 'fileno'):
|
||||
set_close_exec(self._impl.fileno())
|
||||
|
@ -739,8 +766,10 @@ class PollIOLoop(IOLoop):
|
|||
# IOLoop is just started once at the beginning.
|
||||
signal.set_wakeup_fd(old_wakeup_fd)
|
||||
old_wakeup_fd = None
|
||||
except ValueError: # non-main thread
|
||||
pass
|
||||
except ValueError:
|
||||
# Non-main thread, or the previous value of wakeup_fd
|
||||
# is no longer valid.
|
||||
old_wakeup_fd = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
|
@ -944,8 +973,16 @@ class PeriodicCallback(object):
|
|||
"""Schedules the given callback to be called periodically.
|
||||
|
||||
The callback is called every ``callback_time`` milliseconds.
|
||||
Note that the timeout is given in milliseconds, while most other
|
||||
time-related functions in Tornado use seconds.
|
||||
|
||||
If the callback runs for longer than ``callback_time`` milliseconds,
|
||||
subsequent invocations will be skipped to get back on schedule.
|
||||
|
||||
`start` must be called after the `PeriodicCallback` is created.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
The ``io_loop`` argument is deprecated.
|
||||
"""
|
||||
def __init__(self, callback, callback_time, io_loop=None):
|
||||
self.callback = callback
|
||||
|
@ -969,6 +1006,13 @@ class PeriodicCallback(object):
|
|||
self.io_loop.remove_timeout(self._timeout)
|
||||
self._timeout = None
|
||||
|
||||
def is_running(self):
|
||||
"""Return True if this `.PeriodicCallback` has been started.
|
||||
|
||||
.. versionadded:: 4.1
|
||||
"""
|
||||
return self._running
|
||||
|
||||
def _run(self):
|
||||
if not self._running:
|
||||
return
|
||||
|
@ -982,6 +1026,9 @@ class PeriodicCallback(object):
|
|||
def _schedule_next(self):
|
||||
if self._running:
|
||||
current_time = self.io_loop.time()
|
||||
while self._next_timeout <= current_time:
|
||||
self._next_timeout += self.callback_time / 1000.0
|
||||
|
||||
if self._next_timeout <= current_time:
|
||||
callback_time_sec = self.callback_time / 1000.0
|
||||
self._next_timeout += (math.floor((current_time - self._next_timeout) / callback_time_sec) + 1) * callback_time_sec
|
||||
|
||||
self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run)
|
||||
|
|
|
@ -37,7 +37,7 @@ import re
|
|||
from tornado.concurrent import TracebackFuture
|
||||
from tornado import ioloop
|
||||
from tornado.log import gen_log, app_log
|
||||
from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError
|
||||
from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError, _client_ssl_defaults, _server_ssl_defaults
|
||||
from tornado import stack_context
|
||||
from tornado.util import errno_from_exception
|
||||
|
||||
|
@ -68,13 +68,21 @@ _ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE,
|
|||
if hasattr(errno, "WSAECONNRESET"):
|
||||
_ERRNO_CONNRESET += (errno.WSAECONNRESET, errno.WSAECONNABORTED, errno.WSAETIMEDOUT)
|
||||
|
||||
if sys.platform == 'darwin':
|
||||
# OSX appears to have a race condition that causes send(2) to return
|
||||
# EPROTOTYPE if called while a socket is being torn down:
|
||||
# http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/
|
||||
# Since the socket is being closed anyway, treat this as an ECONNRESET
|
||||
# instead of an unexpected error.
|
||||
_ERRNO_CONNRESET += (errno.EPROTOTYPE,)
|
||||
|
||||
# More non-portable errnos:
|
||||
_ERRNO_INPROGRESS = (errno.EINPROGRESS,)
|
||||
|
||||
if hasattr(errno, "WSAEINPROGRESS"):
|
||||
_ERRNO_INPROGRESS += (errno.WSAEINPROGRESS,)
|
||||
|
||||
#######################################################
|
||||
|
||||
class StreamClosedError(IOError):
|
||||
"""Exception raised by `IOStream` methods when the stream is closed.
|
||||
|
||||
|
@ -122,6 +130,7 @@ class BaseIOStream(object):
|
|||
"""`BaseIOStream` constructor.
|
||||
|
||||
:arg io_loop: The `.IOLoop` to use; defaults to `.IOLoop.current`.
|
||||
Deprecated since Tornado 4.1.
|
||||
:arg max_buffer_size: Maximum amount of incoming data to buffer;
|
||||
defaults to 100MB.
|
||||
:arg read_chunk_size: Amount of data to read at one time from the
|
||||
|
@ -160,6 +169,11 @@ class BaseIOStream(object):
|
|||
self._close_callback = None
|
||||
self._connect_callback = None
|
||||
self._connect_future = None
|
||||
# _ssl_connect_future should be defined in SSLIOStream
|
||||
# but it's here so we can clean it up in maybe_run_close_callback.
|
||||
# TODO: refactor that so subclasses can add additional futures
|
||||
# to be cancelled.
|
||||
self._ssl_connect_future = None
|
||||
self._connecting = False
|
||||
self._state = None
|
||||
self._pending_callbacks = 0
|
||||
|
@ -230,6 +244,12 @@ class BaseIOStream(object):
|
|||
gen_log.info("Unsatisfiable read, closing connection: %s" % e)
|
||||
self.close(exc_info=True)
|
||||
return future
|
||||
except:
|
||||
if future is not None:
|
||||
# Ensure that the future doesn't log an error because its
|
||||
# failure was never examined.
|
||||
future.add_done_callback(lambda f: f.exception())
|
||||
raise
|
||||
return future
|
||||
|
||||
def read_until(self, delimiter, callback=None, max_bytes=None):
|
||||
|
@ -257,6 +277,10 @@ class BaseIOStream(object):
|
|||
gen_log.info("Unsatisfiable read, closing connection: %s" % e)
|
||||
self.close(exc_info=True)
|
||||
return future
|
||||
except:
|
||||
if future is not None:
|
||||
future.add_done_callback(lambda f: f.exception())
|
||||
raise
|
||||
return future
|
||||
|
||||
def read_bytes(self, num_bytes, callback=None, streaming_callback=None,
|
||||
|
@ -281,7 +305,12 @@ class BaseIOStream(object):
|
|||
self._read_bytes = num_bytes
|
||||
self._read_partial = partial
|
||||
self._streaming_callback = stack_context.wrap(streaming_callback)
|
||||
try:
|
||||
self._try_inline_read()
|
||||
except:
|
||||
if future is not None:
|
||||
future.add_done_callback(lambda f: f.exception())
|
||||
raise
|
||||
return future
|
||||
|
||||
def read_until_close(self, callback=None, streaming_callback=None):
|
||||
|
@ -293,9 +322,16 @@ class BaseIOStream(object):
|
|||
If a callback is given, it will be run with the data as an argument;
|
||||
if not, this method returns a `.Future`.
|
||||
|
||||
Note that if a ``streaming_callback`` is used, data will be
|
||||
read from the socket as quickly as it becomes available; there
|
||||
is no way to apply backpressure or cancel the reads. If flow
|
||||
control or cancellation are desired, use a loop with
|
||||
`read_bytes(partial=True) <.read_bytes>` instead.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
The callback argument is now optional and a `.Future` will
|
||||
be returned if it is omitted.
|
||||
|
||||
"""
|
||||
future = self._set_read_callback(callback)
|
||||
self._streaming_callback = stack_context.wrap(streaming_callback)
|
||||
|
@ -305,7 +341,11 @@ class BaseIOStream(object):
|
|||
self._run_read_callback(self._read_buffer_size, False)
|
||||
return future
|
||||
self._read_until_close = True
|
||||
try:
|
||||
self._try_inline_read()
|
||||
except:
|
||||
future.add_done_callback(lambda f: f.exception())
|
||||
raise
|
||||
return future
|
||||
|
||||
def write(self, data, callback=None):
|
||||
|
@ -331,7 +371,7 @@ class BaseIOStream(object):
|
|||
if data:
|
||||
if (self.max_write_buffer_size is not None and
|
||||
self._write_buffer_size + len(data) > self.max_write_buffer_size):
|
||||
raise StreamBufferFullError("Reached maximum read buffer size")
|
||||
raise StreamBufferFullError("Reached maximum write buffer size")
|
||||
# Break up large contiguous strings before inserting them in the
|
||||
# write buffer, so we don't have to recopy the entire thing
|
||||
# as we slice off pieces to send to the socket.
|
||||
|
@ -344,6 +384,7 @@ class BaseIOStream(object):
|
|||
future = None
|
||||
else:
|
||||
future = self._write_future = TracebackFuture()
|
||||
future.add_done_callback(lambda f: f.exception())
|
||||
if not self._connecting:
|
||||
self._handle_write()
|
||||
if self._write_buffer:
|
||||
|
@ -401,9 +442,11 @@ class BaseIOStream(object):
|
|||
if self._connect_future is not None:
|
||||
futures.append(self._connect_future)
|
||||
self._connect_future = None
|
||||
if self._ssl_connect_future is not None:
|
||||
futures.append(self._ssl_connect_future)
|
||||
self._ssl_connect_future = None
|
||||
for future in futures:
|
||||
if (isinstance(self.error, (socket.error, IOError)) and
|
||||
errno_from_exception(self.error) in _ERRNO_CONNRESET):
|
||||
if self._is_connreset(self.error):
|
||||
# Treat connection resets as closed connections so
|
||||
# clients only have to catch one kind of exception
|
||||
# to avoid logging.
|
||||
|
@ -601,8 +644,7 @@ class BaseIOStream(object):
|
|||
pos = self._read_to_buffer_loop()
|
||||
except UnsatisfiableReadError:
|
||||
raise
|
||||
except Exception as e:
|
||||
if 1 != e.errno:
|
||||
except Exception:
|
||||
gen_log.warning("error on read", exc_info=True)
|
||||
self.close(exc_info=True)
|
||||
return
|
||||
|
@ -633,7 +675,7 @@ class BaseIOStream(object):
|
|||
self._read_future = None
|
||||
future.set_result(self._consume(size))
|
||||
if callback is not None:
|
||||
assert self._read_future is None
|
||||
assert (self._read_future is None) or streaming
|
||||
self._run_callback(callback, self._consume(size))
|
||||
else:
|
||||
# If we scheduled a callback, we will add the error listener
|
||||
|
@ -684,7 +726,7 @@ class BaseIOStream(object):
|
|||
chunk = self.read_from_fd()
|
||||
except (socket.error, IOError, OSError) as e:
|
||||
# ssl.SSLError is a subclass of socket.error
|
||||
if e.args[0] in _ERRNO_CONNRESET:
|
||||
if self._is_connreset(e):
|
||||
# Treat ECONNRESET as a connection close rather than
|
||||
# an error to minimize log spam (the exception will
|
||||
# be available on self.error for apps that care).
|
||||
|
@ -806,7 +848,7 @@ class BaseIOStream(object):
|
|||
self._write_buffer_frozen = True
|
||||
break
|
||||
else:
|
||||
if e.args[0] not in _ERRNO_CONNRESET:
|
||||
if not self._is_connreset(e):
|
||||
# Broken pipe errors are usually caused by connection
|
||||
# reset, and its better to not log EPIPE errors to
|
||||
# minimize log spam
|
||||
|
@ -884,6 +926,14 @@ class BaseIOStream(object):
|
|||
self._state = self._state | state
|
||||
self.io_loop.update_handler(self.fileno(), self._state)
|
||||
|
||||
def _is_connreset(self, exc):
|
||||
"""Return true if exc is ECONNRESET or equivalent.
|
||||
|
||||
May be overridden in subclasses.
|
||||
"""
|
||||
return (isinstance(exc, (socket.error, IOError)) and
|
||||
errno_from_exception(exc) in _ERRNO_CONNRESET)
|
||||
|
||||
|
||||
class IOStream(BaseIOStream):
|
||||
r"""Socket-based `IOStream` implementation.
|
||||
|
@ -898,7 +948,9 @@ class IOStream(BaseIOStream):
|
|||
connected before passing it to the `IOStream` or connected with
|
||||
`IOStream.connect`.
|
||||
|
||||
A very simple (and broken) HTTP client using this class::
|
||||
A very simple (and broken) HTTP client using this class:
|
||||
|
||||
.. testcode::
|
||||
|
||||
import tornado.ioloop
|
||||
import tornado.iostream
|
||||
|
@ -917,14 +969,19 @@ class IOStream(BaseIOStream):
|
|||
stream.read_bytes(int(headers[b"Content-Length"]), on_body)
|
||||
|
||||
def on_body(data):
|
||||
print data
|
||||
print(data)
|
||||
stream.close()
|
||||
tornado.ioloop.IOLoop.instance().stop()
|
||||
tornado.ioloop.IOLoop.current().stop()
|
||||
|
||||
if __name__ == '__main__':
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
|
||||
stream = tornado.iostream.IOStream(s)
|
||||
stream.connect(("friendfeed.com", 80), send_request)
|
||||
tornado.ioloop.IOLoop.instance().start()
|
||||
tornado.ioloop.IOLoop.current().start()
|
||||
|
||||
.. testoutput::
|
||||
:hide:
|
||||
|
||||
"""
|
||||
def __init__(self, socket, *args, **kwargs):
|
||||
self.socket = socket
|
||||
|
@ -978,10 +1035,10 @@ class IOStream(BaseIOStream):
|
|||
returns a `.Future` (whose result after a successful
|
||||
connection will be the stream itself).
|
||||
|
||||
If specified, the ``server_hostname`` parameter will be used
|
||||
in SSL connections for certificate validation (if requested in
|
||||
the ``ssl_options``) and SNI (if supported; requires
|
||||
Python 3.2+).
|
||||
In SSL mode, the ``server_hostname`` parameter will be used
|
||||
for certificate validation (unless disabled in the
|
||||
``ssl_options``) and SNI (if supported; requires Python
|
||||
2.7.9+).
|
||||
|
||||
Note that it is safe to call `IOStream.write
|
||||
<BaseIOStream.write>` while the connection is pending, in
|
||||
|
@ -992,6 +1049,11 @@ class IOStream(BaseIOStream):
|
|||
.. versionchanged:: 4.0
|
||||
If no callback is given, returns a `.Future`.
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
SSL certificates are validated by default; pass
|
||||
``ssl_options=dict(cert_reqs=ssl.CERT_NONE)`` or a
|
||||
suitably-configured `ssl.SSLContext` to the
|
||||
`SSLIOStream` constructor to disable.
|
||||
"""
|
||||
self._connecting = True
|
||||
if callback is not None:
|
||||
|
@ -1011,6 +1073,7 @@ class IOStream(BaseIOStream):
|
|||
# reported later in _handle_connect.
|
||||
if (errno_from_exception(e) not in _ERRNO_INPROGRESS and
|
||||
errno_from_exception(e) not in _ERRNO_WOULDBLOCK):
|
||||
if future is None:
|
||||
gen_log.warning("Connect error on fd %s: %s",
|
||||
self.socket.fileno(), e)
|
||||
self.close(exc_info=True)
|
||||
|
@ -1033,10 +1096,11 @@ class IOStream(BaseIOStream):
|
|||
data. It can also be used immediately after connecting,
|
||||
before any reads or writes.
|
||||
|
||||
The ``ssl_options`` argument may be either a dictionary
|
||||
of options or an `ssl.SSLContext`. If a ``server_hostname``
|
||||
is given, it will be used for certificate verification
|
||||
(as configured in the ``ssl_options``).
|
||||
The ``ssl_options`` argument may be either an `ssl.SSLContext`
|
||||
object or a dictionary of keyword arguments for the
|
||||
`ssl.wrap_socket` function. The ``server_hostname`` argument
|
||||
will be used for certificate validation unless disabled
|
||||
in the ``ssl_options``.
|
||||
|
||||
This method returns a `.Future` whose result is the new
|
||||
`SSLIOStream`. After this method has been called,
|
||||
|
@ -1046,6 +1110,11 @@ class IOStream(BaseIOStream):
|
|||
transferred to the new stream.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
SSL certificates are validated by default; pass
|
||||
``ssl_options=dict(cert_reqs=ssl.CERT_NONE)`` or a
|
||||
suitably-configured `ssl.SSLContext` to disable.
|
||||
"""
|
||||
if (self._read_callback or self._read_future or
|
||||
self._write_callback or self._write_future or
|
||||
|
@ -1054,12 +1123,17 @@ class IOStream(BaseIOStream):
|
|||
self._read_buffer or self._write_buffer):
|
||||
raise ValueError("IOStream is not idle; cannot convert to SSL")
|
||||
if ssl_options is None:
|
||||
ssl_options = {}
|
||||
if server_side:
|
||||
ssl_options = _server_ssl_defaults
|
||||
else:
|
||||
ssl_options = _client_ssl_defaults
|
||||
|
||||
socket = self.socket
|
||||
self.io_loop.remove_handler(socket)
|
||||
self.socket = None
|
||||
socket = ssl_wrap_socket(socket, ssl_options, server_side=server_side,
|
||||
socket = ssl_wrap_socket(socket, ssl_options,
|
||||
server_hostname=server_hostname,
|
||||
server_side=server_side,
|
||||
do_handshake_on_connect=False)
|
||||
orig_close_callback = self._close_callback
|
||||
self._close_callback = None
|
||||
|
@ -1071,6 +1145,7 @@ class IOStream(BaseIOStream):
|
|||
# If we had an "unwrap" counterpart to this method we would need
|
||||
# to restore the original callback after our Future resolves
|
||||
# so that repeated wrap/unwrap calls don't build up layers.
|
||||
|
||||
def close_callback():
|
||||
if not future.done():
|
||||
future.set_exception(ssl_stream.error or StreamClosedError())
|
||||
|
@ -1115,7 +1190,7 @@ class IOStream(BaseIOStream):
|
|||
# Sometimes setsockopt will fail if the socket is closed
|
||||
# at the wrong time. This can happen with HTTPServer
|
||||
# resetting the value to false between requests.
|
||||
if e.errno not in (errno.EINVAL, errno.ECONNRESET):
|
||||
if e.errno != errno.EINVAL and not self._is_connreset(e):
|
||||
raise
|
||||
|
||||
|
||||
|
@ -1131,11 +1206,11 @@ class SSLIOStream(IOStream):
|
|||
wrapped when `IOStream.connect` is finished.
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""The ``ssl_options`` keyword argument may either be a dictionary
|
||||
of keywords arguments for `ssl.wrap_socket`, or an `ssl.SSLContext`
|
||||
object.
|
||||
"""The ``ssl_options`` keyword argument may either be an
|
||||
`ssl.SSLContext` object or a dictionary of keywords arguments
|
||||
for `ssl.wrap_socket`
|
||||
"""
|
||||
self._ssl_options = kwargs.pop('ssl_options', {})
|
||||
self._ssl_options = kwargs.pop('ssl_options', _client_ssl_defaults)
|
||||
super(SSLIOStream, self).__init__(*args, **kwargs)
|
||||
self._ssl_accepting = True
|
||||
self._handshake_reading = False
|
||||
|
@ -1190,8 +1265,7 @@ class SSLIOStream(IOStream):
|
|||
# to cause do_handshake to raise EBADF, so make that error
|
||||
# quiet as well.
|
||||
# https://groups.google.com/forum/?fromgroups#!topic/python-tornado/ApucKJat1_0
|
||||
if (err.args[0] in _ERRNO_CONNRESET or
|
||||
err.args[0] == errno.EBADF):
|
||||
if self._is_connreset(err) or err.args[0] == errno.EBADF:
|
||||
return self.close(exc_info=True)
|
||||
raise
|
||||
except AttributeError:
|
||||
|
@ -1204,10 +1278,17 @@ class SSLIOStream(IOStream):
|
|||
if not self._verify_cert(self.socket.getpeercert()):
|
||||
self.close()
|
||||
return
|
||||
self._run_ssl_connect_callback()
|
||||
|
||||
def _run_ssl_connect_callback(self):
|
||||
if self._ssl_connect_callback is not None:
|
||||
callback = self._ssl_connect_callback
|
||||
self._ssl_connect_callback = None
|
||||
self._run_callback(callback)
|
||||
if self._ssl_connect_future is not None:
|
||||
future = self._ssl_connect_future
|
||||
self._ssl_connect_future = None
|
||||
future.set_result(self)
|
||||
|
||||
def _verify_cert(self, peercert):
|
||||
"""Returns True if peercert is valid according to the configured
|
||||
|
@ -1249,14 +1330,11 @@ class SSLIOStream(IOStream):
|
|||
super(SSLIOStream, self)._handle_write()
|
||||
|
||||
def connect(self, address, callback=None, server_hostname=None):
|
||||
# Save the user's callback and run it after the ssl handshake
|
||||
# has completed.
|
||||
self._ssl_connect_callback = stack_context.wrap(callback)
|
||||
self._server_hostname = server_hostname
|
||||
# Note: Since we don't pass our callback argument along to
|
||||
# super.connect(), this will always return a Future.
|
||||
# This is harmless, but a bit less efficient than it could be.
|
||||
return super(SSLIOStream, self).connect(address, callback=None)
|
||||
# Pass a dummy callback to super.connect(), which is slightly
|
||||
# more efficient than letting it return a Future we ignore.
|
||||
super(SSLIOStream, self).connect(address, callback=lambda: None)
|
||||
return self.wait_for_handshake(callback)
|
||||
|
||||
def _handle_connect(self):
|
||||
# Call the superclass method to check for errors.
|
||||
|
@ -1281,6 +1359,51 @@ class SSLIOStream(IOStream):
|
|||
do_handshake_on_connect=False)
|
||||
self._add_io_state(old_state)
|
||||
|
||||
def wait_for_handshake(self, callback=None):
|
||||
"""Wait for the initial SSL handshake to complete.
|
||||
|
||||
If a ``callback`` is given, it will be called with no
|
||||
arguments once the handshake is complete; otherwise this
|
||||
method returns a `.Future` which will resolve to the
|
||||
stream itself after the handshake is complete.
|
||||
|
||||
Once the handshake is complete, information such as
|
||||
the peer's certificate and NPN/ALPN selections may be
|
||||
accessed on ``self.socket``.
|
||||
|
||||
This method is intended for use on server-side streams
|
||||
or after using `IOStream.start_tls`; it should not be used
|
||||
with `IOStream.connect` (which already waits for the
|
||||
handshake to complete). It may only be called once per stream.
|
||||
|
||||
.. versionadded:: 4.2
|
||||
"""
|
||||
if (self._ssl_connect_callback is not None or
|
||||
self._ssl_connect_future is not None):
|
||||
raise RuntimeError("Already waiting")
|
||||
if callback is not None:
|
||||
self._ssl_connect_callback = stack_context.wrap(callback)
|
||||
future = None
|
||||
else:
|
||||
future = self._ssl_connect_future = TracebackFuture()
|
||||
if not self._ssl_accepting:
|
||||
self._run_ssl_connect_callback()
|
||||
return future
|
||||
|
||||
def write_to_fd(self, data):
|
||||
try:
|
||||
return self.socket.send(data)
|
||||
except ssl.SSLError as e:
|
||||
if e.args[0] == ssl.SSL_ERROR_WANT_WRITE:
|
||||
# In Python 3.5+, SSLSocket.send raises a WANT_WRITE error if
|
||||
# the socket is not writeable; we need to transform this into
|
||||
# an EWOULDBLOCK socket.error or a zero return value,
|
||||
# either of which will be recognized by the caller of this
|
||||
# method. Prior to Python 3.5, an unwriteable socket would
|
||||
# simply return 0 bytes written.
|
||||
return 0
|
||||
raise
|
||||
|
||||
def read_from_fd(self):
|
||||
if self._ssl_accepting:
|
||||
# If the handshake hasn't finished yet, there can't be anything
|
||||
|
@ -1311,6 +1434,11 @@ class SSLIOStream(IOStream):
|
|||
return None
|
||||
return chunk
|
||||
|
||||
def _is_connreset(self, e):
|
||||
if isinstance(e, ssl.SSLError) and e.args[0] == ssl.SSL_ERROR_EOF:
|
||||
return True
|
||||
return super(SSLIOStream, self)._is_connreset(e)
|
||||
|
||||
|
||||
class PipeIOStream(BaseIOStream):
|
||||
"""Pipe-based `IOStream` implementation.
|
||||
|
|
|
@ -55,6 +55,7 @@ _default_locale = "en_US"
|
|||
_translations = {}
|
||||
_supported_locales = frozenset([_default_locale])
|
||||
_use_gettext = False
|
||||
CONTEXT_SEPARATOR = "\x04"
|
||||
|
||||
|
||||
def get(*locale_codes):
|
||||
|
@ -273,6 +274,9 @@ class Locale(object):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def pgettext(self, context, message, plural_message=None, count=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def format_date(self, date, gmt_offset=0, relative=True, shorter=False,
|
||||
full_format=False):
|
||||
"""Formats the given date (which should be GMT).
|
||||
|
@ -422,6 +426,11 @@ class CSVLocale(Locale):
|
|||
message_dict = self.translations.get("unknown", {})
|
||||
return message_dict.get(message, message)
|
||||
|
||||
def pgettext(self, context, message, plural_message=None, count=None):
|
||||
if self.translations:
|
||||
gen_log.warning('pgettext is not supported by CSVLocale')
|
||||
return self.translate(message, plural_message, count)
|
||||
|
||||
|
||||
class GettextLocale(Locale):
|
||||
"""Locale implementation using the `gettext` module."""
|
||||
|
@ -445,6 +454,44 @@ class GettextLocale(Locale):
|
|||
else:
|
||||
return self.gettext(message)
|
||||
|
||||
def pgettext(self, context, message, plural_message=None, count=None):
|
||||
"""Allows to set context for translation, accepts plural forms.
|
||||
|
||||
Usage example::
|
||||
|
||||
pgettext("law", "right")
|
||||
pgettext("good", "right")
|
||||
|
||||
Plural message example::
|
||||
|
||||
pgettext("organization", "club", "clubs", len(clubs))
|
||||
pgettext("stick", "club", "clubs", len(clubs))
|
||||
|
||||
To generate POT file with context, add following options to step 1
|
||||
of `load_gettext_translations` sequence::
|
||||
|
||||
xgettext [basic options] --keyword=pgettext:1c,2 --keyword=pgettext:1c,2,3
|
||||
|
||||
.. versionadded:: 4.2
|
||||
"""
|
||||
if plural_message is not None:
|
||||
assert count is not None
|
||||
msgs_with_ctxt = ("%s%s%s" % (context, CONTEXT_SEPARATOR, message),
|
||||
"%s%s%s" % (context, CONTEXT_SEPARATOR, plural_message),
|
||||
count)
|
||||
result = self.ngettext(*msgs_with_ctxt)
|
||||
if CONTEXT_SEPARATOR in result:
|
||||
# Translation not found
|
||||
result = self.ngettext(message, plural_message, count)
|
||||
return result
|
||||
else:
|
||||
msg_with_ctxt = "%s%s%s" % (context, CONTEXT_SEPARATOR, message)
|
||||
result = self.gettext(msg_with_ctxt)
|
||||
if CONTEXT_SEPARATOR in result:
|
||||
# Translation not found
|
||||
result = message
|
||||
return result
|
||||
|
||||
LOCALE_NAMES = {
|
||||
"af_ZA": {"name_en": u("Afrikaans"), "name": u("Afrikaans")},
|
||||
"am_ET": {"name_en": u("Amharic"), "name": u('\u12a0\u121b\u122d\u129b')},
|
||||
|
|
460
tornado/locks.py
Normal file
460
tornado/locks.py
Normal file
|
@ -0,0 +1,460 @@
|
|||
# Copyright 2015 The Tornado Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
.. testsetup:: *
|
||||
|
||||
from tornado import ioloop, gen, locks
|
||||
io_loop = ioloop.IOLoop.current()
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
|
||||
__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock']
|
||||
|
||||
import collections
|
||||
|
||||
from tornado import gen, ioloop
|
||||
from tornado.concurrent import Future
|
||||
|
||||
|
||||
class _TimeoutGarbageCollector(object):
|
||||
"""Base class for objects that periodically clean up timed-out waiters.
|
||||
|
||||
Avoids memory leak in a common pattern like:
|
||||
|
||||
while True:
|
||||
yield condition.wait(short_timeout)
|
||||
print('looping....')
|
||||
"""
|
||||
def __init__(self):
|
||||
self._waiters = collections.deque() # Futures.
|
||||
self._timeouts = 0
|
||||
|
||||
def _garbage_collect(self):
|
||||
# Occasionally clear timed-out waiters.
|
||||
self._timeouts += 1
|
||||
if self._timeouts > 100:
|
||||
self._timeouts = 0
|
||||
self._waiters = collections.deque(
|
||||
w for w in self._waiters if not w.done())
|
||||
|
||||
|
||||
class Condition(_TimeoutGarbageCollector):
|
||||
"""A condition allows one or more coroutines to wait until notified.
|
||||
|
||||
Like a standard `threading.Condition`, but does not need an underlying lock
|
||||
that is acquired and released.
|
||||
|
||||
With a `Condition`, coroutines can wait to be notified by other coroutines:
|
||||
|
||||
.. testcode::
|
||||
|
||||
condition = locks.Condition()
|
||||
|
||||
@gen.coroutine
|
||||
def waiter():
|
||||
print("I'll wait right here")
|
||||
yield condition.wait() # Yield a Future.
|
||||
print("I'm done waiting")
|
||||
|
||||
@gen.coroutine
|
||||
def notifier():
|
||||
print("About to notify")
|
||||
condition.notify()
|
||||
print("Done notifying")
|
||||
|
||||
@gen.coroutine
|
||||
def runner():
|
||||
# Yield two Futures; wait for waiter() and notifier() to finish.
|
||||
yield [waiter(), notifier()]
|
||||
|
||||
io_loop.run_sync(runner)
|
||||
|
||||
.. testoutput::
|
||||
|
||||
I'll wait right here
|
||||
About to notify
|
||||
Done notifying
|
||||
I'm done waiting
|
||||
|
||||
`wait` takes an optional ``timeout`` argument, which is either an absolute
|
||||
timestamp::
|
||||
|
||||
io_loop = ioloop.IOLoop.current()
|
||||
|
||||
# Wait up to 1 second for a notification.
|
||||
yield condition.wait(timeout=io_loop.time() + 1)
|
||||
|
||||
...or a `datetime.timedelta` for a timeout relative to the current time::
|
||||
|
||||
# Wait up to 1 second.
|
||||
yield condition.wait(timeout=datetime.timedelta(seconds=1))
|
||||
|
||||
The method raises `tornado.gen.TimeoutError` if there's no notification
|
||||
before the deadline.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Condition, self).__init__()
|
||||
self.io_loop = ioloop.IOLoop.current()
|
||||
|
||||
def __repr__(self):
|
||||
result = '<%s' % (self.__class__.__name__, )
|
||||
if self._waiters:
|
||||
result += ' waiters[%s]' % len(self._waiters)
|
||||
return result + '>'
|
||||
|
||||
def wait(self, timeout=None):
|
||||
"""Wait for `.notify`.
|
||||
|
||||
Returns a `.Future` that resolves ``True`` if the condition is notified,
|
||||
or ``False`` after a timeout.
|
||||
"""
|
||||
waiter = Future()
|
||||
self._waiters.append(waiter)
|
||||
if timeout:
|
||||
def on_timeout():
|
||||
waiter.set_result(False)
|
||||
self._garbage_collect()
|
||||
io_loop = ioloop.IOLoop.current()
|
||||
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
|
||||
waiter.add_done_callback(
|
||||
lambda _: io_loop.remove_timeout(timeout_handle))
|
||||
return waiter
|
||||
|
||||
def notify(self, n=1):
|
||||
"""Wake ``n`` waiters."""
|
||||
waiters = [] # Waiters we plan to run right now.
|
||||
while n and self._waiters:
|
||||
waiter = self._waiters.popleft()
|
||||
if not waiter.done(): # Might have timed out.
|
||||
n -= 1
|
||||
waiters.append(waiter)
|
||||
|
||||
for waiter in waiters:
|
||||
waiter.set_result(True)
|
||||
|
||||
def notify_all(self):
|
||||
"""Wake all waiters."""
|
||||
self.notify(len(self._waiters))
|
||||
|
||||
|
||||
class Event(object):
|
||||
"""An event blocks coroutines until its internal flag is set to True.
|
||||
|
||||
Similar to `threading.Event`.
|
||||
|
||||
A coroutine can wait for an event to be set. Once it is set, calls to
|
||||
``yield event.wait()`` will not block unless the event has been cleared:
|
||||
|
||||
.. testcode::
|
||||
|
||||
event = locks.Event()
|
||||
|
||||
@gen.coroutine
|
||||
def waiter():
|
||||
print("Waiting for event")
|
||||
yield event.wait()
|
||||
print("Not waiting this time")
|
||||
yield event.wait()
|
||||
print("Done")
|
||||
|
||||
@gen.coroutine
|
||||
def setter():
|
||||
print("About to set the event")
|
||||
event.set()
|
||||
|
||||
@gen.coroutine
|
||||
def runner():
|
||||
yield [waiter(), setter()]
|
||||
|
||||
io_loop.run_sync(runner)
|
||||
|
||||
.. testoutput::
|
||||
|
||||
Waiting for event
|
||||
About to set the event
|
||||
Not waiting this time
|
||||
Done
|
||||
"""
|
||||
def __init__(self):
|
||||
self._future = Future()
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s %s>' % (
|
||||
self.__class__.__name__, 'set' if self.is_set() else 'clear')
|
||||
|
||||
def is_set(self):
|
||||
"""Return ``True`` if the internal flag is true."""
|
||||
return self._future.done()
|
||||
|
||||
def set(self):
|
||||
"""Set the internal flag to ``True``. All waiters are awakened.
|
||||
|
||||
Calling `.wait` once the flag is set will not block.
|
||||
"""
|
||||
if not self._future.done():
|
||||
self._future.set_result(None)
|
||||
|
||||
def clear(self):
|
||||
"""Reset the internal flag to ``False``.
|
||||
|
||||
Calls to `.wait` will block until `.set` is called.
|
||||
"""
|
||||
if self._future.done():
|
||||
self._future = Future()
|
||||
|
||||
def wait(self, timeout=None):
|
||||
"""Block until the internal flag is true.
|
||||
|
||||
Returns a Future, which raises `tornado.gen.TimeoutError` after a
|
||||
timeout.
|
||||
"""
|
||||
if timeout is None:
|
||||
return self._future
|
||||
else:
|
||||
return gen.with_timeout(timeout, self._future)
|
||||
|
||||
|
||||
class _ReleasingContextManager(object):
|
||||
"""Releases a Lock or Semaphore at the end of a "with" statement.
|
||||
|
||||
with (yield semaphore.acquire()):
|
||||
pass
|
||||
|
||||
# Now semaphore.release() has been called.
|
||||
"""
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._obj.release()
|
||||
|
||||
|
||||
class Semaphore(_TimeoutGarbageCollector):
|
||||
"""A lock that can be acquired a fixed number of times before blocking.
|
||||
|
||||
A Semaphore manages a counter representing the number of `.release` calls
|
||||
minus the number of `.acquire` calls, plus an initial value. The `.acquire`
|
||||
method blocks if necessary until it can return without making the counter
|
||||
negative.
|
||||
|
||||
Semaphores limit access to a shared resource. To allow access for two
|
||||
workers at a time:
|
||||
|
||||
.. testsetup:: semaphore
|
||||
|
||||
from collections import deque
|
||||
|
||||
from tornado import gen, ioloop
|
||||
from tornado.concurrent import Future
|
||||
|
||||
# Ensure reliable doctest output: resolve Futures one at a time.
|
||||
futures_q = deque([Future() for _ in range(3)])
|
||||
|
||||
@gen.coroutine
|
||||
def simulator(futures):
|
||||
for f in futures:
|
||||
yield gen.moment
|
||||
f.set_result(None)
|
||||
|
||||
ioloop.IOLoop.current().add_callback(simulator, list(futures_q))
|
||||
|
||||
def use_some_resource():
|
||||
return futures_q.popleft()
|
||||
|
||||
.. testcode:: semaphore
|
||||
|
||||
sem = locks.Semaphore(2)
|
||||
|
||||
@gen.coroutine
|
||||
def worker(worker_id):
|
||||
yield sem.acquire()
|
||||
try:
|
||||
print("Worker %d is working" % worker_id)
|
||||
yield use_some_resource()
|
||||
finally:
|
||||
print("Worker %d is done" % worker_id)
|
||||
sem.release()
|
||||
|
||||
@gen.coroutine
|
||||
def runner():
|
||||
# Join all workers.
|
||||
yield [worker(i) for i in range(3)]
|
||||
|
||||
io_loop.run_sync(runner)
|
||||
|
||||
.. testoutput:: semaphore
|
||||
|
||||
Worker 0 is working
|
||||
Worker 1 is working
|
||||
Worker 0 is done
|
||||
Worker 2 is working
|
||||
Worker 1 is done
|
||||
Worker 2 is done
|
||||
|
||||
Workers 0 and 1 are allowed to run concurrently, but worker 2 waits until
|
||||
the semaphore has been released once, by worker 0.
|
||||
|
||||
`.acquire` is a context manager, so ``worker`` could be written as::
|
||||
|
||||
@gen.coroutine
|
||||
def worker(worker_id):
|
||||
with (yield sem.acquire()):
|
||||
print("Worker %d is working" % worker_id)
|
||||
yield use_some_resource()
|
||||
|
||||
# Now the semaphore has been released.
|
||||
print("Worker %d is done" % worker_id)
|
||||
"""
|
||||
def __init__(self, value=1):
|
||||
super(Semaphore, self).__init__()
|
||||
if value < 0:
|
||||
raise ValueError('semaphore initial value must be >= 0')
|
||||
|
||||
self._value = value
|
||||
|
||||
def __repr__(self):
|
||||
res = super(Semaphore, self).__repr__()
|
||||
extra = 'locked' if self._value == 0 else 'unlocked,value:{0}'.format(
|
||||
self._value)
|
||||
if self._waiters:
|
||||
extra = '{0},waiters:{1}'.format(extra, len(self._waiters))
|
||||
return '<{0} [{1}]>'.format(res[1:-1], extra)
|
||||
|
||||
def release(self):
|
||||
"""Increment the counter and wake one waiter."""
|
||||
self._value += 1
|
||||
while self._waiters:
|
||||
waiter = self._waiters.popleft()
|
||||
if not waiter.done():
|
||||
self._value -= 1
|
||||
|
||||
# If the waiter is a coroutine paused at
|
||||
#
|
||||
# with (yield semaphore.acquire()):
|
||||
#
|
||||
# then the context manager's __exit__ calls release() at the end
|
||||
# of the "with" block.
|
||||
waiter.set_result(_ReleasingContextManager(self))
|
||||
break
|
||||
|
||||
def acquire(self, timeout=None):
|
||||
"""Decrement the counter. Returns a Future.
|
||||
|
||||
Block if the counter is zero and wait for a `.release`. The Future
|
||||
raises `.TimeoutError` after the deadline.
|
||||
"""
|
||||
waiter = Future()
|
||||
if self._value > 0:
|
||||
self._value -= 1
|
||||
waiter.set_result(_ReleasingContextManager(self))
|
||||
else:
|
||||
self._waiters.append(waiter)
|
||||
if timeout:
|
||||
def on_timeout():
|
||||
waiter.set_exception(gen.TimeoutError())
|
||||
self._garbage_collect()
|
||||
io_loop = ioloop.IOLoop.current()
|
||||
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
|
||||
waiter.add_done_callback(
|
||||
lambda _: io_loop.remove_timeout(timeout_handle))
|
||||
return waiter
|
||||
|
||||
def __enter__(self):
|
||||
raise RuntimeError(
|
||||
"Use Semaphore like 'with (yield semaphore.acquire())', not like"
|
||||
" 'with semaphore'")
|
||||
|
||||
__exit__ = __enter__
|
||||
|
||||
|
||||
class BoundedSemaphore(Semaphore):
|
||||
"""A semaphore that prevents release() being called too many times.
|
||||
|
||||
If `.release` would increment the semaphore's value past the initial
|
||||
value, it raises `ValueError`. Semaphores are mostly used to guard
|
||||
resources with limited capacity, so a semaphore released too many times
|
||||
is a sign of a bug.
|
||||
"""
|
||||
def __init__(self, value=1):
|
||||
super(BoundedSemaphore, self).__init__(value=value)
|
||||
self._initial_value = value
|
||||
|
||||
def release(self):
|
||||
"""Increment the counter and wake one waiter."""
|
||||
if self._value >= self._initial_value:
|
||||
raise ValueError("Semaphore released too many times")
|
||||
super(BoundedSemaphore, self).release()
|
||||
|
||||
|
||||
class Lock(object):
|
||||
"""A lock for coroutines.
|
||||
|
||||
A Lock begins unlocked, and `acquire` locks it immediately. While it is
|
||||
locked, a coroutine that yields `acquire` waits until another coroutine
|
||||
calls `release`.
|
||||
|
||||
Releasing an unlocked lock raises `RuntimeError`.
|
||||
|
||||
`acquire` supports the context manager protocol:
|
||||
|
||||
>>> from tornado import gen, locks
|
||||
>>> lock = locks.Lock()
|
||||
>>>
|
||||
>>> @gen.coroutine
|
||||
... def f():
|
||||
... with (yield lock.acquire()):
|
||||
... # Do something holding the lock.
|
||||
... pass
|
||||
...
|
||||
... # Now the lock is released.
|
||||
"""
|
||||
def __init__(self):
|
||||
self._block = BoundedSemaphore(value=1)
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s _block=%s>" % (
|
||||
self.__class__.__name__,
|
||||
self._block)
|
||||
|
||||
def acquire(self, timeout=None):
|
||||
"""Attempt to lock. Returns a Future.
|
||||
|
||||
Returns a Future, which raises `tornado.gen.TimeoutError` after a
|
||||
timeout.
|
||||
"""
|
||||
return self._block.acquire(timeout)
|
||||
|
||||
def release(self):
|
||||
"""Unlock.
|
||||
|
||||
The first coroutine in line waiting for `acquire` gets the lock.
|
||||
|
||||
If not locked, raise a `RuntimeError`.
|
||||
"""
|
||||
try:
|
||||
self._block.release()
|
||||
except ValueError:
|
||||
raise RuntimeError('release unlocked lock')
|
||||
|
||||
def __enter__(self):
|
||||
raise RuntimeError(
|
||||
"Use Lock like 'with (yield lock)', not like 'with lock'")
|
||||
|
||||
__exit__ = __enter__
|
|
@ -206,6 +206,14 @@ def enable_pretty_logging(options=None, logger=None):
|
|||
|
||||
|
||||
def define_logging_options(options=None):
|
||||
"""Add logging-related flags to ``options``.
|
||||
|
||||
These options are present automatically on the default options instance;
|
||||
this method is only necessary if you have created your own `.OptionParser`.
|
||||
|
||||
.. versionadded:: 4.2
|
||||
This function existed in prior versions but was broken and undocumented until 4.2.
|
||||
"""
|
||||
if options is None:
|
||||
# late import to prevent cycle
|
||||
from tornado.options import options
|
||||
|
@ -227,4 +235,4 @@ def define_logging_options(options=None):
|
|||
options.define("log_file_num_backups", type=int, default=10,
|
||||
help="number of log files to keep")
|
||||
|
||||
options.add_parse_callback(enable_pretty_logging)
|
||||
options.add_parse_callback(lambda: enable_pretty_logging(options))
|
||||
|
|
|
@ -20,7 +20,7 @@ from __future__ import absolute_import, division, print_function, with_statement
|
|||
|
||||
import errno
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import socket
|
||||
import stat
|
||||
|
||||
|
@ -35,6 +35,15 @@ except ImportError:
|
|||
# ssl is not available on Google App Engine
|
||||
ssl = None
|
||||
|
||||
try:
|
||||
import certifi
|
||||
except ImportError:
|
||||
# certifi is optional as long as we have ssl.create_default_context.
|
||||
if ssl is None or hasattr(ssl, 'create_default_context'):
|
||||
certifi = None
|
||||
else:
|
||||
raise
|
||||
|
||||
try:
|
||||
xrange # py2
|
||||
except NameError:
|
||||
|
@ -50,6 +59,38 @@ else:
|
|||
ssl_match_hostname = backports.ssl_match_hostname.match_hostname
|
||||
SSLCertificateError = backports.ssl_match_hostname.CertificateError
|
||||
|
||||
if hasattr(ssl, 'SSLContext'):
|
||||
if hasattr(ssl, 'create_default_context'):
|
||||
# Python 2.7.9+, 3.4+
|
||||
# Note that the naming of ssl.Purpose is confusing; the purpose
|
||||
# of a context is to authentiate the opposite side of the connection.
|
||||
_client_ssl_defaults = ssl.create_default_context(
|
||||
ssl.Purpose.SERVER_AUTH)
|
||||
_server_ssl_defaults = ssl.create_default_context(
|
||||
ssl.Purpose.CLIENT_AUTH)
|
||||
else:
|
||||
# Python 3.2-3.3
|
||||
_client_ssl_defaults = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
_client_ssl_defaults.verify_mode = ssl.CERT_REQUIRED
|
||||
_client_ssl_defaults.load_verify_locations(certifi.where())
|
||||
_server_ssl_defaults = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
if hasattr(ssl, 'OP_NO_COMPRESSION'):
|
||||
# Disable TLS compression to avoid CRIME and related attacks.
|
||||
# This constant wasn't added until python 3.3.
|
||||
_client_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
|
||||
_server_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
|
||||
|
||||
elif ssl:
|
||||
# Python 2.6-2.7.8
|
||||
_client_ssl_defaults = dict(cert_reqs=ssl.CERT_REQUIRED,
|
||||
ca_certs=certifi.where())
|
||||
_server_ssl_defaults = {}
|
||||
else:
|
||||
# Google App Engine
|
||||
_client_ssl_defaults = dict(cert_reqs=None,
|
||||
ca_certs=None)
|
||||
_server_ssl_defaults = {}
|
||||
|
||||
# ThreadedResolver runs getaddrinfo on a thread. If the hostname is unicode,
|
||||
# getaddrinfo attempts to import encodings.idna. If this is done at
|
||||
# module-import time, the import lock is already held by the main thread,
|
||||
|
@ -68,6 +109,7 @@ if hasattr(errno, "WSAEWOULDBLOCK"):
|
|||
# Default backlog used when calling sock.listen()
|
||||
_DEFAULT_BACKLOG = 128
|
||||
|
||||
|
||||
def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
|
||||
backlog=_DEFAULT_BACKLOG, flags=None):
|
||||
"""Creates listening sockets bound to the given port and address.
|
||||
|
@ -105,7 +147,7 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
|
|||
for res in set(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM,
|
||||
0, flags)):
|
||||
af, socktype, proto, canonname, sockaddr = res
|
||||
if (platform.system() == 'Darwin' and address == 'localhost' and
|
||||
if (sys.platform == 'darwin' and address == 'localhost' and
|
||||
af == socket.AF_INET6 and sockaddr[3] != 0):
|
||||
# Mac OS X includes a link-local address fe80::1%lo0 in the
|
||||
# getaddrinfo results for 'localhost'. However, the firewall
|
||||
|
@ -187,6 +229,9 @@ def add_accept_handler(sock, callback, io_loop=None):
|
|||
address of the other end of the connection). Note that this signature
|
||||
is different from the ``callback(fd, events)`` signature used for
|
||||
`.IOLoop` handlers.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
The ``io_loop`` argument is deprecated.
|
||||
"""
|
||||
if io_loop is None:
|
||||
io_loop = IOLoop.current()
|
||||
|
@ -301,6 +346,9 @@ class ExecutorResolver(Resolver):
|
|||
The executor will be shut down when the resolver is closed unless
|
||||
``close_resolver=False``; use this if you want to reuse the same
|
||||
executor elsewhere.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
The ``io_loop`` argument is deprecated.
|
||||
"""
|
||||
def initialize(self, io_loop=None, executor=None, close_executor=True):
|
||||
self.io_loop = io_loop or IOLoop.current()
|
||||
|
@ -413,7 +461,7 @@ def ssl_options_to_context(ssl_options):
|
|||
`~ssl.SSLContext` object.
|
||||
|
||||
The ``ssl_options`` dictionary contains keywords to be passed to
|
||||
`ssl.wrap_socket`. In Python 3.2+, `ssl.SSLContext` objects can
|
||||
`ssl.wrap_socket`. In Python 2.7.9+, `ssl.SSLContext` objects can
|
||||
be used instead. This function converts the dict form to its
|
||||
`~ssl.SSLContext` equivalent, and may be used when a component which
|
||||
accepts both forms needs to upgrade to the `~ssl.SSLContext` version
|
||||
|
@ -444,11 +492,11 @@ def ssl_options_to_context(ssl_options):
|
|||
def ssl_wrap_socket(socket, ssl_options, server_hostname=None, **kwargs):
|
||||
"""Returns an ``ssl.SSLSocket`` wrapping the given socket.
|
||||
|
||||
``ssl_options`` may be either a dictionary (as accepted by
|
||||
`ssl_options_to_context`) or an `ssl.SSLContext` object.
|
||||
Additional keyword arguments are passed to ``wrap_socket``
|
||||
(either the `~ssl.SSLContext` method or the `ssl` module function
|
||||
as appropriate).
|
||||
``ssl_options`` may be either an `ssl.SSLContext` object or a
|
||||
dictionary (as accepted by `ssl_options_to_context`). Additional
|
||||
keyword arguments are passed to ``wrap_socket`` (either the
|
||||
`~ssl.SSLContext` method or the `ssl` module function as
|
||||
appropriate).
|
||||
"""
|
||||
context = ssl_options_to_context(ssl_options)
|
||||
if hasattr(ssl, 'SSLContext') and isinstance(context, ssl.SSLContext):
|
||||
|
|
|
@ -204,6 +204,13 @@ class OptionParser(object):
|
|||
(name, self._options[name].file_name))
|
||||
frame = sys._getframe(0)
|
||||
options_file = frame.f_code.co_filename
|
||||
|
||||
# Can be called directly, or through top level define() fn, in which
|
||||
# case, step up above that frame to look for real caller.
|
||||
if (frame.f_back.f_code.co_filename == options_file and
|
||||
frame.f_back.f_code.co_name == 'define'):
|
||||
frame = frame.f_back
|
||||
|
||||
file_name = frame.f_back.f_code.co_filename
|
||||
if file_name == options_file:
|
||||
file_name = ""
|
||||
|
@ -249,7 +256,7 @@ class OptionParser(object):
|
|||
arg = args[i].lstrip("-")
|
||||
name, equals, value = arg.partition("=")
|
||||
name = name.replace('-', '_')
|
||||
if not name in self._options:
|
||||
if name not in self._options:
|
||||
self.print_help()
|
||||
raise Error('Unrecognized command line option: %r' % name)
|
||||
option = self._options[name]
|
||||
|
|
|
@ -12,6 +12,8 @@ unfinished callbacks on the event loop that fail when it resumes)
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
import functools
|
||||
|
||||
import tornado.concurrent
|
||||
from tornado.gen import convert_yielded
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado import stack_context
|
||||
|
||||
|
@ -27,8 +29,10 @@ except ImportError as e:
|
|||
# Re-raise the original asyncio error, not the trollius one.
|
||||
raise e
|
||||
|
||||
|
||||
class BaseAsyncIOLoop(IOLoop):
|
||||
def initialize(self, asyncio_loop, close_loop=False):
|
||||
def initialize(self, asyncio_loop, close_loop=False, **kwargs):
|
||||
super(BaseAsyncIOLoop, self).initialize(**kwargs)
|
||||
self.asyncio_loop = asyncio_loop
|
||||
self.close_loop = close_loop
|
||||
self.asyncio_loop.call_soon(self.make_current)
|
||||
|
@ -129,12 +133,29 @@ class BaseAsyncIOLoop(IOLoop):
|
|||
|
||||
|
||||
class AsyncIOMainLoop(BaseAsyncIOLoop):
|
||||
def initialize(self):
|
||||
def initialize(self, **kwargs):
|
||||
super(AsyncIOMainLoop, self).initialize(asyncio.get_event_loop(),
|
||||
close_loop=False)
|
||||
close_loop=False, **kwargs)
|
||||
|
||||
|
||||
class AsyncIOLoop(BaseAsyncIOLoop):
|
||||
def initialize(self):
|
||||
def initialize(self, **kwargs):
|
||||
super(AsyncIOLoop, self).initialize(asyncio.new_event_loop(),
|
||||
close_loop=True)
|
||||
close_loop=True, **kwargs)
|
||||
|
||||
|
||||
def to_tornado_future(asyncio_future):
|
||||
"""Convert an ``asyncio.Future`` to a `tornado.concurrent.Future`."""
|
||||
tf = tornado.concurrent.Future()
|
||||
tornado.concurrent.chain_future(asyncio_future, tf)
|
||||
return tf
|
||||
|
||||
|
||||
def to_asyncio_future(tornado_future):
|
||||
"""Convert a `tornado.concurrent.Future` to an ``asyncio.Future``."""
|
||||
af = asyncio.Future()
|
||||
tornado.concurrent.chain_future(tornado_future, af)
|
||||
return af
|
||||
|
||||
if hasattr(convert_yielded, 'register'):
|
||||
convert_yielded.register(asyncio.Future, to_tornado_future)
|
||||
|
|
|
@ -27,13 +27,14 @@ from __future__ import absolute_import, division, print_function, with_statement
|
|||
|
||||
import os
|
||||
|
||||
if os.name == 'nt':
|
||||
from tornado.platform.common import Waker
|
||||
from tornado.platform.windows import set_close_exec
|
||||
elif 'APPENGINE_RUNTIME' in os.environ:
|
||||
if 'APPENGINE_RUNTIME' in os.environ:
|
||||
from tornado.platform.common import Waker
|
||||
|
||||
def set_close_exec(fd):
|
||||
pass
|
||||
elif os.name == 'nt':
|
||||
from tornado.platform.common import Waker
|
||||
from tornado.platform.windows import set_close_exec
|
||||
else:
|
||||
from tornado.platform.posix import set_close_exec, Waker
|
||||
|
||||
|
@ -41,9 +42,13 @@ try:
|
|||
# monotime monkey-patches the time module to have a monotonic function
|
||||
# in versions of python before 3.3.
|
||||
import monotime
|
||||
# Silence pyflakes warning about this unused import
|
||||
monotime
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from time import monotonic as monotonic_time
|
||||
except ImportError:
|
||||
monotonic_time = None
|
||||
|
||||
__all__ = ['Waker', 'set_close_exec', 'monotonic_time']
|
||||
|
|
|
@ -18,6 +18,9 @@ class CaresResolver(Resolver):
|
|||
so it is only recommended for use in ``AF_INET`` (i.e. IPv4). This is
|
||||
the default for ``tornado.simple_httpclient``, but other libraries
|
||||
may default to ``AF_UNSPEC``.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
The ``io_loop`` argument is deprecated.
|
||||
"""
|
||||
def initialize(self, io_loop=None):
|
||||
self.io_loop = io_loop or IOLoop.current()
|
||||
|
|
|
@ -54,8 +54,7 @@ class _KQueue(object):
|
|||
if events & IOLoop.WRITE:
|
||||
kevents.append(select.kevent(
|
||||
fd, filter=select.KQ_FILTER_WRITE, flags=flags))
|
||||
if events & IOLoop.READ or not kevents:
|
||||
# Always read when there is not a write
|
||||
if events & IOLoop.READ:
|
||||
kevents.append(select.kevent(
|
||||
fd, filter=select.KQ_FILTER_READ, flags=flags))
|
||||
# Even though control() takes a list, it seems to return EINVAL
|
||||
|
|
|
@ -47,7 +47,7 @@ class _Select(object):
|
|||
# Closed connections are reported as errors by epoll and kqueue,
|
||||
# but as zero-byte reads by select, so when errors are requested
|
||||
# we need to listen for both read and error.
|
||||
self.read_fds.add(fd)
|
||||
# self.read_fds.add(fd)
|
||||
|
||||
def modify(self, fd, events):
|
||||
self.unregister(fd)
|
||||
|
|
|
@ -35,7 +35,7 @@ of the application::
|
|||
tornado.platform.twisted.install()
|
||||
from twisted.internet import reactor
|
||||
|
||||
When the app is ready to start, call `IOLoop.instance().start()`
|
||||
When the app is ready to start, call `IOLoop.current().start()`
|
||||
instead of `reactor.run()`.
|
||||
|
||||
It is also possible to create a non-global reactor by calling
|
||||
|
@ -70,8 +70,10 @@ import datetime
|
|||
import functools
|
||||
import numbers
|
||||
import socket
|
||||
import sys
|
||||
|
||||
import twisted.internet.abstract
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.posixbase import PosixReactorBase
|
||||
from twisted.internet.interfaces import \
|
||||
IReactorFDSet, IDelayedCall, IReactorTime, IReadDescriptor, IWriteDescriptor
|
||||
|
@ -84,6 +86,7 @@ import twisted.names.resolve
|
|||
|
||||
from zope.interface import implementer
|
||||
|
||||
from tornado.concurrent import Future
|
||||
from tornado.escape import utf8
|
||||
from tornado import gen
|
||||
import tornado.ioloop
|
||||
|
@ -147,6 +150,9 @@ class TornadoReactor(PosixReactorBase):
|
|||
We override `mainLoop` instead of `doIteration` and must implement
|
||||
timed call functionality on top of `IOLoop.add_timeout` rather than
|
||||
using the implementation in `PosixReactorBase`.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
The ``io_loop`` argument is deprecated.
|
||||
"""
|
||||
def __init__(self, io_loop=None):
|
||||
if not io_loop:
|
||||
|
@ -356,7 +362,11 @@ class _TestReactor(TornadoReactor):
|
|||
|
||||
|
||||
def install(io_loop=None):
|
||||
"""Install this package as the default Twisted reactor."""
|
||||
"""Install this package as the default Twisted reactor.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
The ``io_loop`` argument is deprecated.
|
||||
"""
|
||||
if not io_loop:
|
||||
io_loop = tornado.ioloop.IOLoop.current()
|
||||
reactor = TornadoReactor(io_loop)
|
||||
|
@ -406,7 +416,8 @@ class TwistedIOLoop(tornado.ioloop.IOLoop):
|
|||
because the ``SIGCHLD`` handlers used by Tornado and Twisted conflict
|
||||
with each other.
|
||||
"""
|
||||
def initialize(self, reactor=None):
|
||||
def initialize(self, reactor=None, **kwargs):
|
||||
super(TwistedIOLoop, self).initialize(**kwargs)
|
||||
if reactor is None:
|
||||
import twisted.internet.reactor
|
||||
reactor = twisted.internet.reactor
|
||||
|
@ -512,6 +523,9 @@ class TwistedResolver(Resolver):
|
|||
``socket.AF_UNSPEC``.
|
||||
|
||||
Requires Twisted 12.1 or newer.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
The ``io_loop`` argument is deprecated.
|
||||
"""
|
||||
def initialize(self, io_loop=None):
|
||||
self.io_loop = io_loop or IOLoop.current()
|
||||
|
@ -554,3 +568,18 @@ class TwistedResolver(Resolver):
|
|||
(resolved_family, (resolved, port)),
|
||||
]
|
||||
raise gen.Return(result)
|
||||
|
||||
if hasattr(gen.convert_yielded, 'register'):
|
||||
@gen.convert_yielded.register(Deferred)
|
||||
def _(d):
|
||||
f = Future()
|
||||
|
||||
def errback(failure):
|
||||
try:
|
||||
failure.raiseException()
|
||||
# Should never happen, but just in case
|
||||
raise Exception("errback called without error")
|
||||
except:
|
||||
f.set_exc_info(sys.exc_info())
|
||||
d.addCallbacks(f.set_result, errback)
|
||||
return f
|
||||
|
|
|
@ -29,6 +29,7 @@ import time
|
|||
|
||||
from binascii import hexlify
|
||||
|
||||
from tornado.concurrent import Future
|
||||
from tornado import ioloop
|
||||
from tornado.iostream import PipeIOStream
|
||||
from tornado.log import gen_log
|
||||
|
@ -48,6 +49,10 @@ except NameError:
|
|||
long = int # py3
|
||||
|
||||
|
||||
# Re-export this exception for convenience.
|
||||
CalledProcessError = subprocess.CalledProcessError
|
||||
|
||||
|
||||
def cpu_count():
|
||||
"""Returns the number of processors on this machine."""
|
||||
if multiprocessing is None:
|
||||
|
@ -191,6 +196,9 @@ class Subprocess(object):
|
|||
``tornado.process.Subprocess.STREAM``, which will make the corresponding
|
||||
attribute of the resulting Subprocess a `.PipeIOStream`.
|
||||
* A new keyword argument ``io_loop`` may be used to pass in an IOLoop.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
The ``io_loop`` argument is deprecated.
|
||||
"""
|
||||
STREAM = object()
|
||||
|
||||
|
@ -255,6 +263,33 @@ class Subprocess(object):
|
|||
Subprocess._waiting[self.pid] = self
|
||||
Subprocess._try_cleanup_process(self.pid)
|
||||
|
||||
def wait_for_exit(self, raise_error=True):
|
||||
"""Returns a `.Future` which resolves when the process exits.
|
||||
|
||||
Usage::
|
||||
|
||||
ret = yield proc.wait_for_exit()
|
||||
|
||||
This is a coroutine-friendly alternative to `set_exit_callback`
|
||||
(and a replacement for the blocking `subprocess.Popen.wait`).
|
||||
|
||||
By default, raises `subprocess.CalledProcessError` if the process
|
||||
has a non-zero exit status. Use ``wait_for_exit(raise_error=False)``
|
||||
to suppress this behavior and return the exit status without raising.
|
||||
|
||||
.. versionadded:: 4.2
|
||||
"""
|
||||
future = Future()
|
||||
|
||||
def callback(ret):
|
||||
if ret != 0 and raise_error:
|
||||
# Unfortunately we don't have the original args any more.
|
||||
future.set_exception(CalledProcessError(ret, None))
|
||||
else:
|
||||
future.set_result(ret)
|
||||
self.set_exit_callback(callback)
|
||||
return future
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, io_loop=None):
|
||||
"""Initializes the ``SIGCHLD`` handler.
|
||||
|
@ -263,6 +298,9 @@ class Subprocess(object):
|
|||
Note that the `.IOLoop` used for signal handling need not be the
|
||||
same one used by individual Subprocess objects (as long as the
|
||||
``IOLoops`` are each running in separate threads).
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
The ``io_loop`` argument is deprecated.
|
||||
"""
|
||||
if cls._initialized:
|
||||
return
|
||||
|
|
321
tornado/queues.py
Normal file
321
tornado/queues.py
Normal file
|
@ -0,0 +1,321 @@
|
|||
# Copyright 2015 The Tornado Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
|
||||
__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty']
|
||||
|
||||
import collections
|
||||
import heapq
|
||||
|
||||
from tornado import gen, ioloop
|
||||
from tornado.concurrent import Future
|
||||
from tornado.locks import Event
|
||||
|
||||
|
||||
class QueueEmpty(Exception):
|
||||
"""Raised by `.Queue.get_nowait` when the queue has no items."""
|
||||
pass
|
||||
|
||||
|
||||
class QueueFull(Exception):
|
||||
"""Raised by `.Queue.put_nowait` when a queue is at its maximum size."""
|
||||
pass
|
||||
|
||||
|
||||
def _set_timeout(future, timeout):
|
||||
if timeout:
|
||||
def on_timeout():
|
||||
future.set_exception(gen.TimeoutError())
|
||||
io_loop = ioloop.IOLoop.current()
|
||||
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
|
||||
future.add_done_callback(
|
||||
lambda _: io_loop.remove_timeout(timeout_handle))
|
||||
|
||||
|
||||
class Queue(object):
|
||||
"""Coordinate producer and consumer coroutines.
|
||||
|
||||
If maxsize is 0 (the default) the queue size is unbounded.
|
||||
|
||||
.. testcode::
|
||||
|
||||
q = queues.Queue(maxsize=2)
|
||||
|
||||
@gen.coroutine
|
||||
def consumer():
|
||||
while True:
|
||||
item = yield q.get()
|
||||
try:
|
||||
print('Doing work on %s' % item)
|
||||
yield gen.sleep(0.01)
|
||||
finally:
|
||||
q.task_done()
|
||||
|
||||
@gen.coroutine
|
||||
def producer():
|
||||
for item in range(5):
|
||||
yield q.put(item)
|
||||
print('Put %s' % item)
|
||||
|
||||
@gen.coroutine
|
||||
def main():
|
||||
consumer() # Start consumer.
|
||||
yield producer() # Wait for producer to put all tasks.
|
||||
yield q.join() # Wait for consumer to finish all tasks.
|
||||
print('Done')
|
||||
|
||||
io_loop.run_sync(main)
|
||||
|
||||
.. testoutput::
|
||||
|
||||
Put 0
|
||||
Put 1
|
||||
Put 2
|
||||
Doing work on 0
|
||||
Doing work on 1
|
||||
Put 3
|
||||
Doing work on 2
|
||||
Put 4
|
||||
Doing work on 3
|
||||
Doing work on 4
|
||||
Done
|
||||
"""
|
||||
def __init__(self, maxsize=0):
|
||||
if maxsize is None:
|
||||
raise TypeError("maxsize can't be None")
|
||||
|
||||
if maxsize < 0:
|
||||
raise ValueError("maxsize can't be negative")
|
||||
|
||||
self._maxsize = maxsize
|
||||
self._init()
|
||||
self._getters = collections.deque([]) # Futures.
|
||||
self._putters = collections.deque([]) # Pairs of (item, Future).
|
||||
self._unfinished_tasks = 0
|
||||
self._finished = Event()
|
||||
self._finished.set()
|
||||
|
||||
@property
|
||||
def maxsize(self):
|
||||
"""Number of items allowed in the queue."""
|
||||
return self._maxsize
|
||||
|
||||
def qsize(self):
|
||||
"""Number of items in the queue."""
|
||||
return len(self._queue)
|
||||
|
||||
def empty(self):
|
||||
return not self._queue
|
||||
|
||||
def full(self):
|
||||
if self.maxsize == 0:
|
||||
return False
|
||||
else:
|
||||
return self.qsize() >= self.maxsize
|
||||
|
||||
def put(self, item, timeout=None):
|
||||
"""Put an item into the queue, perhaps waiting until there is room.
|
||||
|
||||
Returns a Future, which raises `tornado.gen.TimeoutError` after a
|
||||
timeout.
|
||||
"""
|
||||
try:
|
||||
self.put_nowait(item)
|
||||
except QueueFull:
|
||||
future = Future()
|
||||
self._putters.append((item, future))
|
||||
_set_timeout(future, timeout)
|
||||
return future
|
||||
else:
|
||||
return gen._null_future
|
||||
|
||||
def put_nowait(self, item):
|
||||
"""Put an item into the queue without blocking.
|
||||
|
||||
If no free slot is immediately available, raise `QueueFull`.
|
||||
"""
|
||||
self._consume_expired()
|
||||
if self._getters:
|
||||
assert self.empty(), "queue non-empty, why are getters waiting?"
|
||||
getter = self._getters.popleft()
|
||||
self.__put_internal(item)
|
||||
getter.set_result(self._get())
|
||||
elif self.full():
|
||||
raise QueueFull
|
||||
else:
|
||||
self.__put_internal(item)
|
||||
|
||||
def get(self, timeout=None):
|
||||
"""Remove and return an item from the queue.
|
||||
|
||||
Returns a Future which resolves once an item is available, or raises
|
||||
`tornado.gen.TimeoutError` after a timeout.
|
||||
"""
|
||||
future = Future()
|
||||
try:
|
||||
future.set_result(self.get_nowait())
|
||||
except QueueEmpty:
|
||||
self._getters.append(future)
|
||||
_set_timeout(future, timeout)
|
||||
return future
|
||||
|
||||
def get_nowait(self):
|
||||
"""Remove and return an item from the queue without blocking.
|
||||
|
||||
Return an item if one is immediately available, else raise
|
||||
`QueueEmpty`.
|
||||
"""
|
||||
self._consume_expired()
|
||||
if self._putters:
|
||||
assert self.full(), "queue not full, why are putters waiting?"
|
||||
item, putter = self._putters.popleft()
|
||||
self.__put_internal(item)
|
||||
putter.set_result(None)
|
||||
return self._get()
|
||||
elif self.qsize():
|
||||
return self._get()
|
||||
else:
|
||||
raise QueueEmpty
|
||||
|
||||
def task_done(self):
|
||||
"""Indicate that a formerly enqueued task is complete.
|
||||
|
||||
Used by queue consumers. For each `.get` used to fetch a task, a
|
||||
subsequent call to `.task_done` tells the queue that the processing
|
||||
on the task is complete.
|
||||
|
||||
If a `.join` is blocking, it resumes when all items have been
|
||||
processed; that is, when every `.put` is matched by a `.task_done`.
|
||||
|
||||
Raises `ValueError` if called more times than `.put`.
|
||||
"""
|
||||
if self._unfinished_tasks <= 0:
|
||||
raise ValueError('task_done() called too many times')
|
||||
self._unfinished_tasks -= 1
|
||||
if self._unfinished_tasks == 0:
|
||||
self._finished.set()
|
||||
|
||||
def join(self, timeout=None):
|
||||
"""Block until all items in the queue are processed.
|
||||
|
||||
Returns a Future, which raises `tornado.gen.TimeoutError` after a
|
||||
timeout.
|
||||
"""
|
||||
return self._finished.wait(timeout)
|
||||
|
||||
# These three are overridable in subclasses.
|
||||
def _init(self):
|
||||
self._queue = collections.deque()
|
||||
|
||||
def _get(self):
|
||||
return self._queue.popleft()
|
||||
|
||||
def _put(self, item):
|
||||
self._queue.append(item)
|
||||
# End of the overridable methods.
|
||||
|
||||
def __put_internal(self, item):
|
||||
self._unfinished_tasks += 1
|
||||
self._finished.clear()
|
||||
self._put(item)
|
||||
|
||||
def _consume_expired(self):
|
||||
# Remove timed-out waiters.
|
||||
while self._putters and self._putters[0][1].done():
|
||||
self._putters.popleft()
|
||||
|
||||
while self._getters and self._getters[0].done():
|
||||
self._getters.popleft()
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s at %s %s>' % (
|
||||
type(self).__name__, hex(id(self)), self._format())
|
||||
|
||||
def __str__(self):
|
||||
return '<%s %s>' % (type(self).__name__, self._format())
|
||||
|
||||
def _format(self):
|
||||
result = 'maxsize=%r' % (self.maxsize, )
|
||||
if getattr(self, '_queue', None):
|
||||
result += ' queue=%r' % self._queue
|
||||
if self._getters:
|
||||
result += ' getters[%s]' % len(self._getters)
|
||||
if self._putters:
|
||||
result += ' putters[%s]' % len(self._putters)
|
||||
if self._unfinished_tasks:
|
||||
result += ' tasks=%s' % self._unfinished_tasks
|
||||
return result
|
||||
|
||||
|
||||
class PriorityQueue(Queue):
|
||||
"""A `.Queue` that retrieves entries in priority order, lowest first.
|
||||
|
||||
Entries are typically tuples like ``(priority number, data)``.
|
||||
|
||||
.. testcode::
|
||||
|
||||
q = queues.PriorityQueue()
|
||||
q.put((1, 'medium-priority item'))
|
||||
q.put((0, 'high-priority item'))
|
||||
q.put((10, 'low-priority item'))
|
||||
|
||||
print(q.get_nowait())
|
||||
print(q.get_nowait())
|
||||
print(q.get_nowait())
|
||||
|
||||
.. testoutput::
|
||||
|
||||
(0, 'high-priority item')
|
||||
(1, 'medium-priority item')
|
||||
(10, 'low-priority item')
|
||||
"""
|
||||
def _init(self):
|
||||
self._queue = []
|
||||
|
||||
def _put(self, item):
|
||||
heapq.heappush(self._queue, item)
|
||||
|
||||
def _get(self):
|
||||
return heapq.heappop(self._queue)
|
||||
|
||||
|
||||
class LifoQueue(Queue):
|
||||
"""A `.Queue` that retrieves the most recently put items first.
|
||||
|
||||
.. testcode::
|
||||
|
||||
q = queues.LifoQueue()
|
||||
q.put(3)
|
||||
q.put(2)
|
||||
q.put(1)
|
||||
|
||||
print(q.get_nowait())
|
||||
print(q.get_nowait())
|
||||
print(q.get_nowait())
|
||||
|
||||
.. testoutput::
|
||||
|
||||
1
|
||||
2
|
||||
3
|
||||
"""
|
||||
def _init(self):
|
||||
self._queue = []
|
||||
|
||||
def _put(self, item):
|
||||
self._queue.append(item)
|
||||
|
||||
def _get(self):
|
||||
return self._queue.pop()
|
|
@ -7,7 +7,7 @@ from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _
|
|||
from tornado import httputil
|
||||
from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters
|
||||
from tornado.iostream import StreamClosedError
|
||||
from tornado.netutil import Resolver, OverrideResolver
|
||||
from tornado.netutil import Resolver, OverrideResolver, _client_ssl_defaults
|
||||
from tornado.log import gen_log
|
||||
from tornado import stack_context
|
||||
from tornado.tcpclient import TCPClient
|
||||
|
@ -34,7 +34,7 @@ except ImportError:
|
|||
ssl = None
|
||||
|
||||
try:
|
||||
import lib.certifi
|
||||
import certifi
|
||||
except ImportError:
|
||||
certifi = None
|
||||
|
||||
|
@ -50,9 +50,6 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
|
|||
"""Non-blocking HTTP client with no external dependencies.
|
||||
|
||||
This class implements an HTTP 1.1 client on top of Tornado's IOStreams.
|
||||
It does not currently implement all applicable parts of the HTTP
|
||||
specification, but it does enough to work with major web service APIs.
|
||||
|
||||
Some features found in the curl-based AsyncHTTPClient are not yet
|
||||
supported. In particular, proxies are not supported, connections
|
||||
are not reused, and callers cannot select the network interface to be
|
||||
|
@ -60,25 +57,39 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
|
|||
"""
|
||||
def initialize(self, io_loop, max_clients=10,
|
||||
hostname_mapping=None, max_buffer_size=104857600,
|
||||
resolver=None, defaults=None, max_header_size=None):
|
||||
resolver=None, defaults=None, max_header_size=None,
|
||||
max_body_size=None):
|
||||
"""Creates a AsyncHTTPClient.
|
||||
|
||||
Only a single AsyncHTTPClient instance exists per IOLoop
|
||||
in order to provide limitations on the number of pending connections.
|
||||
force_instance=True may be used to suppress this behavior.
|
||||
``force_instance=True`` may be used to suppress this behavior.
|
||||
|
||||
max_clients is the number of concurrent requests that can be
|
||||
in progress. Note that this arguments are only used when the
|
||||
client is first created, and will be ignored when an existing
|
||||
client is reused.
|
||||
Note that because of this implicit reuse, unless ``force_instance``
|
||||
is used, only the first call to the constructor actually uses
|
||||
its arguments. It is recommended to use the ``configure`` method
|
||||
instead of the constructor to ensure that arguments take effect.
|
||||
|
||||
hostname_mapping is a dictionary mapping hostnames to IP addresses.
|
||||
``max_clients`` is the number of concurrent requests that can be
|
||||
in progress; when this limit is reached additional requests will be
|
||||
queued. Note that time spent waiting in this queue still counts
|
||||
against the ``request_timeout``.
|
||||
|
||||
``hostname_mapping`` is a dictionary mapping hostnames to IP addresses.
|
||||
It can be used to make local DNS changes when modifying system-wide
|
||||
settings like /etc/hosts is not possible or desirable (e.g. in
|
||||
settings like ``/etc/hosts`` is not possible or desirable (e.g. in
|
||||
unittests).
|
||||
|
||||
max_buffer_size is the number of bytes that can be read by IOStream. It
|
||||
defaults to 100mb.
|
||||
``max_buffer_size`` (default 100MB) is the number of bytes
|
||||
that can be read into memory at once. ``max_body_size``
|
||||
(defaults to ``max_buffer_size``) is the largest response body
|
||||
that the client will accept. Without a
|
||||
``streaming_callback``, the smaller of these two limits
|
||||
applies; with a ``streaming_callback`` only ``max_body_size``
|
||||
does.
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
Added the ``max_body_size`` argument.
|
||||
"""
|
||||
super(SimpleAsyncHTTPClient, self).initialize(io_loop,
|
||||
defaults=defaults)
|
||||
|
@ -88,6 +99,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
|
|||
self.waiting = {}
|
||||
self.max_buffer_size = max_buffer_size
|
||||
self.max_header_size = max_header_size
|
||||
self.max_body_size = max_body_size
|
||||
# TCPClient could create a Resolver for us, but we have to do it
|
||||
# ourselves to support hostname_mapping.
|
||||
if resolver:
|
||||
|
@ -135,10 +147,14 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
|
|||
release_callback = functools.partial(self._release_fetch, key)
|
||||
self._handle_request(request, release_callback, callback)
|
||||
|
||||
def _connection_class(self):
|
||||
return _HTTPConnection
|
||||
|
||||
def _handle_request(self, request, release_callback, final_callback):
|
||||
_HTTPConnection(self.io_loop, self, request, release_callback,
|
||||
self._connection_class()(
|
||||
self.io_loop, self, request, release_callback,
|
||||
final_callback, self.max_buffer_size, self.tcp_client,
|
||||
self.max_header_size)
|
||||
self.max_header_size, self.max_body_size)
|
||||
|
||||
def _release_fetch(self, key):
|
||||
del self.active[key]
|
||||
|
@ -166,7 +182,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
|
|||
|
||||
def __init__(self, io_loop, client, request, release_callback,
|
||||
final_callback, max_buffer_size, tcp_client,
|
||||
max_header_size):
|
||||
max_header_size, max_body_size):
|
||||
self.start_time = io_loop.time()
|
||||
self.io_loop = io_loop
|
||||
self.client = client
|
||||
|
@ -176,6 +192,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
|
|||
self.max_buffer_size = max_buffer_size
|
||||
self.tcp_client = tcp_client
|
||||
self.max_header_size = max_header_size
|
||||
self.max_body_size = max_body_size
|
||||
self.code = None
|
||||
self.headers = None
|
||||
self.chunks = []
|
||||
|
@ -193,12 +210,8 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
|
|||
netloc = self.parsed.netloc
|
||||
if "@" in netloc:
|
||||
userpass, _, netloc = netloc.rpartition("@")
|
||||
match = re.match(r'^(.+):(\d+)$', netloc)
|
||||
if match:
|
||||
host = match.group(1)
|
||||
port = int(match.group(2))
|
||||
else:
|
||||
host = netloc
|
||||
host, port = httputil.split_host_and_port(netloc)
|
||||
if port is None:
|
||||
port = 443 if self.parsed.scheme == "https" else 80
|
||||
if re.match(r'^\[.*\]$', host):
|
||||
# raw ipv6 addresses in urls are enclosed in brackets
|
||||
|
@ -224,12 +237,24 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
|
|||
|
||||
def _get_ssl_options(self, scheme):
|
||||
if scheme == "https":
|
||||
if self.request.ssl_options is not None:
|
||||
return self.request.ssl_options
|
||||
# If we are using the defaults, don't construct a
|
||||
# new SSLContext.
|
||||
if (self.request.validate_cert and
|
||||
self.request.ca_certs is None and
|
||||
self.request.client_cert is None and
|
||||
self.request.client_key is None):
|
||||
return _client_ssl_defaults
|
||||
ssl_options = {}
|
||||
if self.request.validate_cert:
|
||||
ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
|
||||
if self.request.ca_certs is not None:
|
||||
ssl_options["ca_certs"] = self.request.ca_certs
|
||||
else:
|
||||
elif not hasattr(ssl, 'create_default_context'):
|
||||
# When create_default_context is present,
|
||||
# we can omit the "ca_certs" parameter entirely,
|
||||
# which avoids the dependency on "certifi" for py34.
|
||||
ssl_options["ca_certs"] = _default_ca_certs()
|
||||
if self.request.client_key is not None:
|
||||
ssl_options["keyfile"] = self.request.client_key
|
||||
|
@ -323,7 +348,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
|
|||
if ((body_expected and not body_present) or
|
||||
(body_present and not body_expected)):
|
||||
raise ValueError(
|
||||
'Body must %sbe None for method %s (unelss '
|
||||
'Body must %sbe None for method %s (unless '
|
||||
'allow_nonstandard_methods is true)' %
|
||||
('not ' if body_expected else '', self.request.method))
|
||||
if self.request.expect_100_continue:
|
||||
|
@ -340,26 +365,30 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
|
|||
self.request.headers["Accept-Encoding"] = "gzip"
|
||||
req_path = ((self.parsed.path or '/') +
|
||||
(('?' + self.parsed.query) if self.parsed.query else ''))
|
||||
self.stream.set_nodelay(True)
|
||||
self.connection = HTTP1Connection(
|
||||
self.stream, True,
|
||||
HTTP1ConnectionParameters(
|
||||
no_keep_alive=True,
|
||||
max_header_size=self.max_header_size,
|
||||
decompress=self.request.decompress_response),
|
||||
self._sockaddr)
|
||||
self.connection = self._create_connection(stream)
|
||||
start_line = httputil.RequestStartLine(self.request.method,
|
||||
req_path, 'HTTP/1.1')
|
||||
req_path, '')
|
||||
self.connection.write_headers(start_line, self.request.headers)
|
||||
if self.request.expect_100_continue:
|
||||
self._read_response()
|
||||
else:
|
||||
self._write_body(True)
|
||||
|
||||
def _create_connection(self, stream):
|
||||
stream.set_nodelay(True)
|
||||
connection = HTTP1Connection(
|
||||
stream, True,
|
||||
HTTP1ConnectionParameters(
|
||||
no_keep_alive=True,
|
||||
max_header_size=self.max_header_size,
|
||||
max_body_size=self.max_body_size,
|
||||
decompress=self.request.decompress_response),
|
||||
self._sockaddr)
|
||||
return connection
|
||||
|
||||
def _write_body(self, start_read):
|
||||
if self.request.body is not None:
|
||||
self.connection.write(self.request.body)
|
||||
self.connection.finish()
|
||||
elif self.request.body_producer is not None:
|
||||
fut = self.request.body_producer(self.connection.write)
|
||||
if is_future(fut):
|
||||
|
|
|
@ -111,6 +111,7 @@ class _Connector(object):
|
|||
if self.timeout is not None:
|
||||
# If the first attempt failed, don't wait for the
|
||||
# timeout to try an address from the secondary queue.
|
||||
self.io_loop.remove_timeout(self.timeout)
|
||||
self.on_timeout()
|
||||
return
|
||||
self.clear_timeout()
|
||||
|
@ -135,6 +136,9 @@ class _Connector(object):
|
|||
|
||||
class TCPClient(object):
|
||||
"""A non-blocking TCP connection factory.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
The ``io_loop`` argument is deprecated.
|
||||
"""
|
||||
def __init__(self, resolver=None, io_loop=None):
|
||||
self.io_loop = io_loop or IOLoop.current()
|
||||
|
|
|
@ -41,14 +41,15 @@ class TCPServer(object):
|
|||
To use `TCPServer`, define a subclass which overrides the `handle_stream`
|
||||
method.
|
||||
|
||||
To make this server serve SSL traffic, send the ssl_options dictionary
|
||||
argument with the arguments required for the `ssl.wrap_socket` method,
|
||||
including "certfile" and "keyfile"::
|
||||
To make this server serve SSL traffic, send the ``ssl_options`` keyword
|
||||
argument with an `ssl.SSLContext` object. For compatibility with older
|
||||
versions of Python ``ssl_options`` may also be a dictionary of keyword
|
||||
arguments for the `ssl.wrap_socket` method.::
|
||||
|
||||
TCPServer(ssl_options={
|
||||
"certfile": os.path.join(data_dir, "mydomain.crt"),
|
||||
"keyfile": os.path.join(data_dir, "mydomain.key"),
|
||||
})
|
||||
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"),
|
||||
os.path.join(data_dir, "mydomain.key"))
|
||||
TCPServer(ssl_options=ssl_ctx)
|
||||
|
||||
`TCPServer` initialization follows one of three patterns:
|
||||
|
||||
|
@ -56,14 +57,14 @@ class TCPServer(object):
|
|||
|
||||
server = TCPServer()
|
||||
server.listen(8888)
|
||||
IOLoop.instance().start()
|
||||
IOLoop.current().start()
|
||||
|
||||
2. `bind`/`start`: simple multi-process::
|
||||
|
||||
server = TCPServer()
|
||||
server.bind(8888)
|
||||
server.start(0) # Forks multiple sub-processes
|
||||
IOLoop.instance().start()
|
||||
IOLoop.current().start()
|
||||
|
||||
When using this interface, an `.IOLoop` must *not* be passed
|
||||
to the `TCPServer` constructor. `start` will always start
|
||||
|
@ -75,7 +76,7 @@ class TCPServer(object):
|
|||
tornado.process.fork_processes(0)
|
||||
server = TCPServer()
|
||||
server.add_sockets(sockets)
|
||||
IOLoop.instance().start()
|
||||
IOLoop.current().start()
|
||||
|
||||
The `add_sockets` interface is more complicated, but it can be
|
||||
used with `tornado.process.fork_processes` to give you more
|
||||
|
@ -95,7 +96,7 @@ class TCPServer(object):
|
|||
self._pending_sockets = []
|
||||
self._started = False
|
||||
self.max_buffer_size = max_buffer_size
|
||||
self.read_chunk_size = None
|
||||
self.read_chunk_size = read_chunk_size
|
||||
|
||||
# Verify the SSL options. Otherwise we don't get errors until clients
|
||||
# connect. This doesn't verify that the keys are legitimate, but
|
||||
|
@ -212,7 +213,20 @@ class TCPServer(object):
|
|||
sock.close()
|
||||
|
||||
def handle_stream(self, stream, address):
|
||||
"""Override to handle a new `.IOStream` from an incoming connection."""
|
||||
"""Override to handle a new `.IOStream` from an incoming connection.
|
||||
|
||||
This method may be a coroutine; if so any exceptions it raises
|
||||
asynchronously will be logged. Accepting of incoming connections
|
||||
will not be blocked by this coroutine.
|
||||
|
||||
If this `TCPServer` is configured for SSL, ``handle_stream``
|
||||
may be called before the SSL handshake has completed. Use
|
||||
`.SSLIOStream.wait_for_handshake` if you need to verify the client's
|
||||
certificate or use NPN/ALPN.
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
Added the option for this method to be a coroutine.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _handle_connection(self, connection, address):
|
||||
|
@ -252,6 +266,8 @@ class TCPServer(object):
|
|||
stream = IOStream(connection, io_loop=self.io_loop,
|
||||
max_buffer_size=self.max_buffer_size,
|
||||
read_chunk_size=self.read_chunk_size)
|
||||
self.handle_stream(stream, address)
|
||||
future = self.handle_stream(stream, address)
|
||||
if future is not None:
|
||||
self.io_loop.add_future(future, lambda f: f.result())
|
||||
except Exception:
|
||||
app_log.error("Error in connection callback", exc_info=True)
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
Test coverage is almost non-existent, but it's a start. Be sure to
|
||||
set PYTHONPATH appropriately (generally to the root directory of your
|
||||
tornado checkout) when running tests to make sure you're getting the
|
||||
version of the tornado package that you expect.
|
69
tornado/test/asyncio_test.py
Normal file
69
tornado/test/asyncio_test.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
|
||||
import sys
|
||||
import textwrap
|
||||
|
||||
from tornado import gen
|
||||
from tornado.testing import AsyncTestCase, gen_test
|
||||
from tornado.test.util import unittest
|
||||
|
||||
try:
|
||||
from tornado.platform.asyncio import asyncio, AsyncIOLoop
|
||||
except ImportError:
|
||||
asyncio = None
|
||||
|
||||
skipIfNoSingleDispatch = unittest.skipIf(
|
||||
gen.singledispatch is None, "singledispatch module not present")
|
||||
|
||||
|
||||
@unittest.skipIf(asyncio is None, "asyncio module not present")
|
||||
class AsyncIOLoopTest(AsyncTestCase):
|
||||
def get_new_ioloop(self):
|
||||
io_loop = AsyncIOLoop()
|
||||
asyncio.set_event_loop(io_loop.asyncio_loop)
|
||||
return io_loop
|
||||
|
||||
def test_asyncio_callback(self):
|
||||
# Basic test that the asyncio loop is set up correctly.
|
||||
asyncio.get_event_loop().call_soon(self.stop)
|
||||
self.wait()
|
||||
|
||||
@skipIfNoSingleDispatch
|
||||
@gen_test
|
||||
def test_asyncio_future(self):
|
||||
# Test that we can yield an asyncio future from a tornado coroutine.
|
||||
# Without 'yield from', we must wrap coroutines in asyncio.async.
|
||||
x = yield asyncio.async(
|
||||
asyncio.get_event_loop().run_in_executor(None, lambda: 42))
|
||||
self.assertEqual(x, 42)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 3),
|
||||
'PEP 380 not available')
|
||||
@skipIfNoSingleDispatch
|
||||
@gen_test
|
||||
def test_asyncio_yield_from(self):
|
||||
# Test that we can use asyncio coroutines with 'yield from'
|
||||
# instead of asyncio.async(). This requires python 3.3 syntax.
|
||||
global_namespace = dict(globals(), **locals())
|
||||
local_namespace = {}
|
||||
exec(textwrap.dedent("""
|
||||
@gen.coroutine
|
||||
def f():
|
||||
event_loop = asyncio.get_event_loop()
|
||||
x = yield from event_loop.run_in_executor(None, lambda: 42)
|
||||
return x
|
||||
"""), global_namespace, local_namespace)
|
||||
result = yield local_namespace['f']()
|
||||
self.assertEqual(result, 42)
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin, TwitterMixin, GoogleMixin, AuthError
|
||||
from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin, TwitterMixin, AuthError
|
||||
from tornado.concurrent import Future
|
||||
from tornado.escape import json_decode
|
||||
from tornado import gen
|
||||
|
@ -238,28 +238,6 @@ class TwitterServerVerifyCredentialsHandler(RequestHandler):
|
|||
self.write(dict(screen_name='foo', name='Foo'))
|
||||
|
||||
|
||||
class GoogleOpenIdClientLoginHandler(RequestHandler, GoogleMixin):
|
||||
def initialize(self, test):
|
||||
self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
|
||||
|
||||
@asynchronous
|
||||
def get(self):
|
||||
if self.get_argument("openid.mode", None):
|
||||
self.get_authenticated_user(self.on_user)
|
||||
return
|
||||
res = self.authenticate_redirect()
|
||||
assert isinstance(res, Future)
|
||||
assert res.done()
|
||||
|
||||
def on_user(self, user):
|
||||
if user is None:
|
||||
raise Exception("user is None")
|
||||
self.finish(user)
|
||||
|
||||
def get_auth_http_client(self):
|
||||
return self.settings['http_client']
|
||||
|
||||
|
||||
class AuthTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
return Application(
|
||||
|
@ -286,7 +264,6 @@ class AuthTest(AsyncHTTPTestCase):
|
|||
('/twitter/client/login_gen_coroutine', TwitterClientLoginGenCoroutineHandler, dict(test=self)),
|
||||
('/twitter/client/show_user', TwitterClientShowUserHandler, dict(test=self)),
|
||||
('/twitter/client/show_user_future', TwitterClientShowUserFutureHandler, dict(test=self)),
|
||||
('/google/client/openid_login', GoogleOpenIdClientLoginHandler, dict(test=self)),
|
||||
|
||||
# simulated servers
|
||||
('/openid/server/authenticate', OpenIdServerAuthenticateHandler),
|
||||
|
@ -436,16 +413,3 @@ class AuthTest(AsyncHTTPTestCase):
|
|||
response = self.fetch('/twitter/client/show_user_future?name=error')
|
||||
self.assertEqual(response.code, 500)
|
||||
self.assertIn(b'Error response HTTP 500', response.body)
|
||||
|
||||
def test_google_redirect(self):
|
||||
# same as test_openid_redirect
|
||||
response = self.fetch('/google/client/openid_login', follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
self.assertTrue(
|
||||
'/openid/server/authenticate?' in response.headers['Location'])
|
||||
|
||||
def test_google_get_user(self):
|
||||
response = self.fetch('/google/client/openid_login?openid.mode=blah&openid.ns.ax=http://openid.net/srv/ax/1.0&openid.ax.type.email=http://axschema.org/contact/email&openid.ax.value.email=foo@example.com', follow_redirects=False)
|
||||
response.rethrow()
|
||||
parsed = json_decode(response.body)
|
||||
self.assertEqual(parsed["email"], "foo@example.com")
|
||||
|
|
|
@ -21,13 +21,14 @@ import socket
|
|||
import sys
|
||||
import traceback
|
||||
|
||||
from tornado.concurrent import Future, return_future, ReturnValueIgnoredError
|
||||
from tornado.concurrent import Future, return_future, ReturnValueIgnoredError, run_on_executor
|
||||
from tornado.escape import utf8, to_unicode
|
||||
from tornado import gen
|
||||
from tornado.iostream import IOStream
|
||||
from tornado import stack_context
|
||||
from tornado.tcpserver import TCPServer
|
||||
from tornado.testing import AsyncTestCase, LogTrapTestCase, bind_unused_port, gen_test
|
||||
from tornado.test.util import unittest
|
||||
|
||||
|
||||
try:
|
||||
|
@ -334,3 +335,81 @@ class DecoratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
|
|||
|
||||
class GeneratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
|
||||
client_class = GeneratorCapClient
|
||||
|
||||
|
||||
@unittest.skipIf(futures is None, "concurrent.futures module not present")
|
||||
class RunOnExecutorTest(AsyncTestCase):
|
||||
@gen_test
|
||||
def test_no_calling(self):
|
||||
class Object(object):
|
||||
def __init__(self, io_loop):
|
||||
self.io_loop = io_loop
|
||||
self.executor = futures.thread.ThreadPoolExecutor(1)
|
||||
|
||||
@run_on_executor
|
||||
def f(self):
|
||||
return 42
|
||||
|
||||
o = Object(io_loop=self.io_loop)
|
||||
answer = yield o.f()
|
||||
self.assertEqual(answer, 42)
|
||||
|
||||
@gen_test
|
||||
def test_call_with_no_args(self):
|
||||
class Object(object):
|
||||
def __init__(self, io_loop):
|
||||
self.io_loop = io_loop
|
||||
self.executor = futures.thread.ThreadPoolExecutor(1)
|
||||
|
||||
@run_on_executor()
|
||||
def f(self):
|
||||
return 42
|
||||
|
||||
o = Object(io_loop=self.io_loop)
|
||||
answer = yield o.f()
|
||||
self.assertEqual(answer, 42)
|
||||
|
||||
@gen_test
|
||||
def test_call_with_io_loop(self):
|
||||
class Object(object):
|
||||
def __init__(self, io_loop):
|
||||
self._io_loop = io_loop
|
||||
self.executor = futures.thread.ThreadPoolExecutor(1)
|
||||
|
||||
@run_on_executor(io_loop='_io_loop')
|
||||
def f(self):
|
||||
return 42
|
||||
|
||||
o = Object(io_loop=self.io_loop)
|
||||
answer = yield o.f()
|
||||
self.assertEqual(answer, 42)
|
||||
|
||||
@gen_test
|
||||
def test_call_with_executor(self):
|
||||
class Object(object):
|
||||
def __init__(self, io_loop):
|
||||
self.io_loop = io_loop
|
||||
self.__executor = futures.thread.ThreadPoolExecutor(1)
|
||||
|
||||
@run_on_executor(executor='_Object__executor')
|
||||
def f(self):
|
||||
return 42
|
||||
|
||||
o = Object(io_loop=self.io_loop)
|
||||
answer = yield o.f()
|
||||
self.assertEqual(answer, 42)
|
||||
|
||||
@gen_test
|
||||
def test_call_with_both(self):
|
||||
class Object(object):
|
||||
def __init__(self, io_loop):
|
||||
self._io_loop = io_loop
|
||||
self.__executor = futures.thread.ThreadPoolExecutor(1)
|
||||
|
||||
@run_on_executor(io_loop='_io_loop', executor='_Object__executor')
|
||||
def f(self):
|
||||
return 42
|
||||
|
||||
o = Object(io_loop=self.io_loop)
|
||||
answer = yield o.f()
|
||||
self.assertEqual(answer, 42)
|
||||
|
|
|
@ -8,7 +8,7 @@ from tornado.stack_context import ExceptionStackContext
|
|||
from tornado.testing import AsyncHTTPTestCase
|
||||
from tornado.test import httpclient_test
|
||||
from tornado.test.util import unittest
|
||||
from tornado.web import Application, RequestHandler, URLSpec
|
||||
from tornado.web import Application, RequestHandler
|
||||
|
||||
|
||||
try:
|
||||
|
|
|
@ -217,8 +217,7 @@ class EscapeTestCase(unittest.TestCase):
|
|||
self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9")
|
||||
|
||||
def test_squeeze(self):
|
||||
self.assertEqual(squeeze(u('sequences of whitespace chars'))
|
||||
, u('sequences of whitespace chars'))
|
||||
self.assertEqual(squeeze(u('sequences of whitespace chars')), u('sequences of whitespace chars'))
|
||||
|
||||
def test_recursive_unicode(self):
|
||||
tests = {
|
||||
|
|
|
@ -62,6 +62,11 @@ class GenEngineTest(AsyncTestCase):
|
|||
def async_future(self, result, callback):
|
||||
self.io_loop.add_callback(callback, result)
|
||||
|
||||
@gen.coroutine
|
||||
def async_exception(self, e):
|
||||
yield gen.moment
|
||||
raise e
|
||||
|
||||
def test_no_yield(self):
|
||||
@gen.engine
|
||||
def f():
|
||||
|
@ -385,11 +390,56 @@ class GenEngineTest(AsyncTestCase):
|
|||
results = yield [self.async_future(1), self.async_future(2)]
|
||||
self.assertEqual(results, [1, 2])
|
||||
|
||||
@gen_test
|
||||
def test_multi_future_duplicate(self):
|
||||
f = self.async_future(2)
|
||||
results = yield [self.async_future(1), f, self.async_future(3), f]
|
||||
self.assertEqual(results, [1, 2, 3, 2])
|
||||
|
||||
@gen_test
|
||||
def test_multi_dict_future(self):
|
||||
results = yield dict(foo=self.async_future(1), bar=self.async_future(2))
|
||||
self.assertEqual(results, dict(foo=1, bar=2))
|
||||
|
||||
@gen_test
|
||||
def test_multi_exceptions(self):
|
||||
with ExpectLog(app_log, "Multiple exceptions in yield list"):
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
yield gen.Multi([self.async_exception(RuntimeError("error 1")),
|
||||
self.async_exception(RuntimeError("error 2"))])
|
||||
self.assertEqual(str(cm.exception), "error 1")
|
||||
|
||||
# With only one exception, no error is logged.
|
||||
with self.assertRaises(RuntimeError):
|
||||
yield gen.Multi([self.async_exception(RuntimeError("error 1")),
|
||||
self.async_future(2)])
|
||||
|
||||
# Exception logging may be explicitly quieted.
|
||||
with self.assertRaises(RuntimeError):
|
||||
yield gen.Multi([self.async_exception(RuntimeError("error 1")),
|
||||
self.async_exception(RuntimeError("error 2"))],
|
||||
quiet_exceptions=RuntimeError)
|
||||
|
||||
@gen_test
|
||||
def test_multi_future_exceptions(self):
|
||||
with ExpectLog(app_log, "Multiple exceptions in yield list"):
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
yield [self.async_exception(RuntimeError("error 1")),
|
||||
self.async_exception(RuntimeError("error 2"))]
|
||||
self.assertEqual(str(cm.exception), "error 1")
|
||||
|
||||
# With only one exception, no error is logged.
|
||||
with self.assertRaises(RuntimeError):
|
||||
yield [self.async_exception(RuntimeError("error 1")),
|
||||
self.async_future(2)]
|
||||
|
||||
# Exception logging may be explicitly quieted.
|
||||
with self.assertRaises(RuntimeError):
|
||||
yield gen.multi_future(
|
||||
[self.async_exception(RuntimeError("error 1")),
|
||||
self.async_exception(RuntimeError("error 2"))],
|
||||
quiet_exceptions=RuntimeError)
|
||||
|
||||
def test_arguments(self):
|
||||
@gen.engine
|
||||
def f():
|
||||
|
@ -816,6 +866,7 @@ class GenCoroutineTest(AsyncTestCase):
|
|||
@gen_test
|
||||
def test_moment(self):
|
||||
calls = []
|
||||
|
||||
@gen.coroutine
|
||||
def f(name, yieldable):
|
||||
for i in range(5):
|
||||
|
@ -838,6 +889,34 @@ class GenCoroutineTest(AsyncTestCase):
|
|||
yield [f('a', gen.moment), f('b', immediate)]
|
||||
self.assertEqual(''.join(calls), 'abbbbbaaaa')
|
||||
|
||||
@gen_test
|
||||
def test_sleep(self):
|
||||
yield gen.sleep(0.01)
|
||||
self.finished = True
|
||||
|
||||
@skipBefore33
|
||||
@gen_test
|
||||
def test_py3_leak_exception_context(self):
|
||||
class LeakedException(Exception):
|
||||
pass
|
||||
|
||||
@gen.coroutine
|
||||
def inner(iteration):
|
||||
raise LeakedException(iteration)
|
||||
|
||||
try:
|
||||
yield inner(1)
|
||||
except LeakedException as e:
|
||||
self.assertEqual(str(e), "1")
|
||||
self.assertIsNone(e.__context__)
|
||||
|
||||
try:
|
||||
yield inner(2)
|
||||
except LeakedException as e:
|
||||
self.assertEqual(str(e), "2")
|
||||
self.assertIsNone(e.__context__)
|
||||
|
||||
self.finished = True
|
||||
|
||||
class GenSequenceHandler(RequestHandler):
|
||||
@asynchronous
|
||||
|
@ -1031,7 +1110,7 @@ class WithTimeoutTest(AsyncTestCase):
|
|||
self.io_loop.add_timeout(datetime.timedelta(seconds=0.1),
|
||||
lambda: future.set_result('asdf'))
|
||||
result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
|
||||
future)
|
||||
future, io_loop=self.io_loop)
|
||||
self.assertEqual(result, 'asdf')
|
||||
|
||||
@gen_test
|
||||
|
@ -1039,16 +1118,17 @@ class WithTimeoutTest(AsyncTestCase):
|
|||
future = Future()
|
||||
self.io_loop.add_timeout(
|
||||
datetime.timedelta(seconds=0.1),
|
||||
lambda: future.set_exception(ZeroDivisionError))
|
||||
lambda: future.set_exception(ZeroDivisionError()))
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
yield gen.with_timeout(datetime.timedelta(seconds=3600), future)
|
||||
yield gen.with_timeout(datetime.timedelta(seconds=3600),
|
||||
future, io_loop=self.io_loop)
|
||||
|
||||
@gen_test
|
||||
def test_already_resolved(self):
|
||||
future = Future()
|
||||
future.set_result('asdf')
|
||||
result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
|
||||
future)
|
||||
future, io_loop=self.io_loop)
|
||||
self.assertEqual(result, 'asdf')
|
||||
|
||||
@unittest.skipIf(futures is None, 'futures module not present')
|
||||
|
@ -1067,5 +1147,117 @@ class WithTimeoutTest(AsyncTestCase):
|
|||
executor.submit(lambda: None))
|
||||
|
||||
|
||||
class WaitIteratorTest(AsyncTestCase):
|
||||
@gen_test
|
||||
def test_empty_iterator(self):
|
||||
g = gen.WaitIterator()
|
||||
self.assertTrue(g.done(), 'empty generator iterated')
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
g = gen.WaitIterator(False, bar=False)
|
||||
|
||||
self.assertEqual(g.current_index, None, "bad nil current index")
|
||||
self.assertEqual(g.current_future, None, "bad nil current future")
|
||||
|
||||
@gen_test
|
||||
def test_already_done(self):
|
||||
f1 = Future()
|
||||
f2 = Future()
|
||||
f3 = Future()
|
||||
f1.set_result(24)
|
||||
f2.set_result(42)
|
||||
f3.set_result(84)
|
||||
|
||||
g = gen.WaitIterator(f1, f2, f3)
|
||||
i = 0
|
||||
while not g.done():
|
||||
r = yield g.next()
|
||||
# Order is not guaranteed, but the current implementation
|
||||
# preserves ordering of already-done Futures.
|
||||
if i == 0:
|
||||
self.assertEqual(g.current_index, 0)
|
||||
self.assertIs(g.current_future, f1)
|
||||
self.assertEqual(r, 24)
|
||||
elif i == 1:
|
||||
self.assertEqual(g.current_index, 1)
|
||||
self.assertIs(g.current_future, f2)
|
||||
self.assertEqual(r, 42)
|
||||
elif i == 2:
|
||||
self.assertEqual(g.current_index, 2)
|
||||
self.assertIs(g.current_future, f3)
|
||||
self.assertEqual(r, 84)
|
||||
i += 1
|
||||
|
||||
self.assertEqual(g.current_index, None, "bad nil current index")
|
||||
self.assertEqual(g.current_future, None, "bad nil current future")
|
||||
|
||||
dg = gen.WaitIterator(f1=f1, f2=f2)
|
||||
|
||||
while not dg.done():
|
||||
dr = yield dg.next()
|
||||
if dg.current_index == "f1":
|
||||
self.assertTrue(dg.current_future == f1 and dr == 24,
|
||||
"WaitIterator dict status incorrect")
|
||||
elif dg.current_index == "f2":
|
||||
self.assertTrue(dg.current_future == f2 and dr == 42,
|
||||
"WaitIterator dict status incorrect")
|
||||
else:
|
||||
self.fail("got bad WaitIterator index {}".format(
|
||||
dg.current_index))
|
||||
|
||||
i += 1
|
||||
|
||||
self.assertEqual(dg.current_index, None, "bad nil current index")
|
||||
self.assertEqual(dg.current_future, None, "bad nil current future")
|
||||
|
||||
def finish_coroutines(self, iteration, futures):
|
||||
if iteration == 3:
|
||||
futures[2].set_result(24)
|
||||
elif iteration == 5:
|
||||
futures[0].set_exception(ZeroDivisionError())
|
||||
elif iteration == 8:
|
||||
futures[1].set_result(42)
|
||||
futures[3].set_result(84)
|
||||
|
||||
if iteration < 8:
|
||||
self.io_loop.add_callback(self.finish_coroutines, iteration + 1, futures)
|
||||
|
||||
@gen_test
|
||||
def test_iterator(self):
|
||||
futures = [Future(), Future(), Future(), Future()]
|
||||
|
||||
self.finish_coroutines(0, futures)
|
||||
|
||||
g = gen.WaitIterator(*futures)
|
||||
|
||||
i = 0
|
||||
while not g.done():
|
||||
try:
|
||||
r = yield g.next()
|
||||
except ZeroDivisionError:
|
||||
self.assertIs(g.current_future, futures[0],
|
||||
'exception future invalid')
|
||||
else:
|
||||
if i == 0:
|
||||
self.assertEqual(r, 24, 'iterator value incorrect')
|
||||
self.assertEqual(g.current_index, 2, 'wrong index')
|
||||
elif i == 2:
|
||||
self.assertEqual(r, 42, 'iterator value incorrect')
|
||||
self.assertEqual(g.current_index, 1, 'wrong index')
|
||||
elif i == 3:
|
||||
self.assertEqual(r, 84, 'iterator value incorrect')
|
||||
self.assertEqual(g.current_index, 3, 'wrong index')
|
||||
i += 1
|
||||
|
||||
@gen_test
|
||||
def test_no_ref(self):
|
||||
# In this usage, there is no direct hard reference to the
|
||||
# WaitIterator itself, only the Future it returns. Since
|
||||
# WaitIterator uses weak references internally to improve GC
|
||||
# performance, this used to cause problems.
|
||||
yield gen.with_timeout(datetime.timedelta(seconds=0.1),
|
||||
gen.WaitIterator(gen.sleep(0)).next())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -1,11 +1,16 @@
|
|||
# flake8: noqa
|
||||
# Dummy source file to allow creation of the initial .po file in the
|
||||
# same way as a real project. I'm not entirely sure about the real
|
||||
# workflow here, but this seems to work.
|
||||
#
|
||||
# 1) xgettext --language=Python --keyword=_:1,2 -d tornado_test extract_me.py -o tornado_test.po
|
||||
# 2) Edit tornado_test.po, setting CHARSET and setting msgstr
|
||||
# 1) xgettext --language=Python --keyword=_:1,2 --keyword=pgettext:1c,2 --keyword=pgettext:1c,2,3 extract_me.py -o tornado_test.po
|
||||
# 2) Edit tornado_test.po, setting CHARSET, Plural-Forms and setting msgstr
|
||||
# 3) msgfmt tornado_test.po -o tornado_test.mo
|
||||
# 4) Put the file in the proper location: $LANG/LC_MESSAGES
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
_("school")
|
||||
pgettext("law", "right")
|
||||
pgettext("good", "right")
|
||||
pgettext("organization", "club", "clubs", 1)
|
||||
pgettext("stick", "club", "clubs", 1)
|
||||
|
|
Binary file not shown.
|
@ -8,7 +8,7 @@ msgid ""
|
|||
msgstr ""
|
||||
"Project-Id-Version: PACKAGE VERSION\n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2012-06-14 01:10-0700\n"
|
||||
"POT-Creation-Date: 2015-01-27 11:05+0300\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language-Team: LANGUAGE <LL@li.org>\n"
|
||||
|
@ -16,7 +16,32 @@ msgstr ""
|
|||
"MIME-Version: 1.0\n"
|
||||
"Content-Type: text/plain; charset=utf-8\n"
|
||||
"Content-Transfer-Encoding: 8bit\n"
|
||||
"Plural-Forms: nplurals=2; plural=(n > 1);\n"
|
||||
|
||||
#: extract_me.py:1
|
||||
#: extract_me.py:11
|
||||
msgid "school"
|
||||
msgstr "école"
|
||||
|
||||
#: extract_me.py:12
|
||||
msgctxt "law"
|
||||
msgid "right"
|
||||
msgstr "le droit"
|
||||
|
||||
#: extract_me.py:13
|
||||
msgctxt "good"
|
||||
msgid "right"
|
||||
msgstr "le bien"
|
||||
|
||||
#: extract_me.py:14
|
||||
msgctxt "organization"
|
||||
msgid "club"
|
||||
msgid_plural "clubs"
|
||||
msgstr[0] "le club"
|
||||
msgstr[1] "les clubs"
|
||||
|
||||
#: extract_me.py:15
|
||||
msgctxt "stick"
|
||||
msgid "club"
|
||||
msgid_plural "clubs"
|
||||
msgstr[0] "le bâton"
|
||||
msgstr[1] "les bâtons"
|
||||
|
|
|
@ -12,6 +12,7 @@ import datetime
|
|||
from io import BytesIO
|
||||
|
||||
from tornado.escape import utf8
|
||||
from tornado import gen
|
||||
from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.ioloop import IOLoop
|
||||
|
@ -52,9 +53,12 @@ class RedirectHandler(RequestHandler):
|
|||
|
||||
|
||||
class ChunkHandler(RequestHandler):
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
self.write("asdf")
|
||||
self.flush()
|
||||
# Wait a bit to ensure the chunks are sent and received separately.
|
||||
yield gen.sleep(0.01)
|
||||
self.write("qwer")
|
||||
|
||||
|
||||
|
@ -178,6 +182,8 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase):
|
|||
sock, port = bind_unused_port()
|
||||
with closing(sock):
|
||||
def write_response(stream, request_data):
|
||||
if b"HTTP/1." not in request_data:
|
||||
self.skipTest("requires HTTP/1.x")
|
||||
stream.write(b"""\
|
||||
HTTP/1.1 200 OK
|
||||
Transfer-Encoding: chunked
|
||||
|
@ -300,23 +306,26 @@ Transfer-Encoding: chunked
|
|||
chunks = []
|
||||
|
||||
def header_callback(header_line):
|
||||
if header_line.startswith('HTTP/'):
|
||||
if header_line.startswith('HTTP/1.1 101'):
|
||||
# Upgrading to HTTP/2
|
||||
pass
|
||||
elif header_line.startswith('HTTP/'):
|
||||
first_line.append(header_line)
|
||||
elif header_line != '\r\n':
|
||||
k, v = header_line.split(':', 1)
|
||||
headers[k] = v.strip()
|
||||
headers[k.lower()] = v.strip()
|
||||
|
||||
def streaming_callback(chunk):
|
||||
# All header callbacks are run before any streaming callbacks,
|
||||
# so the header data is available to process the data as it
|
||||
# comes in.
|
||||
self.assertEqual(headers['Content-Type'], 'text/html; charset=UTF-8')
|
||||
self.assertEqual(headers['content-type'], 'text/html; charset=UTF-8')
|
||||
chunks.append(chunk)
|
||||
|
||||
self.fetch('/chunk', header_callback=header_callback,
|
||||
streaming_callback=streaming_callback)
|
||||
self.assertEqual(len(first_line), 1)
|
||||
self.assertRegexpMatches(first_line[0], 'HTTP/1.[01] 200 OK\r\n')
|
||||
self.assertEqual(len(first_line), 1, first_line)
|
||||
self.assertRegexpMatches(first_line[0], 'HTTP/[0-9]\\.[0-9] 200.*\r\n')
|
||||
self.assertEqual(chunks, [b'asdf', b'qwer'])
|
||||
|
||||
def test_header_callback_stack_context(self):
|
||||
|
@ -327,7 +336,7 @@ Transfer-Encoding: chunked
|
|||
return True
|
||||
|
||||
def header_callback(header_line):
|
||||
if header_line.startswith('Content-Type:'):
|
||||
if header_line.lower().startswith('content-type:'):
|
||||
1 / 0
|
||||
|
||||
with ExceptionStackContext(error_handler):
|
||||
|
@ -404,6 +413,11 @@ Transfer-Encoding: chunked
|
|||
self.assertEqual(context.exception.code, 404)
|
||||
self.assertEqual(context.exception.response.code, 404)
|
||||
|
||||
@gen_test
|
||||
def test_future_http_error_no_raise(self):
|
||||
response = yield self.http_client.fetch(self.get_url('/notfound'), raise_error=False)
|
||||
self.assertEqual(response.code, 404)
|
||||
|
||||
@gen_test
|
||||
def test_reuse_request_from_response(self):
|
||||
# The response.request attribute should be an HTTPRequest, not
|
||||
|
@ -454,7 +468,7 @@ Transfer-Encoding: chunked
|
|||
# Twisted's reactor does not. The removeReader call fails and so
|
||||
# do all future removeAll calls (which our tests do at cleanup).
|
||||
#
|
||||
#def test_post_307(self):
|
||||
# def test_post_307(self):
|
||||
# response = self.fetch("/redirect?status=307&url=/post",
|
||||
# method="POST", body=b"arg1=foo&arg2=bar")
|
||||
# self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
|
||||
|
@ -536,14 +550,19 @@ class SyncHTTPClientTest(unittest.TestCase):
|
|||
def tearDown(self):
|
||||
def stop_server():
|
||||
self.server.stop()
|
||||
self.server_ioloop.stop()
|
||||
# Delay the shutdown of the IOLoop by one iteration because
|
||||
# the server may still have some cleanup work left when
|
||||
# the client finishes with the response (this is noticable
|
||||
# with http/2, which leaves a Future with an unexamined
|
||||
# StreamClosedError on the loop).
|
||||
self.server_ioloop.add_callback(self.server_ioloop.stop)
|
||||
self.server_ioloop.add_callback(stop_server)
|
||||
self.server_thread.join()
|
||||
self.http_client.close()
|
||||
self.server_ioloop.close(all_fds=True)
|
||||
|
||||
def get_url(self, path):
|
||||
return 'http://localhost:%d%s' % (self.port, path)
|
||||
return 'http://127.0.0.1:%d%s' % (self.port, path)
|
||||
|
||||
def test_sync_client(self):
|
||||
response = self.http_client.fetch(self.get_url('/'))
|
||||
|
|
|
@ -32,6 +32,7 @@ def read_stream_body(stream, callback):
|
|||
"""Reads an HTTP response from `stream` and runs callback with its
|
||||
headers and body."""
|
||||
chunks = []
|
||||
|
||||
class Delegate(HTTPMessageDelegate):
|
||||
def headers_received(self, start_line, headers):
|
||||
self.headers = headers
|
||||
|
@ -161,11 +162,14 @@ class BadSSLOptionsTest(unittest.TestCase):
|
|||
application = Application()
|
||||
module_dir = os.path.dirname(__file__)
|
||||
existing_certificate = os.path.join(module_dir, 'test.crt')
|
||||
existing_key = os.path.join(module_dir, 'test.key')
|
||||
|
||||
self.assertRaises(ValueError, HTTPServer, application, ssl_options={
|
||||
self.assertRaises((ValueError, IOError),
|
||||
HTTPServer, application, ssl_options={
|
||||
"certfile": "/__mising__.crt",
|
||||
})
|
||||
self.assertRaises(ValueError, HTTPServer, application, ssl_options={
|
||||
self.assertRaises((ValueError, IOError),
|
||||
HTTPServer, application, ssl_options={
|
||||
"certfile": existing_certificate,
|
||||
"keyfile": "/__missing__.key"
|
||||
})
|
||||
|
@ -173,7 +177,7 @@ class BadSSLOptionsTest(unittest.TestCase):
|
|||
# This actually works because both files exist
|
||||
HTTPServer(application, ssl_options={
|
||||
"certfile": existing_certificate,
|
||||
"keyfile": existing_certificate
|
||||
"keyfile": existing_key,
|
||||
})
|
||||
|
||||
|
||||
|
@ -195,14 +199,14 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
|
|||
def get_app(self):
|
||||
return Application(self.get_handlers())
|
||||
|
||||
def raw_fetch(self, headers, body):
|
||||
def raw_fetch(self, headers, body, newline=b"\r\n"):
|
||||
with closing(IOStream(socket.socket())) as stream:
|
||||
stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
|
||||
self.wait()
|
||||
stream.write(
|
||||
b"\r\n".join(headers +
|
||||
[utf8("Content-Length: %d\r\n" % len(body))]) +
|
||||
b"\r\n" + body)
|
||||
newline.join(headers +
|
||||
[utf8("Content-Length: %d" % len(body))]) +
|
||||
newline + newline + body)
|
||||
read_stream_body(stream, self.stop)
|
||||
headers, body = self.wait()
|
||||
return body
|
||||
|
@ -232,12 +236,19 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
|
|||
self.assertEqual(u("\u00f3"), data["filename"])
|
||||
self.assertEqual(u("\u00fa"), data["filebody"])
|
||||
|
||||
def test_newlines(self):
|
||||
# We support both CRLF and bare LF as line separators.
|
||||
for newline in (b"\r\n", b"\n"):
|
||||
response = self.raw_fetch([b"GET /hello HTTP/1.0"], b"",
|
||||
newline=newline)
|
||||
self.assertEqual(response, b'Hello world')
|
||||
|
||||
def test_100_continue(self):
|
||||
# Run through a 100-continue interaction by hand:
|
||||
# When given Expect: 100-continue, we get a 100 response after the
|
||||
# headers, and then the real response after the body.
|
||||
stream = IOStream(socket.socket(), io_loop=self.io_loop)
|
||||
stream.connect(("localhost", self.get_http_port()), callback=self.stop)
|
||||
stream.connect(("127.0.0.1", self.get_http_port()), callback=self.stop)
|
||||
self.wait()
|
||||
stream.write(b"\r\n".join([b"POST /hello HTTP/1.1",
|
||||
b"Content-Length: 1024",
|
||||
|
@ -374,7 +385,7 @@ class HTTPServerRawTest(AsyncHTTPTestCase):
|
|||
def setUp(self):
|
||||
super(HTTPServerRawTest, self).setUp()
|
||||
self.stream = IOStream(socket.socket())
|
||||
self.stream.connect(('localhost', self.get_http_port()), self.stop)
|
||||
self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
|
||||
self.wait()
|
||||
|
||||
def tearDown(self):
|
||||
|
@ -555,7 +566,7 @@ class UnixSocketTest(AsyncTestCase):
|
|||
self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
|
||||
self.stream.read_until(b"\r\n", self.stop)
|
||||
response = self.wait()
|
||||
self.assertEqual(response, b"HTTP/1.0 200 OK\r\n")
|
||||
self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
|
||||
self.stream.read_until(b"\r\n\r\n", self.stop)
|
||||
headers = HTTPHeaders.parse(self.wait().decode('latin1'))
|
||||
self.stream.read_bytes(int(headers["Content-Length"]), self.stop)
|
||||
|
@ -582,6 +593,7 @@ class KeepAliveTest(AsyncHTTPTestCase):
|
|||
class HelloHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.finish('Hello world')
|
||||
|
||||
def post(self):
|
||||
self.finish('Hello world')
|
||||
|
||||
|
@ -623,13 +635,13 @@ class KeepAliveTest(AsyncHTTPTestCase):
|
|||
# The next few methods are a crude manual http client
|
||||
def connect(self):
|
||||
self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
|
||||
self.stream.connect(('localhost', self.get_http_port()), self.stop)
|
||||
self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
|
||||
self.wait()
|
||||
|
||||
def read_headers(self):
|
||||
self.stream.read_until(b'\r\n', self.stop)
|
||||
first_line = self.wait()
|
||||
self.assertTrue(first_line.startswith(self.http_version + b' 200'), first_line)
|
||||
self.assertTrue(first_line.startswith(b'HTTP/1.1 200'), first_line)
|
||||
self.stream.read_until(b'\r\n\r\n', self.stop)
|
||||
header_bytes = self.wait()
|
||||
headers = HTTPHeaders.parse(header_bytes.decode('latin1'))
|
||||
|
@ -808,8 +820,8 @@ class StreamingChunkSizeTest(AsyncHTTPTestCase):
|
|||
|
||||
def get_app(self):
|
||||
class App(HTTPServerConnectionDelegate):
|
||||
def start_request(self, connection):
|
||||
return StreamingChunkSizeTest.MessageDelegate(connection)
|
||||
def start_request(self, server_conn, request_conn):
|
||||
return StreamingChunkSizeTest.MessageDelegate(request_conn)
|
||||
return App()
|
||||
|
||||
def fetch_chunk_sizes(self, **kwargs):
|
||||
|
@ -856,6 +868,7 @@ class StreamingChunkSizeTest(AsyncHTTPTestCase):
|
|||
def test_chunked_compressed(self):
|
||||
compressed = self.compress(self.BODY)
|
||||
self.assertGreater(len(compressed), 20)
|
||||
|
||||
def body_producer(write):
|
||||
write(compressed[:20])
|
||||
write(compressed[20:])
|
||||
|
@ -900,7 +913,7 @@ class IdleTimeoutTest(AsyncHTTPTestCase):
|
|||
|
||||
def connect(self):
|
||||
stream = IOStream(socket.socket())
|
||||
stream.connect(('localhost', self.get_http_port()), self.stop)
|
||||
stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
|
||||
self.wait()
|
||||
self.streams.append(stream)
|
||||
return stream
|
||||
|
@ -1045,6 +1058,15 @@ class LegacyInterfaceTest(AsyncHTTPTestCase):
|
|||
# delegate interface, and writes its response via request.write
|
||||
# instead of request.connection.write_headers.
|
||||
def handle_request(request):
|
||||
self.http1 = request.version.startswith("HTTP/1.")
|
||||
if not self.http1:
|
||||
# This test will be skipped if we're using HTTP/2,
|
||||
# so just close it out cleanly using the modern interface.
|
||||
request.connection.write_headers(
|
||||
ResponseStartLine('', 200, 'OK'),
|
||||
HTTPHeaders())
|
||||
request.connection.finish()
|
||||
return
|
||||
message = b"Hello world"
|
||||
request.write(utf8("HTTP/1.1 200 OK\r\n"
|
||||
"Content-Length: %d\r\n\r\n" % len(message)))
|
||||
|
@ -1054,4 +1076,6 @@ class LegacyInterfaceTest(AsyncHTTPTestCase):
|
|||
|
||||
def test_legacy_interface(self):
|
||||
response = self.fetch('/')
|
||||
if not self.http1:
|
||||
self.skipTest("requires HTTP/1.x")
|
||||
self.assertEqual(response.body, b"Hello world")
|
||||
|
|
|
@ -3,11 +3,13 @@
|
|||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from tornado.httputil import url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp, HTTPServerRequest, parse_request_start_line
|
||||
from tornado.escape import utf8
|
||||
from tornado.escape import utf8, native_str
|
||||
from tornado.log import gen_log
|
||||
from tornado.testing import ExpectLog
|
||||
from tornado.test.util import unittest
|
||||
from tornado.util import u
|
||||
|
||||
import copy
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
@ -228,6 +230,75 @@ Foo: even
|
|||
("Foo", "bar baz"),
|
||||
("Foo", "even more lines")])
|
||||
|
||||
def test_unicode_newlines(self):
|
||||
# Ensure that only \r\n is recognized as a header separator, and not
|
||||
# the other newline-like unicode characters.
|
||||
# Characters that are likely to be problematic can be found in
|
||||
# http://unicode.org/standard/reports/tr13/tr13-5.html
|
||||
# and cpython's unicodeobject.c (which defines the implementation
|
||||
# of unicode_type.splitlines(), and uses a different list than TR13).
|
||||
newlines = [
|
||||
u('\u001b'), # VERTICAL TAB
|
||||
u('\u001c'), # FILE SEPARATOR
|
||||
u('\u001d'), # GROUP SEPARATOR
|
||||
u('\u001e'), # RECORD SEPARATOR
|
||||
u('\u0085'), # NEXT LINE
|
||||
u('\u2028'), # LINE SEPARATOR
|
||||
u('\u2029'), # PARAGRAPH SEPARATOR
|
||||
]
|
||||
for newline in newlines:
|
||||
# Try the utf8 and latin1 representations of each newline
|
||||
for encoding in ['utf8', 'latin1']:
|
||||
try:
|
||||
try:
|
||||
encoded = newline.encode(encoding)
|
||||
except UnicodeEncodeError:
|
||||
# Some chars cannot be represented in latin1
|
||||
continue
|
||||
data = b'Cookie: foo=' + encoded + b'bar'
|
||||
# parse() wants a native_str, so decode through latin1
|
||||
# in the same way the real parser does.
|
||||
headers = HTTPHeaders.parse(
|
||||
native_str(data.decode('latin1')))
|
||||
expected = [('Cookie', 'foo=' +
|
||||
native_str(encoded.decode('latin1')) + 'bar')]
|
||||
self.assertEqual(
|
||||
expected, list(headers.get_all()))
|
||||
except Exception:
|
||||
gen_log.warning("failed while trying %r in %s",
|
||||
newline, encoding)
|
||||
raise
|
||||
|
||||
def test_optional_cr(self):
|
||||
# Both CRLF and LF should be accepted as separators. CR should not be
|
||||
# part of the data when followed by LF, but it is a normal char
|
||||
# otherwise (or should bare CR be an error?)
|
||||
headers = HTTPHeaders.parse(
|
||||
'CRLF: crlf\r\nLF: lf\nCR: cr\rMore: more\r\n')
|
||||
self.assertEqual(sorted(headers.get_all()),
|
||||
[('Cr', 'cr\rMore: more'),
|
||||
('Crlf', 'crlf'),
|
||||
('Lf', 'lf'),
|
||||
])
|
||||
|
||||
def test_copy(self):
|
||||
all_pairs = [('A', '1'), ('A', '2'), ('B', 'c')]
|
||||
h1 = HTTPHeaders()
|
||||
for k, v in all_pairs:
|
||||
h1.add(k, v)
|
||||
h2 = h1.copy()
|
||||
h3 = copy.copy(h1)
|
||||
h4 = copy.deepcopy(h1)
|
||||
for headers in [h1, h2, h3, h4]:
|
||||
# All the copies are identical, no matter how they were
|
||||
# constructed.
|
||||
self.assertEqual(list(sorted(headers.get_all())), all_pairs)
|
||||
for headers in [h2, h3, h4]:
|
||||
# Neither the dict or its member lists are reused.
|
||||
self.assertIsNot(headers, h1)
|
||||
self.assertIsNot(headers.get_list('A'), h1.get_list('A'))
|
||||
|
||||
|
||||
|
||||
class FormatTimestampTest(unittest.TestCase):
|
||||
# Make sure that all the input types are supported.
|
||||
|
@ -264,6 +335,10 @@ class HTTPServerRequestTest(unittest.TestCase):
|
|||
# more required parameters slip in.
|
||||
HTTPServerRequest(uri='/')
|
||||
|
||||
def test_body_is_a_byte_string(self):
|
||||
requets = HTTPServerRequest(uri='/')
|
||||
self.assertIsInstance(requets.body, bytes)
|
||||
|
||||
|
||||
class ParseRequestStartLineTest(unittest.TestCase):
|
||||
METHOD = "GET"
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# flake8: noqa
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from tornado.test.util import unittest
|
||||
|
||||
|
|
|
@ -11,8 +11,9 @@ import threading
|
|||
import time
|
||||
|
||||
from tornado import gen
|
||||
from tornado.ioloop import IOLoop, TimeoutError
|
||||
from tornado.ioloop import IOLoop, TimeoutError, PollIOLoop, PeriodicCallback
|
||||
from tornado.log import app_log
|
||||
from tornado.platform.select import _Select
|
||||
from tornado.stack_context import ExceptionStackContext, StackContext, wrap, NullContext
|
||||
from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog
|
||||
from tornado.test.util import unittest, skipIfNonUnix, skipOnTravis
|
||||
|
@ -23,6 +24,42 @@ except ImportError:
|
|||
futures = None
|
||||
|
||||
|
||||
class FakeTimeSelect(_Select):
|
||||
def __init__(self):
|
||||
self._time = 1000
|
||||
super(FakeTimeSelect, self).__init__()
|
||||
|
||||
def time(self):
|
||||
return self._time
|
||||
|
||||
def sleep(self, t):
|
||||
self._time += t
|
||||
|
||||
def poll(self, timeout):
|
||||
events = super(FakeTimeSelect, self).poll(0)
|
||||
if events:
|
||||
return events
|
||||
self._time += timeout
|
||||
return []
|
||||
|
||||
|
||||
class FakeTimeIOLoop(PollIOLoop):
|
||||
"""IOLoop implementation with a fake and deterministic clock.
|
||||
|
||||
The clock advances as needed to trigger timeouts immediately.
|
||||
For use when testing code that involves the passage of time
|
||||
and no external dependencies.
|
||||
"""
|
||||
def initialize(self):
|
||||
self.fts = FakeTimeSelect()
|
||||
super(FakeTimeIOLoop, self).initialize(impl=self.fts,
|
||||
time_func=self.fts.time)
|
||||
|
||||
def sleep(self, t):
|
||||
"""Simulate a blocking sleep by advancing the clock."""
|
||||
self.fts.sleep(t)
|
||||
|
||||
|
||||
class TestIOLoop(AsyncTestCase):
|
||||
@skipOnTravis
|
||||
def test_add_callback_wakeup(self):
|
||||
|
@ -180,10 +217,12 @@ class TestIOLoop(AsyncTestCase):
|
|||
# t2 should be cancelled by t1, even though it is already scheduled to
|
||||
# be run before the ioloop even looks at it.
|
||||
now = self.io_loop.time()
|
||||
|
||||
def t1():
|
||||
calls[0] = True
|
||||
self.io_loop.remove_timeout(t2_handle)
|
||||
self.io_loop.add_timeout(now + 0.01, t1)
|
||||
|
||||
def t2():
|
||||
calls[1] = True
|
||||
t2_handle = self.io_loop.add_timeout(now + 0.02, t2)
|
||||
|
@ -252,6 +291,7 @@ class TestIOLoop(AsyncTestCase):
|
|||
"""The handler callback receives the same fd object it passed in."""
|
||||
server_sock, port = bind_unused_port()
|
||||
fds = []
|
||||
|
||||
def handle_connection(fd, events):
|
||||
fds.append(fd)
|
||||
conn, addr = server_sock.accept()
|
||||
|
@ -274,6 +314,7 @@ class TestIOLoop(AsyncTestCase):
|
|||
|
||||
def test_mixed_fd_fileobj(self):
|
||||
server_sock, port = bind_unused_port()
|
||||
|
||||
def f(fd, events):
|
||||
pass
|
||||
self.io_loop.add_handler(server_sock, f, IOLoop.READ)
|
||||
|
@ -288,6 +329,7 @@ class TestIOLoop(AsyncTestCase):
|
|||
"""Calling start() twice should raise an error, not deadlock."""
|
||||
returned_from_start = [False]
|
||||
got_exception = [False]
|
||||
|
||||
def callback():
|
||||
try:
|
||||
self.io_loop.start()
|
||||
|
@ -305,7 +347,7 @@ class TestIOLoop(AsyncTestCase):
|
|||
# Use a NullContext to keep the exception from being caught by
|
||||
# AsyncTestCase.
|
||||
with NullContext():
|
||||
self.io_loop.add_callback(lambda: 1/0)
|
||||
self.io_loop.add_callback(lambda: 1 / 0)
|
||||
self.io_loop.add_callback(self.stop)
|
||||
with ExpectLog(app_log, "Exception in callback"):
|
||||
self.wait()
|
||||
|
@ -316,7 +358,7 @@ class TestIOLoop(AsyncTestCase):
|
|||
@gen.coroutine
|
||||
def callback():
|
||||
self.io_loop.add_callback(self.stop)
|
||||
1/0
|
||||
1 / 0
|
||||
self.io_loop.add_callback(callback)
|
||||
with ExpectLog(app_log, "Exception in callback"):
|
||||
self.wait()
|
||||
|
@ -324,12 +366,12 @@ class TestIOLoop(AsyncTestCase):
|
|||
def test_spawn_callback(self):
|
||||
# An added callback runs in the test's stack_context, so will be
|
||||
# re-arised in wait().
|
||||
self.io_loop.add_callback(lambda: 1/0)
|
||||
self.io_loop.add_callback(lambda: 1 / 0)
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
self.wait()
|
||||
# A spawned callback is run directly on the IOLoop, so it will be
|
||||
# logged without stopping the test.
|
||||
self.io_loop.spawn_callback(lambda: 1/0)
|
||||
self.io_loop.spawn_callback(lambda: 1 / 0)
|
||||
self.io_loop.add_callback(self.stop)
|
||||
with ExpectLog(app_log, "Exception in callback"):
|
||||
self.wait()
|
||||
|
@ -344,6 +386,7 @@ class TestIOLoop(AsyncTestCase):
|
|||
|
||||
# After reading from one fd, remove the other from the IOLoop.
|
||||
chunks = []
|
||||
|
||||
def handle_read(fd, events):
|
||||
chunks.append(fd.recv(1024))
|
||||
if fd is client:
|
||||
|
@ -352,7 +395,7 @@ class TestIOLoop(AsyncTestCase):
|
|||
self.io_loop.remove_handler(client)
|
||||
self.io_loop.add_handler(client, handle_read, self.io_loop.READ)
|
||||
self.io_loop.add_handler(server, handle_read, self.io_loop.READ)
|
||||
self.io_loop.call_later(0.01, self.stop)
|
||||
self.io_loop.call_later(0.03, self.stop)
|
||||
self.wait()
|
||||
|
||||
# Only one fd was read; the other was cleanly removed.
|
||||
|
@ -520,5 +563,47 @@ class TestIOLoopRunSync(unittest.TestCase):
|
|||
self.assertRaises(TimeoutError, self.io_loop.run_sync, f, timeout=0.01)
|
||||
|
||||
|
||||
class TestPeriodicCallback(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.io_loop = FakeTimeIOLoop()
|
||||
self.io_loop.make_current()
|
||||
|
||||
def tearDown(self):
|
||||
self.io_loop.close()
|
||||
|
||||
def test_basic(self):
|
||||
calls = []
|
||||
|
||||
def cb():
|
||||
calls.append(self.io_loop.time())
|
||||
pc = PeriodicCallback(cb, 10000)
|
||||
pc.start()
|
||||
self.io_loop.call_later(50, self.io_loop.stop)
|
||||
self.io_loop.start()
|
||||
self.assertEqual(calls, [1010, 1020, 1030, 1040, 1050])
|
||||
|
||||
def test_overrun(self):
|
||||
sleep_durations = [9, 9, 10, 11, 20, 20, 35, 35, 0, 0]
|
||||
expected = [
|
||||
1010, 1020, 1030, # first 3 calls on schedule
|
||||
1050, 1070, # next 2 delayed one cycle
|
||||
1100, 1130, # next 2 delayed 2 cycles
|
||||
1170, 1210, # next 2 delayed 3 cycles
|
||||
1220, 1230, # then back on schedule.
|
||||
]
|
||||
calls = []
|
||||
|
||||
def cb():
|
||||
calls.append(self.io_loop.time())
|
||||
if not sleep_durations:
|
||||
self.io_loop.stop()
|
||||
return
|
||||
self.io_loop.sleep(sleep_durations.pop(0))
|
||||
pc = PeriodicCallback(cb, 10000)
|
||||
pc.start()
|
||||
self.io_loop.start()
|
||||
self.assertEqual(calls, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -7,10 +7,10 @@ from tornado.httputil import HTTPHeaders
|
|||
from tornado.log import gen_log, app_log
|
||||
from tornado.netutil import ssl_wrap_socket
|
||||
from tornado.stack_context import NullContext
|
||||
from tornado.tcpserver import TCPServer
|
||||
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test
|
||||
from tornado.test.util import unittest, skipIfNonUnix
|
||||
from tornado.test.util import unittest, skipIfNonUnix, refusing_port
|
||||
from tornado.web import RequestHandler, Application
|
||||
import lib.certifi
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
|
@ -51,18 +51,18 @@ class TestIOStreamWebMixin(object):
|
|||
|
||||
def test_read_until_close(self):
|
||||
stream = self._make_client_iostream()
|
||||
stream.connect(('localhost', self.get_http_port()), callback=self.stop)
|
||||
stream.connect(('127.0.0.1', self.get_http_port()), callback=self.stop)
|
||||
self.wait()
|
||||
stream.write(b"GET / HTTP/1.0\r\n\r\n")
|
||||
|
||||
stream.read_until_close(self.stop)
|
||||
data = self.wait()
|
||||
self.assertTrue(data.startswith(b"HTTP/1.0 200"))
|
||||
self.assertTrue(data.startswith(b"HTTP/1.1 200"))
|
||||
self.assertTrue(data.endswith(b"Hello"))
|
||||
|
||||
def test_read_zero_bytes(self):
|
||||
self.stream = self._make_client_iostream()
|
||||
self.stream.connect(("localhost", self.get_http_port()),
|
||||
self.stream.connect(("127.0.0.1", self.get_http_port()),
|
||||
callback=self.stop)
|
||||
self.wait()
|
||||
self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
|
||||
|
@ -70,7 +70,7 @@ class TestIOStreamWebMixin(object):
|
|||
# normal read
|
||||
self.stream.read_bytes(9, self.stop)
|
||||
data = self.wait()
|
||||
self.assertEqual(data, b"HTTP/1.0 ")
|
||||
self.assertEqual(data, b"HTTP/1.1 ")
|
||||
|
||||
# zero bytes
|
||||
self.stream.read_bytes(0, self.stop)
|
||||
|
@ -91,7 +91,7 @@ class TestIOStreamWebMixin(object):
|
|||
def connected_callback():
|
||||
connected[0] = True
|
||||
self.stop()
|
||||
stream.connect(("localhost", self.get_http_port()),
|
||||
stream.connect(("127.0.0.1", self.get_http_port()),
|
||||
callback=connected_callback)
|
||||
# unlike the previous tests, try to write before the connection
|
||||
# is complete.
|
||||
|
@ -121,11 +121,11 @@ class TestIOStreamWebMixin(object):
|
|||
"""Basic test of IOStream's ability to return Futures."""
|
||||
stream = self._make_client_iostream()
|
||||
connect_result = yield stream.connect(
|
||||
("localhost", self.get_http_port()))
|
||||
("127.0.0.1", self.get_http_port()))
|
||||
self.assertIs(connect_result, stream)
|
||||
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
|
||||
first_line = yield stream.read_until(b"\r\n")
|
||||
self.assertEqual(first_line, b"HTTP/1.0 200 OK\r\n")
|
||||
self.assertEqual(first_line, b"HTTP/1.1 200 OK\r\n")
|
||||
# callback=None is equivalent to no callback.
|
||||
header_data = yield stream.read_until(b"\r\n\r\n", callback=None)
|
||||
headers = HTTPHeaders.parse(header_data.decode('latin1'))
|
||||
|
@ -137,7 +137,7 @@ class TestIOStreamWebMixin(object):
|
|||
@gen_test
|
||||
def test_future_close_while_reading(self):
|
||||
stream = self._make_client_iostream()
|
||||
yield stream.connect(("localhost", self.get_http_port()))
|
||||
yield stream.connect(("127.0.0.1", self.get_http_port()))
|
||||
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
|
||||
with self.assertRaises(StreamClosedError):
|
||||
yield stream.read_bytes(1024 * 1024)
|
||||
|
@ -147,7 +147,7 @@ class TestIOStreamWebMixin(object):
|
|||
def test_future_read_until_close(self):
|
||||
# Ensure that the data comes through before the StreamClosedError.
|
||||
stream = self._make_client_iostream()
|
||||
yield stream.connect(("localhost", self.get_http_port()))
|
||||
yield stream.connect(("127.0.0.1", self.get_http_port()))
|
||||
yield stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n")
|
||||
yield stream.read_until(b"\r\n\r\n")
|
||||
body = yield stream.read_until_close()
|
||||
|
@ -217,17 +217,18 @@ class TestIOStreamMixin(object):
|
|||
# When a connection is refused, the connect callback should not
|
||||
# be run. (The kqueue IOLoop used to behave differently from the
|
||||
# epoll IOLoop in this respect)
|
||||
server_socket, port = bind_unused_port()
|
||||
server_socket.close()
|
||||
cleanup_func, port = refusing_port()
|
||||
self.addCleanup(cleanup_func)
|
||||
stream = IOStream(socket.socket(), self.io_loop)
|
||||
self.connect_called = False
|
||||
|
||||
def connect_callback():
|
||||
self.connect_called = True
|
||||
self.stop()
|
||||
stream.set_close_callback(self.stop)
|
||||
# log messages vary by platform and ioloop implementation
|
||||
with ExpectLog(gen_log, ".*", required=False):
|
||||
stream.connect(("localhost", port), connect_callback)
|
||||
stream.connect(("127.0.0.1", port), connect_callback)
|
||||
self.wait()
|
||||
self.assertFalse(self.connect_called)
|
||||
self.assertTrue(isinstance(stream.error, socket.error), stream.error)
|
||||
|
@ -248,7 +249,8 @@ class TestIOStreamMixin(object):
|
|||
# opendns and some ISPs return bogus addresses for nonexistent
|
||||
# domains instead of the proper error codes).
|
||||
with ExpectLog(gen_log, "Connect error"):
|
||||
stream.connect(('an invalid domain', 54321))
|
||||
stream.connect(('an invalid domain', 54321), callback=self.stop)
|
||||
self.wait()
|
||||
self.assertTrue(isinstance(stream.error, socket.gaierror), stream.error)
|
||||
|
||||
def test_read_callback_error(self):
|
||||
|
@ -308,6 +310,7 @@ class TestIOStreamMixin(object):
|
|||
def streaming_callback(data):
|
||||
chunks.append(data)
|
||||
self.stop()
|
||||
|
||||
def close_callback(data):
|
||||
assert not data, data
|
||||
closed[0] = True
|
||||
|
@ -325,6 +328,31 @@ class TestIOStreamMixin(object):
|
|||
server.close()
|
||||
client.close()
|
||||
|
||||
def test_streaming_until_close_future(self):
|
||||
server, client = self.make_iostream_pair()
|
||||
try:
|
||||
chunks = []
|
||||
|
||||
@gen.coroutine
|
||||
def client_task():
|
||||
yield client.read_until_close(streaming_callback=chunks.append)
|
||||
|
||||
@gen.coroutine
|
||||
def server_task():
|
||||
yield server.write(b"1234")
|
||||
yield gen.sleep(0.01)
|
||||
yield server.write(b"5678")
|
||||
server.close()
|
||||
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield [client_task(), server_task()]
|
||||
self.io_loop.run_sync(f)
|
||||
self.assertEqual(chunks, [b"1234", b"5678"])
|
||||
finally:
|
||||
server.close()
|
||||
client.close()
|
||||
|
||||
def test_delayed_close_callback(self):
|
||||
# The scenario: Server closes the connection while there is a pending
|
||||
# read that can be served out of buffered data. The client does not
|
||||
|
@ -353,6 +381,7 @@ class TestIOStreamMixin(object):
|
|||
def test_future_delayed_close_callback(self):
|
||||
# Same as test_delayed_close_callback, but with the future interface.
|
||||
server, client = self.make_iostream_pair()
|
||||
|
||||
# We can't call make_iostream_pair inside a gen_test function
|
||||
# because the ioloop is not reentrant.
|
||||
@gen_test
|
||||
|
@ -532,6 +561,7 @@ class TestIOStreamMixin(object):
|
|||
# and IOStream._maybe_add_error_listener.
|
||||
server, client = self.make_iostream_pair()
|
||||
closed = [False]
|
||||
|
||||
def close_callback():
|
||||
closed[0] = True
|
||||
self.stop()
|
||||
|
@ -724,6 +754,26 @@ class TestIOStreamMixin(object):
|
|||
server.close()
|
||||
client.close()
|
||||
|
||||
def test_flow_control(self):
|
||||
MB = 1024 * 1024
|
||||
server, client = self.make_iostream_pair(max_buffer_size=5 * MB)
|
||||
try:
|
||||
# Client writes more than the server will accept.
|
||||
client.write(b"a" * 10 * MB)
|
||||
# The server pauses while reading.
|
||||
server.read_bytes(MB, self.stop)
|
||||
self.wait()
|
||||
self.io_loop.call_later(0.1, self.stop)
|
||||
self.wait()
|
||||
# The client's writes have been blocked; the server can
|
||||
# continue to read gradually.
|
||||
for i in range(9):
|
||||
server.read_bytes(MB, self.stop)
|
||||
self.wait()
|
||||
finally:
|
||||
server.close()
|
||||
client.close()
|
||||
|
||||
|
||||
class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase):
|
||||
def _make_client_iostream(self):
|
||||
|
@ -732,7 +782,8 @@ class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase):
|
|||
|
||||
class TestIOStreamWebHTTPS(TestIOStreamWebMixin, AsyncHTTPSTestCase):
|
||||
def _make_client_iostream(self):
|
||||
return SSLIOStream(socket.socket(), io_loop=self.io_loop)
|
||||
return SSLIOStream(socket.socket(), io_loop=self.io_loop,
|
||||
ssl_options=dict(cert_reqs=ssl.CERT_NONE))
|
||||
|
||||
|
||||
class TestIOStream(TestIOStreamMixin, AsyncTestCase):
|
||||
|
@ -752,7 +803,9 @@ class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase):
|
|||
return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
|
||||
|
||||
def _make_client_iostream(self, connection, **kwargs):
|
||||
return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
|
||||
return SSLIOStream(connection, io_loop=self.io_loop,
|
||||
ssl_options=dict(cert_reqs=ssl.CERT_NONE),
|
||||
**kwargs)
|
||||
|
||||
|
||||
# This will run some tests that are basically redundant but it's the
|
||||
|
@ -820,10 +873,10 @@ class TestIOStreamStartTLS(AsyncTestCase):
|
|||
recv_line = yield self.client_stream.read_until(b"\r\n")
|
||||
self.assertEqual(line, recv_line)
|
||||
|
||||
def client_start_tls(self, ssl_options=None):
|
||||
def client_start_tls(self, ssl_options=None, server_hostname=None):
|
||||
client_stream = self.client_stream
|
||||
self.client_stream = None
|
||||
return client_stream.start_tls(False, ssl_options)
|
||||
return client_stream.start_tls(False, ssl_options, server_hostname)
|
||||
|
||||
def server_start_tls(self, ssl_options=None):
|
||||
server_stream = self.server_stream
|
||||
|
@ -842,7 +895,7 @@ class TestIOStreamStartTLS(AsyncTestCase):
|
|||
yield self.server_send_line(b"250 STARTTLS\r\n")
|
||||
yield self.client_send_line(b"STARTTLS\r\n")
|
||||
yield self.server_send_line(b"220 Go ahead\r\n")
|
||||
client_future = self.client_start_tls()
|
||||
client_future = self.client_start_tls(dict(cert_reqs=ssl.CERT_NONE))
|
||||
server_future = self.server_start_tls(_server_ssl_options())
|
||||
self.client_stream = yield client_future
|
||||
self.server_stream = yield server_future
|
||||
|
@ -853,12 +906,123 @@ class TestIOStreamStartTLS(AsyncTestCase):
|
|||
|
||||
@gen_test
|
||||
def test_handshake_fail(self):
|
||||
self.server_start_tls(_server_ssl_options())
|
||||
client_future = self.client_start_tls(
|
||||
dict(cert_reqs=ssl.CERT_REQUIRED, ca_certs=certifi.where()))
|
||||
server_future = self.server_start_tls(_server_ssl_options())
|
||||
# Certificates are verified with the default configuration.
|
||||
client_future = self.client_start_tls(server_hostname="localhost")
|
||||
with ExpectLog(gen_log, "SSL Error"):
|
||||
with self.assertRaises(ssl.SSLError):
|
||||
yield client_future
|
||||
with self.assertRaises((ssl.SSLError, socket.error)):
|
||||
yield server_future
|
||||
|
||||
@unittest.skipIf(not hasattr(ssl, 'create_default_context'),
|
||||
'ssl.create_default_context not present')
|
||||
@gen_test
|
||||
def test_check_hostname(self):
|
||||
# Test that server_hostname parameter to start_tls is being used.
|
||||
# The check_hostname functionality is only available in python 2.7 and
|
||||
# up and in python 3.4 and up.
|
||||
server_future = self.server_start_tls(_server_ssl_options())
|
||||
client_future = self.client_start_tls(
|
||||
ssl.create_default_context(),
|
||||
server_hostname=b'127.0.0.1')
|
||||
with ExpectLog(gen_log, "SSL Error"):
|
||||
with self.assertRaises(ssl.SSLError):
|
||||
yield client_future
|
||||
with self.assertRaises((ssl.SSLError, socket.error)):
|
||||
yield server_future
|
||||
|
||||
|
||||
class WaitForHandshakeTest(AsyncTestCase):
|
||||
@gen.coroutine
|
||||
def connect_to_server(self, server_cls):
|
||||
server = client = None
|
||||
try:
|
||||
sock, port = bind_unused_port()
|
||||
server = server_cls(ssl_options=_server_ssl_options())
|
||||
server.add_socket(sock)
|
||||
|
||||
client = SSLIOStream(socket.socket(),
|
||||
ssl_options=dict(cert_reqs=ssl.CERT_NONE))
|
||||
yield client.connect(('127.0.0.1', port))
|
||||
self.assertIsNotNone(client.socket.cipher())
|
||||
finally:
|
||||
if server is not None:
|
||||
server.stop()
|
||||
if client is not None:
|
||||
client.close()
|
||||
|
||||
@gen_test
|
||||
def test_wait_for_handshake_callback(self):
|
||||
test = self
|
||||
handshake_future = Future()
|
||||
|
||||
class TestServer(TCPServer):
|
||||
def handle_stream(self, stream, address):
|
||||
# The handshake has not yet completed.
|
||||
test.assertIsNone(stream.socket.cipher())
|
||||
self.stream = stream
|
||||
stream.wait_for_handshake(self.handshake_done)
|
||||
|
||||
def handshake_done(self):
|
||||
# Now the handshake is done and ssl information is available.
|
||||
test.assertIsNotNone(self.stream.socket.cipher())
|
||||
handshake_future.set_result(None)
|
||||
|
||||
yield self.connect_to_server(TestServer)
|
||||
yield handshake_future
|
||||
|
||||
@gen_test
|
||||
def test_wait_for_handshake_future(self):
|
||||
test = self
|
||||
handshake_future = Future()
|
||||
|
||||
class TestServer(TCPServer):
|
||||
def handle_stream(self, stream, address):
|
||||
test.assertIsNone(stream.socket.cipher())
|
||||
test.io_loop.spawn_callback(self.handle_connection, stream)
|
||||
|
||||
@gen.coroutine
|
||||
def handle_connection(self, stream):
|
||||
yield stream.wait_for_handshake()
|
||||
handshake_future.set_result(None)
|
||||
|
||||
yield self.connect_to_server(TestServer)
|
||||
yield handshake_future
|
||||
|
||||
@gen_test
|
||||
def test_wait_for_handshake_already_waiting_error(self):
|
||||
test = self
|
||||
handshake_future = Future()
|
||||
|
||||
class TestServer(TCPServer):
|
||||
def handle_stream(self, stream, address):
|
||||
stream.wait_for_handshake(self.handshake_done)
|
||||
test.assertRaises(RuntimeError, stream.wait_for_handshake)
|
||||
|
||||
def handshake_done(self):
|
||||
handshake_future.set_result(None)
|
||||
|
||||
yield self.connect_to_server(TestServer)
|
||||
yield handshake_future
|
||||
|
||||
@gen_test
|
||||
def test_wait_for_handshake_already_connected(self):
|
||||
handshake_future = Future()
|
||||
|
||||
class TestServer(TCPServer):
|
||||
def handle_stream(self, stream, address):
|
||||
self.stream = stream
|
||||
stream.wait_for_handshake(self.handshake_done)
|
||||
|
||||
def handshake_done(self):
|
||||
self.stream.wait_for_handshake(self.handshake2_done)
|
||||
|
||||
def handshake2_done(self):
|
||||
handshake_future.set_result(None)
|
||||
|
||||
yield self.connect_to_server(TestServer)
|
||||
yield handshake_future
|
||||
|
||||
|
||||
@skipIfNonUnix
|
||||
|
|
|
@ -41,6 +41,12 @@ class TranslationLoaderTest(unittest.TestCase):
|
|||
locale = tornado.locale.get("fr_FR")
|
||||
self.assertTrue(isinstance(locale, tornado.locale.GettextLocale))
|
||||
self.assertEqual(locale.translate("school"), u("\u00e9cole"))
|
||||
self.assertEqual(locale.pgettext("law", "right"), u("le droit"))
|
||||
self.assertEqual(locale.pgettext("good", "right"), u("le bien"))
|
||||
self.assertEqual(locale.pgettext("organization", "club", "clubs", 1), u("le club"))
|
||||
self.assertEqual(locale.pgettext("organization", "club", "clubs", 2), u("les clubs"))
|
||||
self.assertEqual(locale.pgettext("stick", "club", "clubs", 1), u("le b\xe2ton"))
|
||||
self.assertEqual(locale.pgettext("stick", "club", "clubs", 2), u("les b\xe2tons"))
|
||||
|
||||
|
||||
class LocaleDataTest(unittest.TestCase):
|
||||
|
|
480
tornado/test/locks_test.py
Normal file
480
tornado/test/locks_test.py
Normal file
|
@ -0,0 +1,480 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from tornado import gen, locks
|
||||
from tornado.gen import TimeoutError
|
||||
from tornado.testing import gen_test, AsyncTestCase
|
||||
from tornado.test.util import unittest
|
||||
|
||||
|
||||
class ConditionTest(AsyncTestCase):
|
||||
def setUp(self):
|
||||
super(ConditionTest, self).setUp()
|
||||
self.history = []
|
||||
|
||||
def record_done(self, future, key):
|
||||
"""Record the resolution of a Future returned by Condition.wait."""
|
||||
def callback(_):
|
||||
if not future.result():
|
||||
# wait() resolved to False, meaning it timed out.
|
||||
self.history.append('timeout')
|
||||
else:
|
||||
self.history.append(key)
|
||||
future.add_done_callback(callback)
|
||||
|
||||
def test_repr(self):
|
||||
c = locks.Condition()
|
||||
self.assertIn('Condition', repr(c))
|
||||
self.assertNotIn('waiters', repr(c))
|
||||
c.wait()
|
||||
self.assertIn('waiters', repr(c))
|
||||
|
||||
@gen_test
|
||||
def test_notify(self):
|
||||
c = locks.Condition()
|
||||
self.io_loop.call_later(0.01, c.notify)
|
||||
yield c.wait()
|
||||
|
||||
def test_notify_1(self):
|
||||
c = locks.Condition()
|
||||
self.record_done(c.wait(), 'wait1')
|
||||
self.record_done(c.wait(), 'wait2')
|
||||
c.notify(1)
|
||||
self.history.append('notify1')
|
||||
c.notify(1)
|
||||
self.history.append('notify2')
|
||||
self.assertEqual(['wait1', 'notify1', 'wait2', 'notify2'],
|
||||
self.history)
|
||||
|
||||
def test_notify_n(self):
|
||||
c = locks.Condition()
|
||||
for i in range(6):
|
||||
self.record_done(c.wait(), i)
|
||||
|
||||
c.notify(3)
|
||||
|
||||
# Callbacks execute in the order they were registered.
|
||||
self.assertEqual(list(range(3)), self.history)
|
||||
c.notify(1)
|
||||
self.assertEqual(list(range(4)), self.history)
|
||||
c.notify(2)
|
||||
self.assertEqual(list(range(6)), self.history)
|
||||
|
||||
def test_notify_all(self):
|
||||
c = locks.Condition()
|
||||
for i in range(4):
|
||||
self.record_done(c.wait(), i)
|
||||
|
||||
c.notify_all()
|
||||
self.history.append('notify_all')
|
||||
|
||||
# Callbacks execute in the order they were registered.
|
||||
self.assertEqual(
|
||||
list(range(4)) + ['notify_all'],
|
||||
self.history)
|
||||
|
||||
@gen_test
|
||||
def test_wait_timeout(self):
|
||||
c = locks.Condition()
|
||||
wait = c.wait(timedelta(seconds=0.01))
|
||||
self.io_loop.call_later(0.02, c.notify) # Too late.
|
||||
yield gen.sleep(0.03)
|
||||
self.assertFalse((yield wait))
|
||||
|
||||
@gen_test
|
||||
def test_wait_timeout_preempted(self):
|
||||
c = locks.Condition()
|
||||
|
||||
# This fires before the wait times out.
|
||||
self.io_loop.call_later(0.01, c.notify)
|
||||
wait = c.wait(timedelta(seconds=0.02))
|
||||
yield gen.sleep(0.03)
|
||||
yield wait # No TimeoutError.
|
||||
|
||||
@gen_test
|
||||
def test_notify_n_with_timeout(self):
|
||||
# Register callbacks 0, 1, 2, and 3. Callback 1 has a timeout.
|
||||
# Wait for that timeout to expire, then do notify(2) and make
|
||||
# sure everyone runs. Verifies that a timed-out callback does
|
||||
# not count against the 'n' argument to notify().
|
||||
c = locks.Condition()
|
||||
self.record_done(c.wait(), 0)
|
||||
self.record_done(c.wait(timedelta(seconds=0.01)), 1)
|
||||
self.record_done(c.wait(), 2)
|
||||
self.record_done(c.wait(), 3)
|
||||
|
||||
# Wait for callback 1 to time out.
|
||||
yield gen.sleep(0.02)
|
||||
self.assertEqual(['timeout'], self.history)
|
||||
|
||||
c.notify(2)
|
||||
yield gen.sleep(0.01)
|
||||
self.assertEqual(['timeout', 0, 2], self.history)
|
||||
self.assertEqual(['timeout', 0, 2], self.history)
|
||||
c.notify()
|
||||
self.assertEqual(['timeout', 0, 2, 3], self.history)
|
||||
|
||||
@gen_test
|
||||
def test_notify_all_with_timeout(self):
|
||||
c = locks.Condition()
|
||||
self.record_done(c.wait(), 0)
|
||||
self.record_done(c.wait(timedelta(seconds=0.01)), 1)
|
||||
self.record_done(c.wait(), 2)
|
||||
|
||||
# Wait for callback 1 to time out.
|
||||
yield gen.sleep(0.02)
|
||||
self.assertEqual(['timeout'], self.history)
|
||||
|
||||
c.notify_all()
|
||||
self.assertEqual(['timeout', 0, 2], self.history)
|
||||
|
||||
@gen_test
|
||||
def test_nested_notify(self):
|
||||
# Ensure no notifications lost, even if notify() is reentered by a
|
||||
# waiter calling notify().
|
||||
c = locks.Condition()
|
||||
|
||||
# Three waiters.
|
||||
futures = [c.wait() for _ in range(3)]
|
||||
|
||||
# First and second futures resolved. Second future reenters notify(),
|
||||
# resolving third future.
|
||||
futures[1].add_done_callback(lambda _: c.notify())
|
||||
c.notify(2)
|
||||
self.assertTrue(all(f.done() for f in futures))
|
||||
|
||||
@gen_test
|
||||
def test_garbage_collection(self):
|
||||
# Test that timed-out waiters are occasionally cleaned from the queue.
|
||||
c = locks.Condition()
|
||||
for _ in range(101):
|
||||
c.wait(timedelta(seconds=0.01))
|
||||
|
||||
future = c.wait()
|
||||
self.assertEqual(102, len(c._waiters))
|
||||
|
||||
# Let first 101 waiters time out, triggering a collection.
|
||||
yield gen.sleep(0.02)
|
||||
self.assertEqual(1, len(c._waiters))
|
||||
|
||||
# Final waiter is still active.
|
||||
self.assertFalse(future.done())
|
||||
c.notify()
|
||||
self.assertTrue(future.done())
|
||||
|
||||
|
||||
class EventTest(AsyncTestCase):
|
||||
def test_repr(self):
|
||||
event = locks.Event()
|
||||
self.assertTrue('clear' in str(event))
|
||||
self.assertFalse('set' in str(event))
|
||||
event.set()
|
||||
self.assertFalse('clear' in str(event))
|
||||
self.assertTrue('set' in str(event))
|
||||
|
||||
def test_event(self):
|
||||
e = locks.Event()
|
||||
future_0 = e.wait()
|
||||
e.set()
|
||||
future_1 = e.wait()
|
||||
e.clear()
|
||||
future_2 = e.wait()
|
||||
|
||||
self.assertTrue(future_0.done())
|
||||
self.assertTrue(future_1.done())
|
||||
self.assertFalse(future_2.done())
|
||||
|
||||
@gen_test
|
||||
def test_event_timeout(self):
|
||||
e = locks.Event()
|
||||
with self.assertRaises(TimeoutError):
|
||||
yield e.wait(timedelta(seconds=0.01))
|
||||
|
||||
# After a timed-out waiter, normal operation works.
|
||||
self.io_loop.add_timeout(timedelta(seconds=0.01), e.set)
|
||||
yield e.wait(timedelta(seconds=1))
|
||||
|
||||
def test_event_set_multiple(self):
|
||||
e = locks.Event()
|
||||
e.set()
|
||||
e.set()
|
||||
self.assertTrue(e.is_set())
|
||||
|
||||
def test_event_wait_clear(self):
|
||||
e = locks.Event()
|
||||
f0 = e.wait()
|
||||
e.clear()
|
||||
f1 = e.wait()
|
||||
e.set()
|
||||
self.assertTrue(f0.done())
|
||||
self.assertTrue(f1.done())
|
||||
|
||||
|
||||
class SemaphoreTest(AsyncTestCase):
|
||||
def test_negative_value(self):
|
||||
self.assertRaises(ValueError, locks.Semaphore, value=-1)
|
||||
|
||||
def test_repr(self):
|
||||
sem = locks.Semaphore()
|
||||
self.assertIn('Semaphore', repr(sem))
|
||||
self.assertIn('unlocked,value:1', repr(sem))
|
||||
sem.acquire()
|
||||
self.assertIn('locked', repr(sem))
|
||||
self.assertNotIn('waiters', repr(sem))
|
||||
sem.acquire()
|
||||
self.assertIn('waiters', repr(sem))
|
||||
|
||||
def test_acquire(self):
|
||||
sem = locks.Semaphore()
|
||||
f0 = sem.acquire()
|
||||
self.assertTrue(f0.done())
|
||||
|
||||
# Wait for release().
|
||||
f1 = sem.acquire()
|
||||
self.assertFalse(f1.done())
|
||||
f2 = sem.acquire()
|
||||
sem.release()
|
||||
self.assertTrue(f1.done())
|
||||
self.assertFalse(f2.done())
|
||||
sem.release()
|
||||
self.assertTrue(f2.done())
|
||||
|
||||
sem.release()
|
||||
# Now acquire() is instant.
|
||||
self.assertTrue(sem.acquire().done())
|
||||
self.assertEqual(0, len(sem._waiters))
|
||||
|
||||
@gen_test
|
||||
def test_acquire_timeout(self):
|
||||
sem = locks.Semaphore(2)
|
||||
yield sem.acquire()
|
||||
yield sem.acquire()
|
||||
acquire = sem.acquire(timedelta(seconds=0.01))
|
||||
self.io_loop.call_later(0.02, sem.release) # Too late.
|
||||
yield gen.sleep(0.3)
|
||||
with self.assertRaises(gen.TimeoutError):
|
||||
yield acquire
|
||||
|
||||
sem.acquire()
|
||||
f = sem.acquire()
|
||||
self.assertFalse(f.done())
|
||||
sem.release()
|
||||
self.assertTrue(f.done())
|
||||
|
||||
@gen_test
|
||||
def test_acquire_timeout_preempted(self):
|
||||
sem = locks.Semaphore(1)
|
||||
yield sem.acquire()
|
||||
|
||||
# This fires before the wait times out.
|
||||
self.io_loop.call_later(0.01, sem.release)
|
||||
acquire = sem.acquire(timedelta(seconds=0.02))
|
||||
yield gen.sleep(0.03)
|
||||
yield acquire # No TimeoutError.
|
||||
|
||||
def test_release_unacquired(self):
|
||||
# Unbounded releases are allowed, and increment the semaphore's value.
|
||||
sem = locks.Semaphore()
|
||||
sem.release()
|
||||
sem.release()
|
||||
|
||||
# Now the counter is 3. We can acquire three times before blocking.
|
||||
self.assertTrue(sem.acquire().done())
|
||||
self.assertTrue(sem.acquire().done())
|
||||
self.assertTrue(sem.acquire().done())
|
||||
self.assertFalse(sem.acquire().done())
|
||||
|
||||
@gen_test
|
||||
def test_garbage_collection(self):
|
||||
# Test that timed-out waiters are occasionally cleaned from the queue.
|
||||
sem = locks.Semaphore(value=0)
|
||||
futures = [sem.acquire(timedelta(seconds=0.01)) for _ in range(101)]
|
||||
|
||||
future = sem.acquire()
|
||||
self.assertEqual(102, len(sem._waiters))
|
||||
|
||||
# Let first 101 waiters time out, triggering a collection.
|
||||
yield gen.sleep(0.02)
|
||||
self.assertEqual(1, len(sem._waiters))
|
||||
|
||||
# Final waiter is still active.
|
||||
self.assertFalse(future.done())
|
||||
sem.release()
|
||||
self.assertTrue(future.done())
|
||||
|
||||
# Prevent "Future exception was never retrieved" messages.
|
||||
for future in futures:
|
||||
self.assertRaises(TimeoutError, future.result)
|
||||
|
||||
|
||||
class SemaphoreContextManagerTest(AsyncTestCase):
|
||||
@gen_test
|
||||
def test_context_manager(self):
|
||||
sem = locks.Semaphore()
|
||||
with (yield sem.acquire()) as yielded:
|
||||
self.assertTrue(yielded is None)
|
||||
|
||||
# Semaphore was released and can be acquired again.
|
||||
self.assertTrue(sem.acquire().done())
|
||||
|
||||
@gen_test
|
||||
def test_context_manager_exception(self):
|
||||
sem = locks.Semaphore()
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
with (yield sem.acquire()):
|
||||
1 / 0
|
||||
|
||||
# Semaphore was released and can be acquired again.
|
||||
self.assertTrue(sem.acquire().done())
|
||||
|
||||
@gen_test
|
||||
def test_context_manager_timeout(self):
|
||||
sem = locks.Semaphore()
|
||||
with (yield sem.acquire(timedelta(seconds=0.01))):
|
||||
pass
|
||||
|
||||
# Semaphore was released and can be acquired again.
|
||||
self.assertTrue(sem.acquire().done())
|
||||
|
||||
@gen_test
|
||||
def test_context_manager_timeout_error(self):
|
||||
sem = locks.Semaphore(value=0)
|
||||
with self.assertRaises(gen.TimeoutError):
|
||||
with (yield sem.acquire(timedelta(seconds=0.01))):
|
||||
pass
|
||||
|
||||
# Counter is still 0.
|
||||
self.assertFalse(sem.acquire().done())
|
||||
|
||||
@gen_test
|
||||
def test_context_manager_contended(self):
|
||||
sem = locks.Semaphore()
|
||||
history = []
|
||||
|
||||
@gen.coroutine
|
||||
def f(index):
|
||||
with (yield sem.acquire()):
|
||||
history.append('acquired %d' % index)
|
||||
yield gen.sleep(0.01)
|
||||
history.append('release %d' % index)
|
||||
|
||||
yield [f(i) for i in range(2)]
|
||||
|
||||
expected_history = []
|
||||
for i in range(2):
|
||||
expected_history.extend(['acquired %d' % i, 'release %d' % i])
|
||||
|
||||
self.assertEqual(expected_history, history)
|
||||
|
||||
@gen_test
|
||||
def test_yield_sem(self):
|
||||
# Ensure we catch a "with (yield sem)", which should be
|
||||
# "with (yield sem.acquire())".
|
||||
with self.assertRaises(gen.BadYieldError):
|
||||
with (yield locks.Semaphore()):
|
||||
pass
|
||||
|
||||
def test_context_manager_misuse(self):
|
||||
# Ensure we catch a "with sem", which should be
|
||||
# "with (yield sem.acquire())".
|
||||
with self.assertRaises(RuntimeError):
|
||||
with locks.Semaphore():
|
||||
pass
|
||||
|
||||
|
||||
class BoundedSemaphoreTest(AsyncTestCase):
|
||||
def test_release_unacquired(self):
|
||||
sem = locks.BoundedSemaphore()
|
||||
self.assertRaises(ValueError, sem.release)
|
||||
# Value is 0.
|
||||
sem.acquire()
|
||||
# Block on acquire().
|
||||
future = sem.acquire()
|
||||
self.assertFalse(future.done())
|
||||
sem.release()
|
||||
self.assertTrue(future.done())
|
||||
# Value is 1.
|
||||
sem.release()
|
||||
self.assertRaises(ValueError, sem.release)
|
||||
|
||||
|
||||
class LockTests(AsyncTestCase):
|
||||
def test_repr(self):
|
||||
lock = locks.Lock()
|
||||
# No errors.
|
||||
repr(lock)
|
||||
lock.acquire()
|
||||
repr(lock)
|
||||
|
||||
def test_acquire_release(self):
|
||||
lock = locks.Lock()
|
||||
self.assertTrue(lock.acquire().done())
|
||||
future = lock.acquire()
|
||||
self.assertFalse(future.done())
|
||||
lock.release()
|
||||
self.assertTrue(future.done())
|
||||
|
||||
@gen_test
|
||||
def test_acquire_fifo(self):
|
||||
lock = locks.Lock()
|
||||
self.assertTrue(lock.acquire().done())
|
||||
N = 5
|
||||
history = []
|
||||
|
||||
@gen.coroutine
|
||||
def f(idx):
|
||||
with (yield lock.acquire()):
|
||||
history.append(idx)
|
||||
|
||||
futures = [f(i) for i in range(N)]
|
||||
self.assertFalse(any(future.done() for future in futures))
|
||||
lock.release()
|
||||
yield futures
|
||||
self.assertEqual(list(range(N)), history)
|
||||
|
||||
@gen_test
|
||||
def test_acquire_timeout(self):
|
||||
lock = locks.Lock()
|
||||
lock.acquire()
|
||||
with self.assertRaises(gen.TimeoutError):
|
||||
yield lock.acquire(timeout=timedelta(seconds=0.01))
|
||||
|
||||
# Still locked.
|
||||
self.assertFalse(lock.acquire().done())
|
||||
|
||||
def test_multi_release(self):
|
||||
lock = locks.Lock()
|
||||
self.assertRaises(RuntimeError, lock.release)
|
||||
lock.acquire()
|
||||
lock.release()
|
||||
self.assertRaises(RuntimeError, lock.release)
|
||||
|
||||
@gen_test
|
||||
def test_yield_lock(self):
|
||||
# Ensure we catch a "with (yield lock)", which should be
|
||||
# "with (yield lock.acquire())".
|
||||
with self.assertRaises(gen.BadYieldError):
|
||||
with (yield locks.Lock()):
|
||||
pass
|
||||
|
||||
def test_context_manager_misuse(self):
|
||||
# Ensure we catch a "with lock", which should be
|
||||
# "with (yield lock.acquire())".
|
||||
with self.assertRaises(RuntimeError):
|
||||
with locks.Lock():
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -67,10 +67,12 @@ class _ResolverErrorTestMixin(object):
|
|||
yield self.resolver.resolve('an invalid domain', 80,
|
||||
socket.AF_UNSPEC)
|
||||
|
||||
|
||||
def _failing_getaddrinfo(*args):
|
||||
"""Dummy implementation of getaddrinfo for use in mocks"""
|
||||
raise socket.gaierror("mock: lookup failed")
|
||||
|
||||
|
||||
@skipIfNoNetwork
|
||||
class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
|
||||
def setUp(self):
|
||||
|
|
|
@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop
|
|||
from tornado.log import gen_log
|
||||
from tornado.process import fork_processes, task_id, Subprocess
|
||||
from tornado.simple_httpclient import SimpleAsyncHTTPClient
|
||||
from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase
|
||||
from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase, gen_test
|
||||
from tornado.test.util import unittest, skipIfNonUnix
|
||||
from tornado.web import RequestHandler, Application
|
||||
|
||||
|
@ -85,7 +85,7 @@ class ProcessTest(unittest.TestCase):
|
|||
self.assertEqual(id, task_id())
|
||||
server = HTTPServer(self.get_app())
|
||||
server.add_sockets([sock])
|
||||
IOLoop.instance().start()
|
||||
IOLoop.current().start()
|
||||
elif id == 2:
|
||||
self.assertEqual(id, task_id())
|
||||
sock.close()
|
||||
|
@ -200,6 +200,16 @@ class SubprocessTest(AsyncTestCase):
|
|||
self.assertEqual(ret, 0)
|
||||
self.assertEqual(subproc.returncode, ret)
|
||||
|
||||
@gen_test
|
||||
def test_sigchild_future(self):
|
||||
skip_if_twisted()
|
||||
Subprocess.initialize()
|
||||
self.addCleanup(Subprocess.uninitialize)
|
||||
subproc = Subprocess([sys.executable, '-c', 'pass'])
|
||||
ret = yield subproc.wait_for_exit()
|
||||
self.assertEqual(ret, 0)
|
||||
self.assertEqual(subproc.returncode, ret)
|
||||
|
||||
def test_sigchild_signal(self):
|
||||
skip_if_twisted()
|
||||
Subprocess.initialize(io_loop=self.io_loop)
|
||||
|
@ -212,3 +222,22 @@ class SubprocessTest(AsyncTestCase):
|
|||
ret = self.wait()
|
||||
self.assertEqual(subproc.returncode, ret)
|
||||
self.assertEqual(ret, -signal.SIGTERM)
|
||||
|
||||
@gen_test
|
||||
def test_wait_for_exit_raise(self):
|
||||
skip_if_twisted()
|
||||
Subprocess.initialize()
|
||||
self.addCleanup(Subprocess.uninitialize)
|
||||
subproc = Subprocess([sys.executable, '-c', 'import sys; sys.exit(1)'])
|
||||
with self.assertRaises(subprocess.CalledProcessError) as cm:
|
||||
yield subproc.wait_for_exit()
|
||||
self.assertEqual(cm.exception.returncode, 1)
|
||||
|
||||
@gen_test
|
||||
def test_wait_for_exit_raise_disabled(self):
|
||||
skip_if_twisted()
|
||||
Subprocess.initialize()
|
||||
self.addCleanup(Subprocess.uninitialize)
|
||||
subproc = Subprocess([sys.executable, '-c', 'import sys; sys.exit(1)'])
|
||||
ret = yield subproc.wait_for_exit(raise_error=False)
|
||||
self.assertEqual(ret, 1)
|
||||
|
|
403
tornado/test/queues_test.py
Normal file
403
tornado/test/queues_test.py
Normal file
|
@ -0,0 +1,403 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from datetime import timedelta
|
||||
from random import random
|
||||
|
||||
from tornado import gen, queues
|
||||
from tornado.gen import TimeoutError
|
||||
from tornado.testing import gen_test, AsyncTestCase
|
||||
from tornado.test.util import unittest
|
||||
|
||||
|
||||
class QueueBasicTest(AsyncTestCase):
|
||||
def test_repr_and_str(self):
|
||||
q = queues.Queue(maxsize=1)
|
||||
self.assertIn(hex(id(q)), repr(q))
|
||||
self.assertNotIn(hex(id(q)), str(q))
|
||||
q.get()
|
||||
|
||||
for q_str in repr(q), str(q):
|
||||
self.assertTrue(q_str.startswith('<Queue'))
|
||||
self.assertIn('maxsize=1', q_str)
|
||||
self.assertIn('getters[1]', q_str)
|
||||
self.assertNotIn('putters', q_str)
|
||||
self.assertNotIn('tasks', q_str)
|
||||
|
||||
q.put(None)
|
||||
q.put(None)
|
||||
# Now the queue is full, this putter blocks.
|
||||
q.put(None)
|
||||
|
||||
for q_str in repr(q), str(q):
|
||||
self.assertNotIn('getters', q_str)
|
||||
self.assertIn('putters[1]', q_str)
|
||||
self.assertIn('tasks=2', q_str)
|
||||
|
||||
def test_order(self):
|
||||
q = queues.Queue()
|
||||
for i in [1, 3, 2]:
|
||||
q.put_nowait(i)
|
||||
|
||||
items = [q.get_nowait() for _ in range(3)]
|
||||
self.assertEqual([1, 3, 2], items)
|
||||
|
||||
@gen_test
|
||||
def test_maxsize(self):
|
||||
self.assertRaises(TypeError, queues.Queue, maxsize=None)
|
||||
self.assertRaises(ValueError, queues.Queue, maxsize=-1)
|
||||
|
||||
q = queues.Queue(maxsize=2)
|
||||
self.assertTrue(q.empty())
|
||||
self.assertFalse(q.full())
|
||||
self.assertEqual(2, q.maxsize)
|
||||
self.assertTrue(q.put(0).done())
|
||||
self.assertTrue(q.put(1).done())
|
||||
self.assertFalse(q.empty())
|
||||
self.assertTrue(q.full())
|
||||
put2 = q.put(2)
|
||||
self.assertFalse(put2.done())
|
||||
self.assertEqual(0, (yield q.get())) # Make room.
|
||||
self.assertTrue(put2.done())
|
||||
self.assertFalse(q.empty())
|
||||
self.assertTrue(q.full())
|
||||
|
||||
|
||||
class QueueGetTest(AsyncTestCase):
|
||||
@gen_test
|
||||
def test_blocking_get(self):
|
||||
q = queues.Queue()
|
||||
q.put_nowait(0)
|
||||
self.assertEqual(0, (yield q.get()))
|
||||
|
||||
def test_nonblocking_get(self):
|
||||
q = queues.Queue()
|
||||
q.put_nowait(0)
|
||||
self.assertEqual(0, q.get_nowait())
|
||||
|
||||
def test_nonblocking_get_exception(self):
|
||||
q = queues.Queue()
|
||||
self.assertRaises(queues.QueueEmpty, q.get_nowait)
|
||||
|
||||
@gen_test
|
||||
def test_get_with_putters(self):
|
||||
q = queues.Queue(1)
|
||||
q.put_nowait(0)
|
||||
put = q.put(1)
|
||||
self.assertEqual(0, (yield q.get()))
|
||||
self.assertIsNone((yield put))
|
||||
|
||||
@gen_test
|
||||
def test_blocking_get_wait(self):
|
||||
q = queues.Queue()
|
||||
q.put(0)
|
||||
self.io_loop.call_later(0.01, q.put, 1)
|
||||
self.io_loop.call_later(0.02, q.put, 2)
|
||||
self.assertEqual(0, (yield q.get(timeout=timedelta(seconds=1))))
|
||||
self.assertEqual(1, (yield q.get(timeout=timedelta(seconds=1))))
|
||||
|
||||
@gen_test
|
||||
def test_get_timeout(self):
|
||||
q = queues.Queue()
|
||||
get_timeout = q.get(timeout=timedelta(seconds=0.01))
|
||||
get = q.get()
|
||||
with self.assertRaises(TimeoutError):
|
||||
yield get_timeout
|
||||
|
||||
q.put_nowait(0)
|
||||
self.assertEqual(0, (yield get))
|
||||
|
||||
@gen_test
|
||||
def test_get_timeout_preempted(self):
|
||||
q = queues.Queue()
|
||||
get = q.get(timeout=timedelta(seconds=0.01))
|
||||
q.put(0)
|
||||
yield gen.sleep(0.02)
|
||||
self.assertEqual(0, (yield get))
|
||||
|
||||
@gen_test
|
||||
def test_get_clears_timed_out_putters(self):
|
||||
q = queues.Queue(1)
|
||||
# First putter succeeds, remainder block.
|
||||
putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)]
|
||||
put = q.put(10)
|
||||
self.assertEqual(10, len(q._putters))
|
||||
yield gen.sleep(0.02)
|
||||
self.assertEqual(10, len(q._putters))
|
||||
self.assertFalse(put.done()) # Final waiter is still active.
|
||||
q.put(11)
|
||||
self.assertEqual(0, (yield q.get())) # get() clears the waiters.
|
||||
self.assertEqual(1, len(q._putters))
|
||||
for putter in putters[1:]:
|
||||
self.assertRaises(TimeoutError, putter.result)
|
||||
|
||||
@gen_test
|
||||
def test_get_clears_timed_out_getters(self):
|
||||
q = queues.Queue()
|
||||
getters = [q.get(timedelta(seconds=0.01)) for _ in range(10)]
|
||||
get = q.get()
|
||||
self.assertEqual(11, len(q._getters))
|
||||
yield gen.sleep(0.02)
|
||||
self.assertEqual(11, len(q._getters))
|
||||
self.assertFalse(get.done()) # Final waiter is still active.
|
||||
q.get() # get() clears the waiters.
|
||||
self.assertEqual(2, len(q._getters))
|
||||
for getter in getters:
|
||||
self.assertRaises(TimeoutError, getter.result)
|
||||
|
||||
|
||||
class QueuePutTest(AsyncTestCase):
|
||||
@gen_test
|
||||
def test_blocking_put(self):
|
||||
q = queues.Queue()
|
||||
q.put(0)
|
||||
self.assertEqual(0, q.get_nowait())
|
||||
|
||||
def test_nonblocking_put_exception(self):
|
||||
q = queues.Queue(1)
|
||||
q.put(0)
|
||||
self.assertRaises(queues.QueueFull, q.put_nowait, 1)
|
||||
|
||||
@gen_test
|
||||
def test_put_with_getters(self):
|
||||
q = queues.Queue()
|
||||
get0 = q.get()
|
||||
get1 = q.get()
|
||||
yield q.put(0)
|
||||
self.assertEqual(0, (yield get0))
|
||||
yield q.put(1)
|
||||
self.assertEqual(1, (yield get1))
|
||||
|
||||
@gen_test
|
||||
def test_nonblocking_put_with_getters(self):
|
||||
q = queues.Queue()
|
||||
get0 = q.get()
|
||||
get1 = q.get()
|
||||
q.put_nowait(0)
|
||||
# put_nowait does *not* immediately unblock getters.
|
||||
yield gen.moment
|
||||
self.assertEqual(0, (yield get0))
|
||||
q.put_nowait(1)
|
||||
yield gen.moment
|
||||
self.assertEqual(1, (yield get1))
|
||||
|
||||
@gen_test
|
||||
def test_blocking_put_wait(self):
|
||||
q = queues.Queue(1)
|
||||
q.put_nowait(0)
|
||||
self.io_loop.call_later(0.01, q.get)
|
||||
self.io_loop.call_later(0.02, q.get)
|
||||
futures = [q.put(0), q.put(1)]
|
||||
self.assertFalse(any(f.done() for f in futures))
|
||||
yield futures
|
||||
|
||||
@gen_test
|
||||
def test_put_timeout(self):
|
||||
q = queues.Queue(1)
|
||||
q.put_nowait(0) # Now it's full.
|
||||
put_timeout = q.put(1, timeout=timedelta(seconds=0.01))
|
||||
put = q.put(2)
|
||||
with self.assertRaises(TimeoutError):
|
||||
yield put_timeout
|
||||
|
||||
self.assertEqual(0, q.get_nowait())
|
||||
# 1 was never put in the queue.
|
||||
self.assertEqual(2, (yield q.get()))
|
||||
|
||||
# Final get() unblocked this putter.
|
||||
yield put
|
||||
|
||||
@gen_test
|
||||
def test_put_timeout_preempted(self):
|
||||
q = queues.Queue(1)
|
||||
q.put_nowait(0)
|
||||
put = q.put(1, timeout=timedelta(seconds=0.01))
|
||||
q.get()
|
||||
yield gen.sleep(0.02)
|
||||
yield put # No TimeoutError.
|
||||
|
||||
@gen_test
|
||||
def test_put_clears_timed_out_putters(self):
|
||||
q = queues.Queue(1)
|
||||
# First putter succeeds, remainder block.
|
||||
putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)]
|
||||
put = q.put(10)
|
||||
self.assertEqual(10, len(q._putters))
|
||||
yield gen.sleep(0.02)
|
||||
self.assertEqual(10, len(q._putters))
|
||||
self.assertFalse(put.done()) # Final waiter is still active.
|
||||
q.put(11) # put() clears the waiters.
|
||||
self.assertEqual(2, len(q._putters))
|
||||
for putter in putters[1:]:
|
||||
self.assertRaises(TimeoutError, putter.result)
|
||||
|
||||
@gen_test
|
||||
def test_put_clears_timed_out_getters(self):
|
||||
q = queues.Queue()
|
||||
getters = [q.get(timedelta(seconds=0.01)) for _ in range(10)]
|
||||
get = q.get()
|
||||
q.get()
|
||||
self.assertEqual(12, len(q._getters))
|
||||
yield gen.sleep(0.02)
|
||||
self.assertEqual(12, len(q._getters))
|
||||
self.assertFalse(get.done()) # Final waiters still active.
|
||||
q.put(0) # put() clears the waiters.
|
||||
self.assertEqual(1, len(q._getters))
|
||||
self.assertEqual(0, (yield get))
|
||||
for getter in getters:
|
||||
self.assertRaises(TimeoutError, getter.result)
|
||||
|
||||
@gen_test
|
||||
def test_float_maxsize(self):
|
||||
# Non-int maxsize must round down: http://bugs.python.org/issue21723
|
||||
q = queues.Queue(maxsize=1.3)
|
||||
self.assertTrue(q.empty())
|
||||
self.assertFalse(q.full())
|
||||
q.put_nowait(0)
|
||||
q.put_nowait(1)
|
||||
self.assertFalse(q.empty())
|
||||
self.assertTrue(q.full())
|
||||
self.assertRaises(queues.QueueFull, q.put_nowait, 2)
|
||||
self.assertEqual(0, q.get_nowait())
|
||||
self.assertFalse(q.empty())
|
||||
self.assertFalse(q.full())
|
||||
|
||||
yield q.put(2)
|
||||
put = q.put(3)
|
||||
self.assertFalse(put.done())
|
||||
self.assertEqual(1, (yield q.get()))
|
||||
yield put
|
||||
self.assertTrue(q.full())
|
||||
|
||||
|
||||
class QueueJoinTest(AsyncTestCase):
|
||||
queue_class = queues.Queue
|
||||
|
||||
def test_task_done_underflow(self):
|
||||
q = self.queue_class()
|
||||
self.assertRaises(ValueError, q.task_done)
|
||||
|
||||
@gen_test
|
||||
def test_task_done(self):
|
||||
q = self.queue_class()
|
||||
for i in range(100):
|
||||
q.put_nowait(i)
|
||||
|
||||
self.accumulator = 0
|
||||
|
||||
@gen.coroutine
|
||||
def worker():
|
||||
while True:
|
||||
item = yield q.get()
|
||||
self.accumulator += item
|
||||
q.task_done()
|
||||
yield gen.sleep(random() * 0.01)
|
||||
|
||||
# Two coroutines share work.
|
||||
worker()
|
||||
worker()
|
||||
yield q.join()
|
||||
self.assertEqual(sum(range(100)), self.accumulator)
|
||||
|
||||
@gen_test
|
||||
def test_task_done_delay(self):
|
||||
# Verify it is task_done(), not get(), that unblocks join().
|
||||
q = self.queue_class()
|
||||
q.put_nowait(0)
|
||||
join = q.join()
|
||||
self.assertFalse(join.done())
|
||||
yield q.get()
|
||||
self.assertFalse(join.done())
|
||||
yield gen.moment
|
||||
self.assertFalse(join.done())
|
||||
q.task_done()
|
||||
self.assertTrue(join.done())
|
||||
|
||||
@gen_test
|
||||
def test_join_empty_queue(self):
|
||||
q = self.queue_class()
|
||||
yield q.join()
|
||||
yield q.join()
|
||||
|
||||
@gen_test
|
||||
def test_join_timeout(self):
|
||||
q = self.queue_class()
|
||||
q.put(0)
|
||||
with self.assertRaises(TimeoutError):
|
||||
yield q.join(timeout=timedelta(seconds=0.01))
|
||||
|
||||
|
||||
class PriorityQueueJoinTest(QueueJoinTest):
|
||||
queue_class = queues.PriorityQueue
|
||||
|
||||
@gen_test
|
||||
def test_order(self):
|
||||
q = self.queue_class(maxsize=2)
|
||||
q.put_nowait((1, 'a'))
|
||||
q.put_nowait((0, 'b'))
|
||||
self.assertTrue(q.full())
|
||||
q.put((3, 'c'))
|
||||
q.put((2, 'd'))
|
||||
self.assertEqual((0, 'b'), q.get_nowait())
|
||||
self.assertEqual((1, 'a'), (yield q.get()))
|
||||
self.assertEqual((2, 'd'), q.get_nowait())
|
||||
self.assertEqual((3, 'c'), (yield q.get()))
|
||||
self.assertTrue(q.empty())
|
||||
|
||||
|
||||
class LifoQueueJoinTest(QueueJoinTest):
|
||||
queue_class = queues.LifoQueue
|
||||
|
||||
@gen_test
|
||||
def test_order(self):
|
||||
q = self.queue_class(maxsize=2)
|
||||
q.put_nowait(1)
|
||||
q.put_nowait(0)
|
||||
self.assertTrue(q.full())
|
||||
q.put(3)
|
||||
q.put(2)
|
||||
self.assertEqual(3, q.get_nowait())
|
||||
self.assertEqual(2, (yield q.get()))
|
||||
self.assertEqual(0, q.get_nowait())
|
||||
self.assertEqual(1, (yield q.get()))
|
||||
self.assertTrue(q.empty())
|
||||
|
||||
|
||||
class ProducerConsumerTest(AsyncTestCase):
|
||||
@gen_test
|
||||
def test_producer_consumer(self):
|
||||
q = queues.Queue(maxsize=3)
|
||||
history = []
|
||||
|
||||
# We don't yield between get() and task_done(), so get() must wait for
|
||||
# the next tick. Otherwise we'd immediately call task_done and unblock
|
||||
# join() before q.put() resumes, and we'd only process the first four
|
||||
# items.
|
||||
@gen.coroutine
|
||||
def consumer():
|
||||
while True:
|
||||
history.append((yield q.get()))
|
||||
q.task_done()
|
||||
|
||||
@gen.coroutine
|
||||
def producer():
|
||||
for item in range(10):
|
||||
yield q.put(item)
|
||||
|
||||
consumer()
|
||||
yield producer()
|
||||
yield q.join()
|
||||
self.assertEqual(list(range(10)), history)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -8,6 +8,7 @@ import operator
|
|||
import textwrap
|
||||
import sys
|
||||
from tornado.httpclient import AsyncHTTPClient
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.netutil import Resolver
|
||||
from tornado.options import define, options, add_parse_callback
|
||||
|
@ -22,6 +23,7 @@ TEST_MODULES = [
|
|||
'tornado.httputil.doctests',
|
||||
'tornado.iostream.doctests',
|
||||
'tornado.util.doctests',
|
||||
'tornado.test.asyncio_test',
|
||||
'tornado.test.auth_test',
|
||||
'tornado.test.concurrent_test',
|
||||
'tornado.test.curl_httpclient_test',
|
||||
|
@ -34,13 +36,16 @@ TEST_MODULES = [
|
|||
'tornado.test.ioloop_test',
|
||||
'tornado.test.iostream_test',
|
||||
'tornado.test.locale_test',
|
||||
'tornado.test.locks_test',
|
||||
'tornado.test.netutil_test',
|
||||
'tornado.test.log_test',
|
||||
'tornado.test.options_test',
|
||||
'tornado.test.process_test',
|
||||
'tornado.test.queues_test',
|
||||
'tornado.test.simple_httpclient_test',
|
||||
'tornado.test.stack_context_test',
|
||||
'tornado.test.tcpclient_test',
|
||||
'tornado.test.tcpserver_test',
|
||||
'tornado.test.template_test',
|
||||
'tornado.test.testing_test',
|
||||
'tornado.test.twisted_test',
|
||||
|
@ -67,6 +72,21 @@ class TornadoTextTestRunner(unittest.TextTestRunner):
|
|||
return result
|
||||
|
||||
|
||||
class LogCounter(logging.Filter):
|
||||
"""Counts the number of WARNING or higher log records."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Can't use super() because logging.Filter is an old-style class in py26
|
||||
logging.Filter.__init__(self, *args, **kwargs)
|
||||
self.warning_count = self.error_count = 0
|
||||
|
||||
def filter(self, record):
|
||||
if record.levelno >= logging.ERROR:
|
||||
self.error_count += 1
|
||||
elif record.levelno >= logging.WARNING:
|
||||
self.warning_count += 1
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
# The -W command-line option does not work in a virtualenv with
|
||||
# python 3 (as of virtualenv 1.7), so configure warnings
|
||||
|
@ -92,12 +112,21 @@ def main():
|
|||
# 2.7 and 3.2
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning,
|
||||
message="Please use assert.* instead")
|
||||
# unittest2 0.6 on py26 reports these as PendingDeprecationWarnings
|
||||
# instead of DeprecationWarnings.
|
||||
warnings.filterwarnings("ignore", category=PendingDeprecationWarning,
|
||||
message="Please use assert.* instead")
|
||||
# Twisted 15.0.0 triggers some warnings on py3 with -bb.
|
||||
warnings.filterwarnings("ignore", category=BytesWarning,
|
||||
module=r"twisted\..*")
|
||||
|
||||
logging.getLogger("tornado.access").setLevel(logging.CRITICAL)
|
||||
|
||||
define('httpclient', type=str, default=None,
|
||||
callback=lambda s: AsyncHTTPClient.configure(
|
||||
s, defaults=dict(allow_ipv6=False)))
|
||||
define('httpserver', type=str, default=None,
|
||||
callback=HTTPServer.configure)
|
||||
define('ioloop', type=str, default=None)
|
||||
define('ioloop_time_monotonic', default=False)
|
||||
define('resolver', type=str, default=None,
|
||||
|
@ -121,6 +150,10 @@ def main():
|
|||
IOLoop.configure(options.ioloop, **kwargs)
|
||||
add_parse_callback(configure_ioloop)
|
||||
|
||||
log_counter = LogCounter()
|
||||
add_parse_callback(
|
||||
lambda: logging.getLogger().handlers[0].addFilter(log_counter))
|
||||
|
||||
import tornado.testing
|
||||
kwargs = {}
|
||||
if sys.version_info >= (3, 2):
|
||||
|
@ -131,7 +164,16 @@ def main():
|
|||
# detail. http://bugs.python.org/issue15626
|
||||
kwargs['warnings'] = False
|
||||
kwargs['testRunner'] = TornadoTextTestRunner
|
||||
try:
|
||||
tornado.testing.main(**kwargs)
|
||||
finally:
|
||||
# The tests should run clean; consider it a failure if they logged
|
||||
# any warnings or errors. We'd like to ban info logs too, but
|
||||
# we can't count them cleanly due to interactions with LogTrapTestCase.
|
||||
if log_counter.warning_count > 0 or log_counter.error_count > 0:
|
||||
logging.error("logged %d warnings and %d errors",
|
||||
log_counter.warning_count, log_counter.error_count)
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -8,19 +8,20 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
|
||||
from tornado import gen
|
||||
from tornado.httpclient import AsyncHTTPClient
|
||||
from tornado.httputil import HTTPHeaders
|
||||
from tornado.httputil import HTTPHeaders, ResponseStartLine
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.log import gen_log
|
||||
from tornado.netutil import Resolver, bind_sockets
|
||||
from tornado.simple_httpclient import SimpleAsyncHTTPClient, _default_ca_certs
|
||||
from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler
|
||||
from tornado.test import httpclient_test
|
||||
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
|
||||
from tornado.test.util import skipOnTravis, skipIfNoIPv6
|
||||
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog
|
||||
from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port, unittest
|
||||
from tornado.web import RequestHandler, Application, asynchronous, url, stream_request_body
|
||||
|
||||
|
||||
|
@ -97,15 +98,18 @@ class HostEchoHandler(RequestHandler):
|
|||
|
||||
|
||||
class NoContentLengthHandler(RequestHandler):
|
||||
@gen.coroutine
|
||||
@asynchronous
|
||||
def get(self):
|
||||
if self.request.version.startswith('HTTP/1'):
|
||||
# Emulate the old HTTP/1.0 behavior of returning a body with no
|
||||
# content-length. Tornado handles content-length at the framework
|
||||
# level so we have to go around it.
|
||||
stream = self.request.connection.stream
|
||||
yield stream.write(b"HTTP/1.0 200 OK\r\n\r\n"
|
||||
stream = self.request.connection.detach()
|
||||
stream.write(b"HTTP/1.0 200 OK\r\n\r\n"
|
||||
b"hello")
|
||||
stream.close()
|
||||
else:
|
||||
self.finish('HTTP/1 required')
|
||||
|
||||
|
||||
class EchoPostHandler(RequestHandler):
|
||||
|
@ -191,9 +195,6 @@ class SimpleHTTPClientTestMixin(object):
|
|||
response = self.wait()
|
||||
response.rethrow()
|
||||
|
||||
def test_default_certificates_exist(self):
|
||||
open(_default_ca_certs()).close()
|
||||
|
||||
def test_gzip(self):
|
||||
# All the tests in this file should be using gzip, but this test
|
||||
# ensures that it is in fact getting compressed.
|
||||
|
@ -235,9 +236,16 @@ class SimpleHTTPClientTestMixin(object):
|
|||
|
||||
@skipOnTravis
|
||||
def test_request_timeout(self):
|
||||
response = self.fetch('/trigger?wake=false', request_timeout=0.1)
|
||||
timeout = 0.1
|
||||
timeout_min, timeout_max = 0.099, 0.15
|
||||
if os.name == 'nt':
|
||||
timeout = 0.5
|
||||
timeout_min, timeout_max = 0.4, 0.6
|
||||
|
||||
response = self.fetch('/trigger?wake=false', request_timeout=timeout)
|
||||
self.assertEqual(response.code, 599)
|
||||
self.assertTrue(0.099 < response.request_time < 0.15, response.request_time)
|
||||
self.assertTrue(timeout_min < response.request_time < timeout_max,
|
||||
response.request_time)
|
||||
self.assertEqual(str(response.error), "HTTP 599: Timeout")
|
||||
# trigger the hanging request to let it clean up after itself
|
||||
self.triggers.popleft()()
|
||||
|
@ -315,10 +323,10 @@ class SimpleHTTPClientTestMixin(object):
|
|||
self.assertTrue(host_re.match(response.body), response.body)
|
||||
|
||||
def test_connection_refused(self):
|
||||
server_socket, port = bind_unused_port()
|
||||
server_socket.close()
|
||||
cleanup_func, port = refusing_port()
|
||||
self.addCleanup(cleanup_func)
|
||||
with ExpectLog(gen_log, ".*", required=False):
|
||||
self.http_client.fetch("http://localhost:%d/" % port, self.stop)
|
||||
self.http_client.fetch("http://127.0.0.1:%d/" % port, self.stop)
|
||||
response = self.wait()
|
||||
self.assertEqual(599, response.code)
|
||||
|
||||
|
@ -352,6 +360,9 @@ class SimpleHTTPClientTestMixin(object):
|
|||
|
||||
def test_no_content_length(self):
|
||||
response = self.fetch("/no_content_length")
|
||||
if response.body == b"HTTP/1 required":
|
||||
self.skipTest("requires HTTP/1.x")
|
||||
else:
|
||||
self.assertEquals(b"hello", response.body)
|
||||
|
||||
def sync_body_producer(self, write):
|
||||
|
@ -425,6 +436,33 @@ class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
|
|||
defaults=dict(validate_cert=False),
|
||||
**kwargs)
|
||||
|
||||
def test_ssl_options(self):
|
||||
resp = self.fetch("/hello", ssl_options={})
|
||||
self.assertEqual(resp.body, b"Hello world!")
|
||||
|
||||
@unittest.skipIf(not hasattr(ssl, 'SSLContext'),
|
||||
'ssl.SSLContext not present')
|
||||
def test_ssl_context(self):
|
||||
resp = self.fetch("/hello",
|
||||
ssl_options=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
|
||||
self.assertEqual(resp.body, b"Hello world!")
|
||||
|
||||
def test_ssl_options_handshake_fail(self):
|
||||
with ExpectLog(gen_log, "SSL Error|Uncaught exception",
|
||||
required=False):
|
||||
resp = self.fetch(
|
||||
"/hello", ssl_options=dict(cert_reqs=ssl.CERT_REQUIRED))
|
||||
self.assertRaises(ssl.SSLError, resp.rethrow)
|
||||
|
||||
@unittest.skipIf(not hasattr(ssl, 'SSLContext'),
|
||||
'ssl.SSLContext not present')
|
||||
def test_ssl_context_handshake_fail(self):
|
||||
with ExpectLog(gen_log, "SSL Error|Uncaught exception"):
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
resp = self.fetch("/hello", ssl_options=ctx)
|
||||
self.assertRaises(ssl.SSLError, resp.rethrow)
|
||||
|
||||
|
||||
class CreateAsyncHTTPClientTestCase(AsyncTestCase):
|
||||
def setUp(self):
|
||||
|
@ -460,6 +498,12 @@ class CreateAsyncHTTPClientTestCase(AsyncTestCase):
|
|||
|
||||
class HTTP100ContinueTestCase(AsyncHTTPTestCase):
|
||||
def respond_100(self, request):
|
||||
self.http1 = request.version.startswith('HTTP/1.')
|
||||
if not self.http1:
|
||||
request.connection.write_headers(ResponseStartLine('', 200, 'OK'),
|
||||
HTTPHeaders())
|
||||
request.connection.finish()
|
||||
return
|
||||
self.request = request
|
||||
self.request.connection.stream.write(
|
||||
b"HTTP/1.1 100 CONTINUE\r\n\r\n",
|
||||
|
@ -476,11 +520,20 @@ class HTTP100ContinueTestCase(AsyncHTTPTestCase):
|
|||
|
||||
def test_100_continue(self):
|
||||
res = self.fetch('/')
|
||||
if not self.http1:
|
||||
self.skipTest("requires HTTP/1.x")
|
||||
self.assertEqual(res.body, b'A')
|
||||
|
||||
|
||||
class HTTP204NoContentTestCase(AsyncHTTPTestCase):
|
||||
def respond_204(self, request):
|
||||
self.http1 = request.version.startswith('HTTP/1.')
|
||||
if not self.http1:
|
||||
# Close the request cleanly in HTTP/2; it will be skipped anyway.
|
||||
request.connection.write_headers(ResponseStartLine('', 200, 'OK'),
|
||||
HTTPHeaders())
|
||||
request.connection.finish()
|
||||
return
|
||||
# A 204 response never has a body, even if doesn't have a content-length
|
||||
# (which would otherwise mean read-until-close). Tornado always
|
||||
# sends a content-length, so we simulate here a server that sends
|
||||
|
@ -488,14 +541,18 @@ class HTTP204NoContentTestCase(AsyncHTTPTestCase):
|
|||
#
|
||||
# Tests of a 204 response with a Content-Length header are included
|
||||
# in SimpleHTTPClientTestMixin.
|
||||
request.connection.stream.write(
|
||||
stream = request.connection.detach()
|
||||
stream.write(
|
||||
b"HTTP/1.1 204 No content\r\n\r\n")
|
||||
stream.close()
|
||||
|
||||
def get_app(self):
|
||||
return self.respond_204
|
||||
|
||||
def test_204_no_content(self):
|
||||
resp = self.fetch('/')
|
||||
if not self.http1:
|
||||
self.skipTest("requires HTTP/1.x")
|
||||
self.assertEqual(resp.code, 204)
|
||||
self.assertEqual(resp.body, b'')
|
||||
|
||||
|
@ -574,3 +631,49 @@ class MaxHeaderSizeTest(AsyncHTTPTestCase):
|
|||
with ExpectLog(gen_log, "Unsatisfiable read"):
|
||||
response = self.fetch('/large')
|
||||
self.assertEqual(response.code, 599)
|
||||
|
||||
|
||||
class MaxBodySizeTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
class SmallBody(RequestHandler):
|
||||
def get(self):
|
||||
self.write("a"*1024*64)
|
||||
|
||||
class LargeBody(RequestHandler):
|
||||
def get(self):
|
||||
self.write("a"*1024*100)
|
||||
|
||||
return Application([('/small', SmallBody),
|
||||
('/large', LargeBody)])
|
||||
|
||||
def get_http_client(self):
|
||||
return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_body_size=1024*64)
|
||||
|
||||
def test_small_body(self):
|
||||
response = self.fetch('/small')
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b'a'*1024*64)
|
||||
|
||||
def test_large_body(self):
|
||||
with ExpectLog(gen_log, "Malformed HTTP message from None: Content-Length too long"):
|
||||
response = self.fetch('/large')
|
||||
self.assertEqual(response.code, 599)
|
||||
|
||||
|
||||
class MaxBufferSizeTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
|
||||
class LargeBody(RequestHandler):
|
||||
def get(self):
|
||||
self.write("a"*1024*100)
|
||||
|
||||
return Application([('/large', LargeBody)])
|
||||
|
||||
def get_http_client(self):
|
||||
# 100KB body with 64KB buffer
|
||||
return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_body_size=1024*100, max_buffer_size=1024*64)
|
||||
|
||||
def test_large_body(self):
|
||||
response = self.fetch('/large')
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b'a'*1024*100)
|
||||
|
|
|
@ -24,8 +24,8 @@ from tornado.concurrent import Future
|
|||
from tornado.netutil import bind_sockets, Resolver
|
||||
from tornado.tcpclient import TCPClient, _Connector
|
||||
from tornado.tcpserver import TCPServer
|
||||
from tornado.testing import AsyncTestCase, bind_unused_port, gen_test
|
||||
from tornado.test.util import skipIfNoIPv6, unittest
|
||||
from tornado.testing import AsyncTestCase, gen_test
|
||||
from tornado.test.util import skipIfNoIPv6, unittest, refusing_port
|
||||
|
||||
# Fake address families for testing. Used in place of AF_INET
|
||||
# and AF_INET6 because some installations do not have AF_INET6.
|
||||
|
@ -120,8 +120,8 @@ class TCPClientTest(AsyncTestCase):
|
|||
|
||||
@gen_test
|
||||
def test_refused_ipv4(self):
|
||||
sock, port = bind_unused_port()
|
||||
sock.close()
|
||||
cleanup_func, port = refusing_port()
|
||||
self.addCleanup(cleanup_func)
|
||||
with self.assertRaises(IOError):
|
||||
yield self.client.connect('127.0.0.1', port)
|
||||
|
||||
|
|
38
tornado/test/tcpserver_test.py
Normal file
38
tornado/test/tcpserver_test.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
import socket
|
||||
|
||||
from tornado import gen
|
||||
from tornado.iostream import IOStream
|
||||
from tornado.log import app_log
|
||||
from tornado.stack_context import NullContext
|
||||
from tornado.tcpserver import TCPServer
|
||||
from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test
|
||||
|
||||
|
||||
class TCPServerTest(AsyncTestCase):
|
||||
@gen_test
|
||||
def test_handle_stream_coroutine_logging(self):
|
||||
# handle_stream may be a coroutine and any exception in its
|
||||
# Future will be logged.
|
||||
class TestServer(TCPServer):
|
||||
@gen.coroutine
|
||||
def handle_stream(self, stream, address):
|
||||
yield gen.moment
|
||||
stream.close()
|
||||
1/0
|
||||
|
||||
server = client = None
|
||||
try:
|
||||
sock, port = bind_unused_port()
|
||||
with NullContext():
|
||||
server = TestServer()
|
||||
server.add_socket(sock)
|
||||
client = IOStream(socket.socket())
|
||||
with ExpectLog(app_log, "Exception in callback"):
|
||||
yield client.connect(('localhost', port))
|
||||
yield client.read_until_close()
|
||||
yield gen.moment
|
||||
finally:
|
||||
if server is not None:
|
||||
server.stop()
|
||||
if client is not None:
|
||||
client.close()
|
|
@ -19,15 +19,18 @@ Unittest for the twisted-style reactor.
|
|||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
try:
|
||||
import fcntl
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
|
||||
from twisted.internet.interfaces import IReadDescriptor, IWriteDescriptor
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.python import log
|
||||
|
@ -40,10 +43,12 @@ except ImportError:
|
|||
# The core of Twisted 12.3.0 is available on python 3, but twisted.web is not
|
||||
# so test for it separately.
|
||||
try:
|
||||
from twisted.web.client import Agent
|
||||
from twisted.web.client import Agent, readBody
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import Site
|
||||
have_twisted_web = True
|
||||
# As of Twisted 15.0.0, twisted.web is present but fails our
|
||||
# tests due to internal str/bytes errors.
|
||||
have_twisted_web = sys.version_info < (3,)
|
||||
except ImportError:
|
||||
have_twisted_web = False
|
||||
|
||||
|
@ -52,6 +57,8 @@ try:
|
|||
except ImportError:
|
||||
import _thread as thread # py3
|
||||
|
||||
from tornado.escape import utf8
|
||||
from tornado import gen
|
||||
from tornado.httpclient import AsyncHTTPClient
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.ioloop import IOLoop
|
||||
|
@ -65,6 +72,9 @@ from tornado.web import RequestHandler, Application
|
|||
skipIfNoTwisted = unittest.skipUnless(have_twisted,
|
||||
"twisted module not present")
|
||||
|
||||
skipIfNoSingleDispatch = unittest.skipIf(
|
||||
gen.singledispatch is None, "singledispatch module not present")
|
||||
|
||||
|
||||
def save_signal_handlers():
|
||||
saved = {}
|
||||
|
@ -407,7 +417,7 @@ class CompatibilityTests(unittest.TestCase):
|
|||
# http://twistedmatrix.com/documents/current/web/howto/client.html
|
||||
chunks = []
|
||||
client = Agent(self.reactor)
|
||||
d = client.request('GET', url)
|
||||
d = client.request(b'GET', utf8(url))
|
||||
|
||||
class Accumulator(Protocol):
|
||||
def __init__(self, finished):
|
||||
|
@ -425,37 +435,98 @@ class CompatibilityTests(unittest.TestCase):
|
|||
return finished
|
||||
d.addCallback(callback)
|
||||
|
||||
def shutdown(ignored):
|
||||
def shutdown(failure):
|
||||
if hasattr(self, 'stop_loop'):
|
||||
self.stop_loop()
|
||||
elif failure is not None:
|
||||
# loop hasn't been initialized yet; try our best to
|
||||
# get an error message out. (the runner() interaction
|
||||
# should probably be refactored).
|
||||
try:
|
||||
failure.raiseException()
|
||||
except:
|
||||
logging.error('exception before starting loop', exc_info=True)
|
||||
d.addBoth(shutdown)
|
||||
runner()
|
||||
self.assertTrue(chunks)
|
||||
return ''.join(chunks)
|
||||
|
||||
def twisted_coroutine_fetch(self, url, runner):
|
||||
body = [None]
|
||||
|
||||
@gen.coroutine
|
||||
def f():
|
||||
# This is simpler than the non-coroutine version, but it cheats
|
||||
# by reading the body in one blob instead of streaming it with
|
||||
# a Protocol.
|
||||
client = Agent(self.reactor)
|
||||
response = yield client.request(b'GET', utf8(url))
|
||||
with warnings.catch_warnings():
|
||||
# readBody has a buggy DeprecationWarning in Twisted 15.0:
|
||||
# https://twistedmatrix.com/trac/changeset/43379
|
||||
warnings.simplefilter('ignore', category=DeprecationWarning)
|
||||
body[0] = yield readBody(response)
|
||||
self.stop_loop()
|
||||
self.io_loop.add_callback(f)
|
||||
runner()
|
||||
return body[0]
|
||||
|
||||
def testTwistedServerTornadoClientIOLoop(self):
|
||||
self.start_twisted_server()
|
||||
response = self.tornado_fetch(
|
||||
'http://localhost:%d' % self.twisted_port, self.run_ioloop)
|
||||
'http://127.0.0.1:%d' % self.twisted_port, self.run_ioloop)
|
||||
self.assertEqual(response.body, 'Hello from twisted!')
|
||||
|
||||
def testTwistedServerTornadoClientReactor(self):
|
||||
self.start_twisted_server()
|
||||
response = self.tornado_fetch(
|
||||
'http://localhost:%d' % self.twisted_port, self.run_reactor)
|
||||
'http://127.0.0.1:%d' % self.twisted_port, self.run_reactor)
|
||||
self.assertEqual(response.body, 'Hello from twisted!')
|
||||
|
||||
def testTornadoServerTwistedClientIOLoop(self):
|
||||
self.start_tornado_server()
|
||||
response = self.twisted_fetch(
|
||||
'http://localhost:%d' % self.tornado_port, self.run_ioloop)
|
||||
'http://127.0.0.1:%d' % self.tornado_port, self.run_ioloop)
|
||||
self.assertEqual(response, 'Hello from tornado!')
|
||||
|
||||
def testTornadoServerTwistedClientReactor(self):
|
||||
self.start_tornado_server()
|
||||
response = self.twisted_fetch(
|
||||
'http://localhost:%d' % self.tornado_port, self.run_reactor)
|
||||
'http://127.0.0.1:%d' % self.tornado_port, self.run_reactor)
|
||||
self.assertEqual(response, 'Hello from tornado!')
|
||||
|
||||
@skipIfNoSingleDispatch
|
||||
def testTornadoServerTwistedCoroutineClientIOLoop(self):
|
||||
self.start_tornado_server()
|
||||
response = self.twisted_coroutine_fetch(
|
||||
'http://127.0.0.1:%d' % self.tornado_port, self.run_ioloop)
|
||||
self.assertEqual(response, 'Hello from tornado!')
|
||||
|
||||
|
||||
@skipIfNoTwisted
|
||||
@skipIfNoSingleDispatch
|
||||
class ConvertDeferredTest(unittest.TestCase):
|
||||
def test_success(self):
|
||||
@inlineCallbacks
|
||||
def fn():
|
||||
if False:
|
||||
# inlineCallbacks doesn't work with regular functions;
|
||||
# must have a yield even if it's unreachable.
|
||||
yield
|
||||
returnValue(42)
|
||||
f = gen.convert_yielded(fn())
|
||||
self.assertEqual(f.result(), 42)
|
||||
|
||||
def test_failure(self):
|
||||
@inlineCallbacks
|
||||
def fn():
|
||||
if False:
|
||||
yield
|
||||
1 / 0
|
||||
f = gen.convert_yielded(fn())
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
f.result()
|
||||
|
||||
|
||||
if have_twisted:
|
||||
# Import and run as much of twisted's test suite as possible.
|
||||
|
@ -483,7 +554,7 @@ if have_twisted:
|
|||
'test_changeUID',
|
||||
],
|
||||
# Process tests appear to work on OSX 10.7, but not 10.6
|
||||
#'twisted.internet.test.test_process.PTYProcessTestsBuilder': [
|
||||
# 'twisted.internet.test.test_process.PTYProcessTestsBuilder': [
|
||||
# 'test_systemCallUninterruptedByChildExit',
|
||||
# ],
|
||||
'twisted.internet.test.test_tcp.TCPClientTestsBuilder': [
|
||||
|
@ -502,7 +573,7 @@ if have_twisted:
|
|||
'twisted.internet.test.test_threads.ThreadTestsBuilder': [],
|
||||
'twisted.internet.test.test_time.TimeTestsBuilder': [],
|
||||
# Extra third-party dependencies (pyOpenSSL)
|
||||
#'twisted.internet.test.test_tls.SSLClientTestsMixin': [],
|
||||
# 'twisted.internet.test.test_tls.SSLClientTestsMixin': [],
|
||||
'twisted.internet.test.test_udp.UDPServerTestsBuilder': [],
|
||||
'twisted.internet.test.test_unix.UNIXTestsBuilder': [
|
||||
# Platform-specific. These tests would be skipped automatically
|
||||
|
@ -588,13 +659,13 @@ if have_twisted:
|
|||
correctly. In some tests another TornadoReactor is layered on top
|
||||
of the whole stack.
|
||||
"""
|
||||
def initialize(self):
|
||||
def initialize(self, **kwargs):
|
||||
# When configured to use LayeredTwistedIOLoop we can't easily
|
||||
# get the next-best IOLoop implementation, so use the lowest common
|
||||
# denominator.
|
||||
self.real_io_loop = SelectIOLoop()
|
||||
reactor = TornadoReactor(io_loop=self.real_io_loop)
|
||||
super(LayeredTwistedIOLoop, self).initialize(reactor=reactor)
|
||||
super(LayeredTwistedIOLoop, self).initialize(reactor=reactor, **kwargs)
|
||||
self.add_callback(self.make_current)
|
||||
|
||||
def close(self, all_fds=False):
|
||||
|
|
|
@ -4,6 +4,8 @@ import os
|
|||
import socket
|
||||
import sys
|
||||
|
||||
from tornado.testing import bind_unused_port
|
||||
|
||||
# Encapsulate the choice of unittest or unittest2 here.
|
||||
# To be used as 'from tornado.test.util import unittest'.
|
||||
if sys.version_info < (2, 7):
|
||||
|
@ -28,3 +30,23 @@ skipIfNoNetwork = unittest.skipIf('NO_NETWORK' in os.environ,
|
|||
'network access disabled')
|
||||
|
||||
skipIfNoIPv6 = unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present')
|
||||
|
||||
|
||||
def refusing_port():
|
||||
"""Returns a local port number that will refuse all connections.
|
||||
|
||||
Return value is (cleanup_func, port); the cleanup function
|
||||
must be called to free the port to be reused.
|
||||
"""
|
||||
# On travis-ci, port numbers are reassigned frequently. To avoid
|
||||
# collisions with other tests, we use an open client-side socket's
|
||||
# ephemeral port number to ensure that nothing can listen on that
|
||||
# port.
|
||||
server_socket, port = bind_unused_port()
|
||||
server_socket.setblocking(1)
|
||||
client_socket = socket.socket()
|
||||
client_socket.connect(("127.0.0.1", port))
|
||||
conn, client_addr = server_socket.accept()
|
||||
conn.close()
|
||||
server_socket.close()
|
||||
return (client_socket.close, client_addr[1])
|
||||
|
|
|
@ -3,8 +3,9 @@ from __future__ import absolute_import, division, print_function, with_statement
|
|||
import sys
|
||||
import datetime
|
||||
|
||||
import tornado.escape
|
||||
from tornado.escape import utf8
|
||||
from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer, timedelta_to_seconds
|
||||
from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer, timedelta_to_seconds, import_object
|
||||
from tornado.test.util import unittest
|
||||
|
||||
try:
|
||||
|
@ -45,13 +46,15 @@ class TestConfigurable(Configurable):
|
|||
|
||||
|
||||
class TestConfig1(TestConfigurable):
|
||||
def initialize(self, a=None):
|
||||
def initialize(self, pos_arg=None, a=None):
|
||||
self.a = a
|
||||
self.pos_arg = pos_arg
|
||||
|
||||
|
||||
class TestConfig2(TestConfigurable):
|
||||
def initialize(self, b=None):
|
||||
def initialize(self, pos_arg=None, b=None):
|
||||
self.b = b
|
||||
self.pos_arg = pos_arg
|
||||
|
||||
|
||||
class ConfigurableTest(unittest.TestCase):
|
||||
|
@ -101,9 +104,10 @@ class ConfigurableTest(unittest.TestCase):
|
|||
self.assertIsInstance(obj, TestConfig1)
|
||||
self.assertEqual(obj.a, 3)
|
||||
|
||||
obj = TestConfigurable(a=4)
|
||||
obj = TestConfigurable(42, a=4)
|
||||
self.assertIsInstance(obj, TestConfig1)
|
||||
self.assertEqual(obj.a, 4)
|
||||
self.assertEqual(obj.pos_arg, 42)
|
||||
|
||||
self.checkSubclasses()
|
||||
# args bound in configure don't apply when using the subclass directly
|
||||
|
@ -116,9 +120,10 @@ class ConfigurableTest(unittest.TestCase):
|
|||
self.assertIsInstance(obj, TestConfig2)
|
||||
self.assertEqual(obj.b, 5)
|
||||
|
||||
obj = TestConfigurable(b=6)
|
||||
obj = TestConfigurable(42, b=6)
|
||||
self.assertIsInstance(obj, TestConfig2)
|
||||
self.assertEqual(obj.b, 6)
|
||||
self.assertEqual(obj.pos_arg, 42)
|
||||
|
||||
self.checkSubclasses()
|
||||
# args bound in configure don't apply when using the subclass directly
|
||||
|
@ -177,3 +182,20 @@ class TimedeltaToSecondsTest(unittest.TestCase):
|
|||
def test_timedelta_to_seconds(self):
|
||||
time_delta = datetime.timedelta(hours=1)
|
||||
self.assertEqual(timedelta_to_seconds(time_delta), 3600.0)
|
||||
|
||||
|
||||
class ImportObjectTest(unittest.TestCase):
|
||||
def test_import_member(self):
|
||||
self.assertIs(import_object('tornado.escape.utf8'), utf8)
|
||||
|
||||
def test_import_member_unicode(self):
|
||||
self.assertIs(import_object(u('tornado.escape.utf8')), utf8)
|
||||
|
||||
def test_import_module(self):
|
||||
self.assertIs(import_object('tornado.escape'), tornado.escape)
|
||||
|
||||
def test_import_module_unicode(self):
|
||||
# The internal implementation of __import__ differs depending on
|
||||
# whether the thing being imported is a module or not.
|
||||
# This variant requires a byte string in python 2.
|
||||
self.assertIs(import_object(u('tornado.escape')), tornado.escape)
|
||||
|
|
|
@ -11,7 +11,7 @@ from tornado.template import DictLoader
|
|||
from tornado.testing import AsyncHTTPTestCase, ExpectLog, gen_test
|
||||
from tornado.test.util import unittest
|
||||
from tornado.util import u, ObjectDict, unicode_type, timedelta_to_seconds
|
||||
from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler
|
||||
from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler, get_signature_key_version
|
||||
|
||||
import binascii
|
||||
import contextlib
|
||||
|
@ -71,10 +71,14 @@ class HelloHandler(RequestHandler):
|
|||
|
||||
class CookieTestRequestHandler(RequestHandler):
|
||||
# stub out enough methods to make the secure_cookie functions work
|
||||
def __init__(self):
|
||||
def __init__(self, cookie_secret='0123456789', key_version=None):
|
||||
# don't call super.__init__
|
||||
self._cookies = {}
|
||||
self.application = ObjectDict(settings=dict(cookie_secret='0123456789'))
|
||||
if key_version is None:
|
||||
self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret))
|
||||
else:
|
||||
self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret,
|
||||
key_version=key_version))
|
||||
|
||||
def get_cookie(self, name):
|
||||
return self._cookies.get(name)
|
||||
|
@ -128,6 +132,51 @@ class SecureCookieV1Test(unittest.TestCase):
|
|||
self.assertEqual(handler.get_secure_cookie('foo', min_version=1), b'\xe9')
|
||||
|
||||
|
||||
# See SignedValueTest below for more.
|
||||
class SecureCookieV2Test(unittest.TestCase):
|
||||
KEY_VERSIONS = {
|
||||
0: 'ajklasdf0ojaisdf',
|
||||
1: 'aslkjasaolwkjsdf'
|
||||
}
|
||||
|
||||
def test_round_trip(self):
|
||||
handler = CookieTestRequestHandler()
|
||||
handler.set_secure_cookie('foo', b'bar', version=2)
|
||||
self.assertEqual(handler.get_secure_cookie('foo', min_version=2), b'bar')
|
||||
|
||||
def test_key_version_roundtrip(self):
|
||||
handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
|
||||
key_version=0)
|
||||
handler.set_secure_cookie('foo', b'bar')
|
||||
self.assertEqual(handler.get_secure_cookie('foo'), b'bar')
|
||||
|
||||
def test_key_version_roundtrip_differing_version(self):
|
||||
handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
|
||||
key_version=1)
|
||||
handler.set_secure_cookie('foo', b'bar')
|
||||
self.assertEqual(handler.get_secure_cookie('foo'), b'bar')
|
||||
|
||||
def test_key_version_increment_version(self):
|
||||
handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
|
||||
key_version=0)
|
||||
handler.set_secure_cookie('foo', b'bar')
|
||||
new_handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
|
||||
key_version=1)
|
||||
new_handler._cookies = handler._cookies
|
||||
self.assertEqual(new_handler.get_secure_cookie('foo'), b'bar')
|
||||
|
||||
def test_key_version_invalidate_version(self):
|
||||
handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
|
||||
key_version=0)
|
||||
handler.set_secure_cookie('foo', b'bar')
|
||||
new_key_versions = self.KEY_VERSIONS.copy()
|
||||
new_key_versions.pop(0)
|
||||
new_handler = CookieTestRequestHandler(cookie_secret=new_key_versions,
|
||||
key_version=1)
|
||||
new_handler._cookies = handler._cookies
|
||||
self.assertEqual(new_handler.get_secure_cookie('foo'), None)
|
||||
|
||||
|
||||
class CookieTest(WebTestCase):
|
||||
def get_handlers(self):
|
||||
class SetCookieHandler(RequestHandler):
|
||||
|
@ -171,6 +220,13 @@ class CookieTest(WebTestCase):
|
|||
def get(self):
|
||||
self.set_cookie("foo", "bar", expires_days=10)
|
||||
|
||||
class SetCookieFalsyFlags(RequestHandler):
|
||||
def get(self):
|
||||
self.set_cookie("a", "1", secure=True)
|
||||
self.set_cookie("b", "1", secure=False)
|
||||
self.set_cookie("c", "1", httponly=True)
|
||||
self.set_cookie("d", "1", httponly=False)
|
||||
|
||||
return [("/set", SetCookieHandler),
|
||||
("/get", GetCookieHandler),
|
||||
("/set_domain", SetCookieDomainHandler),
|
||||
|
@ -178,6 +234,7 @@ class CookieTest(WebTestCase):
|
|||
("/set_overwrite", SetCookieOverwriteHandler),
|
||||
("/set_max_age", SetCookieMaxAgeHandler),
|
||||
("/set_expires_days", SetCookieExpiresDaysHandler),
|
||||
("/set_falsy_flags", SetCookieFalsyFlags)
|
||||
]
|
||||
|
||||
def test_set_cookie(self):
|
||||
|
@ -249,6 +306,16 @@ class CookieTest(WebTestCase):
|
|||
*email.utils.parsedate(match.groupdict()["expires"])[:6])
|
||||
self.assertTrue(abs(timedelta_to_seconds(expires - header_expires)) < 10)
|
||||
|
||||
def test_set_cookie_false_flags(self):
|
||||
response = self.fetch("/set_falsy_flags")
|
||||
headers = sorted(response.headers.get_list("Set-Cookie"))
|
||||
# The secure and httponly headers are capitalized in py35 and
|
||||
# lowercase in older versions.
|
||||
self.assertEqual(headers[0].lower(), 'a=1; path=/; secure')
|
||||
self.assertEqual(headers[1].lower(), 'b=1; path=/')
|
||||
self.assertEqual(headers[2].lower(), 'c=1; httponly; path=/')
|
||||
self.assertEqual(headers[3].lower(), 'd=1; path=/')
|
||||
|
||||
|
||||
class AuthRedirectRequestHandler(RequestHandler):
|
||||
def initialize(self, login_url):
|
||||
|
@ -305,7 +372,7 @@ class ConnectionCloseTest(WebTestCase):
|
|||
|
||||
def test_connection_close(self):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
|
||||
s.connect(("localhost", self.get_http_port()))
|
||||
s.connect(("127.0.0.1", self.get_http_port()))
|
||||
self.stream = IOStream(s, io_loop=self.io_loop)
|
||||
self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
|
||||
self.wait()
|
||||
|
@ -379,6 +446,12 @@ class RequestEncodingTest(WebTestCase):
|
|||
path_args=["a/b", "c/d"],
|
||||
args={}))
|
||||
|
||||
def test_error(self):
|
||||
# Percent signs (encoded as %25) should not mess up printf-style
|
||||
# messages in logs
|
||||
with ExpectLog(gen_log, ".*Invalid unicode"):
|
||||
self.fetch("/group/?arg=%25%e9")
|
||||
|
||||
|
||||
class TypeCheckHandler(RequestHandler):
|
||||
def prepare(self):
|
||||
|
@ -579,6 +652,7 @@ class WSGISafeWebTest(WebTestCase):
|
|||
url("/redirect", RedirectHandler),
|
||||
url("/web_redirect_permanent", WebRedirectHandler, {"url": "/web_redirect_newpath"}),
|
||||
url("/web_redirect", WebRedirectHandler, {"url": "/web_redirect_newpath", "permanent": False}),
|
||||
url("//web_redirect_double_slash", WebRedirectHandler, {"url": '/web_redirect_newpath'}),
|
||||
url("/header_injection", HeaderInjectionHandler),
|
||||
url("/get_argument", GetArgumentHandler),
|
||||
url("/get_arguments", GetArgumentsHandler),
|
||||
|
@ -712,6 +786,11 @@ js_embed()
|
|||
self.assertEqual(response.code, 302)
|
||||
self.assertEqual(response.headers['Location'], '/web_redirect_newpath')
|
||||
|
||||
def test_web_redirect_double_slash(self):
|
||||
response = self.fetch("//web_redirect_double_slash", follow_redirects=False)
|
||||
self.assertEqual(response.code, 301)
|
||||
self.assertEqual(response.headers['Location'], '/web_redirect_newpath')
|
||||
|
||||
def test_header_injection(self):
|
||||
response = self.fetch("/header_injection")
|
||||
self.assertEqual(response.body, b"ok")
|
||||
|
@ -1517,6 +1596,22 @@ class ExceptionHandlerTest(SimpleHandlerTestCase):
|
|||
self.assertEqual(response.code, 403)
|
||||
|
||||
|
||||
@wsgi_safe
|
||||
class BuggyLoggingTest(SimpleHandlerTestCase):
|
||||
class Handler(RequestHandler):
|
||||
def get(self):
|
||||
1/0
|
||||
|
||||
def log_exception(self, typ, value, tb):
|
||||
1/0
|
||||
|
||||
def test_buggy_log_exception(self):
|
||||
# Something gets logged even though the application's
|
||||
# logger is broken.
|
||||
with ExpectLog(app_log, '.*'):
|
||||
self.fetch('/')
|
||||
|
||||
|
||||
@wsgi_safe
|
||||
class UIMethodUIModuleTest(SimpleHandlerTestCase):
|
||||
"""Test that UI methods and modules are created correctly and
|
||||
|
@ -1533,6 +1628,7 @@ class UIMethodUIModuleTest(SimpleHandlerTestCase):
|
|||
def my_ui_method(handler, x):
|
||||
return "In my_ui_method(%s) with handler value %s." % (
|
||||
x, handler.value())
|
||||
|
||||
class MyModule(UIModule):
|
||||
def render(self, x):
|
||||
return "In MyModule(%s) with handler value %s." % (
|
||||
|
@ -1907,7 +2003,7 @@ class StreamingRequestBodyTest(WebTestCase):
|
|||
def connect(self, url, connection_close):
|
||||
# Use a raw connection so we can control the sending of data.
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
|
||||
s.connect(("localhost", self.get_http_port()))
|
||||
s.connect(("127.0.0.1", self.get_http_port()))
|
||||
stream = IOStream(s, io_loop=self.io_loop)
|
||||
stream.write(b"GET " + url + b" HTTP/1.1\r\n")
|
||||
if connection_close:
|
||||
|
@ -1988,7 +2084,9 @@ class StreamingRequestFlowControlTest(WebTestCase):
|
|||
|
||||
@gen.coroutine
|
||||
def prepare(self):
|
||||
with self.in_method('prepare'):
|
||||
# Note that asynchronous prepare() does not block data_received,
|
||||
# so we don't use in_method here.
|
||||
self.methods.append('prepare')
|
||||
yield gen.Task(IOLoop.current().add_callback)
|
||||
|
||||
@gen.coroutine
|
||||
|
@ -2051,9 +2149,10 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase):
|
|||
# When the content-length is too high, the connection is simply
|
||||
# closed without completing the response. An error is logged on
|
||||
# the server.
|
||||
with ExpectLog(app_log, "Uncaught exception"):
|
||||
with ExpectLog(app_log, "(Uncaught exception|Exception in callback)"):
|
||||
with ExpectLog(gen_log,
|
||||
"Cannot send error response after headers written"):
|
||||
"(Cannot send error response after headers written"
|
||||
"|Failed to flush partial response)"):
|
||||
response = self.fetch("/high")
|
||||
self.assertEqual(response.code, 599)
|
||||
self.assertEqual(str(self.server_error),
|
||||
|
@ -2063,9 +2162,10 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase):
|
|||
# When the content-length is too low, the connection is closed
|
||||
# without writing the last chunk, so the client never sees the request
|
||||
# complete (which would be a framing error).
|
||||
with ExpectLog(app_log, "Uncaught exception"):
|
||||
with ExpectLog(app_log, "(Uncaught exception|Exception in callback)"):
|
||||
with ExpectLog(gen_log,
|
||||
"Cannot send error response after headers written"):
|
||||
"(Cannot send error response after headers written"
|
||||
"|Failed to flush partial response)"):
|
||||
response = self.fetch("/low")
|
||||
self.assertEqual(response.code, 599)
|
||||
self.assertEqual(str(self.server_error),
|
||||
|
@ -2075,6 +2175,7 @@ class IncorrectContentLengthTest(SimpleHandlerTestCase):
|
|||
class ClientCloseTest(SimpleHandlerTestCase):
|
||||
class Handler(RequestHandler):
|
||||
def get(self):
|
||||
if self.request.version.startswith('HTTP/1'):
|
||||
# Simulate a connection closed by the client during
|
||||
# request processing. The client will see an error, but the
|
||||
# server should respond gracefully (without logging errors
|
||||
|
@ -2082,14 +2183,20 @@ class ClientCloseTest(SimpleHandlerTestCase):
|
|||
# Content-Length said we would)
|
||||
self.request.connection.stream.close()
|
||||
self.write('hello')
|
||||
else:
|
||||
# TODO: add a HTTP2-compatible version of this test.
|
||||
self.write('requires HTTP/1.x')
|
||||
|
||||
def test_client_close(self):
|
||||
response = self.fetch('/')
|
||||
if response.body == b'requires HTTP/1.x':
|
||||
self.skipTest('requires HTTP/1.x')
|
||||
self.assertEqual(response.code, 599)
|
||||
|
||||
|
||||
class SignedValueTest(unittest.TestCase):
|
||||
SECRET = "It's a secret to everybody"
|
||||
SECRET_DICT = {0: "asdfbasdf", 1: "12312312", 2: "2342342"}
|
||||
|
||||
def past(self):
|
||||
return self.present() - 86400 * 32
|
||||
|
@ -2151,6 +2258,7 @@ class SignedValueTest(unittest.TestCase):
|
|||
def test_payload_tampering(self):
|
||||
# These cookies are variants of the one in test_known_values.
|
||||
sig = "3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152"
|
||||
|
||||
def validate(prefix):
|
||||
return (b'value' ==
|
||||
decode_signed_value(SignedValueTest.SECRET, "key",
|
||||
|
@ -2165,6 +2273,7 @@ class SignedValueTest(unittest.TestCase):
|
|||
|
||||
def test_signature_tampering(self):
|
||||
prefix = "2|1:0|10:1300000000|3:key|8:dmFsdWU=|"
|
||||
|
||||
def validate(sig):
|
||||
return (b'value' ==
|
||||
decode_signed_value(SignedValueTest.SECRET, "key",
|
||||
|
@ -2194,6 +2303,43 @@ class SignedValueTest(unittest.TestCase):
|
|||
clock=self.present)
|
||||
self.assertEqual(value, decoded)
|
||||
|
||||
def test_key_versioning_read_write_default_key(self):
|
||||
value = b"\xe9"
|
||||
signed = create_signed_value(SignedValueTest.SECRET_DICT,
|
||||
"key", value, clock=self.present,
|
||||
key_version=0)
|
||||
decoded = decode_signed_value(SignedValueTest.SECRET_DICT,
|
||||
"key", signed, clock=self.present)
|
||||
self.assertEqual(value, decoded)
|
||||
|
||||
def test_key_versioning_read_write_non_default_key(self):
|
||||
value = b"\xe9"
|
||||
signed = create_signed_value(SignedValueTest.SECRET_DICT,
|
||||
"key", value, clock=self.present,
|
||||
key_version=1)
|
||||
decoded = decode_signed_value(SignedValueTest.SECRET_DICT,
|
||||
"key", signed, clock=self.present)
|
||||
self.assertEqual(value, decoded)
|
||||
|
||||
def test_key_versioning_invalid_key(self):
|
||||
value = b"\xe9"
|
||||
signed = create_signed_value(SignedValueTest.SECRET_DICT,
|
||||
"key", value, clock=self.present,
|
||||
key_version=0)
|
||||
newkeys = SignedValueTest.SECRET_DICT.copy()
|
||||
newkeys.pop(0)
|
||||
decoded = decode_signed_value(newkeys,
|
||||
"key", signed, clock=self.present)
|
||||
self.assertEqual(None, decoded)
|
||||
|
||||
def test_key_version_retrieval(self):
|
||||
value = b"\xe9"
|
||||
signed = create_signed_value(SignedValueTest.SECRET_DICT,
|
||||
"key", value, clock=self.present,
|
||||
key_version=1)
|
||||
key_version = get_signature_key_version(signed)
|
||||
self.assertEqual(1, key_version)
|
||||
|
||||
|
||||
@wsgi_safe
|
||||
class XSRFTest(SimpleHandlerTestCase):
|
||||
|
@ -2372,6 +2518,7 @@ class FinishExceptionTest(SimpleHandlerTestCase):
|
|||
self.assertEqual(b'authentication required', response.body)
|
||||
|
||||
|
||||
@wsgi_safe
|
||||
class DecoratorTest(WebTestCase):
|
||||
def get_handlers(self):
|
||||
class RemoveSlashHandler(RequestHandler):
|
||||
|
@ -2405,3 +2552,85 @@ class DecoratorTest(WebTestCase):
|
|||
response = self.fetch("/addslash?foo=bar", follow_redirects=False)
|
||||
self.assertEqual(response.code, 301)
|
||||
self.assertEqual(response.headers['Location'], "/addslash/?foo=bar")
|
||||
|
||||
|
||||
@wsgi_safe
|
||||
class CacheTest(WebTestCase):
|
||||
def get_handlers(self):
|
||||
class EtagHandler(RequestHandler):
|
||||
def get(self, computed_etag):
|
||||
self.write(computed_etag)
|
||||
|
||||
def compute_etag(self):
|
||||
return self._write_buffer[0]
|
||||
|
||||
return [
|
||||
('/etag/(.*)', EtagHandler)
|
||||
]
|
||||
|
||||
def test_wildcard_etag(self):
|
||||
computed_etag = '"xyzzy"'
|
||||
etags = '*'
|
||||
self._test_etag(computed_etag, etags, 304)
|
||||
|
||||
def test_strong_etag_match(self):
|
||||
computed_etag = '"xyzzy"'
|
||||
etags = '"xyzzy"'
|
||||
self._test_etag(computed_etag, etags, 304)
|
||||
|
||||
def test_multiple_strong_etag_match(self):
|
||||
computed_etag = '"xyzzy1"'
|
||||
etags = '"xyzzy1", "xyzzy2"'
|
||||
self._test_etag(computed_etag, etags, 304)
|
||||
|
||||
def test_strong_etag_not_match(self):
|
||||
computed_etag = '"xyzzy"'
|
||||
etags = '"xyzzy1"'
|
||||
self._test_etag(computed_etag, etags, 200)
|
||||
|
||||
def test_multiple_strong_etag_not_match(self):
|
||||
computed_etag = '"xyzzy"'
|
||||
etags = '"xyzzy1", "xyzzy2"'
|
||||
self._test_etag(computed_etag, etags, 200)
|
||||
|
||||
def test_weak_etag_match(self):
|
||||
computed_etag = '"xyzzy1"'
|
||||
etags = 'W/"xyzzy1"'
|
||||
self._test_etag(computed_etag, etags, 304)
|
||||
|
||||
def test_multiple_weak_etag_match(self):
|
||||
computed_etag = '"xyzzy2"'
|
||||
etags = 'W/"xyzzy1", W/"xyzzy2"'
|
||||
self._test_etag(computed_etag, etags, 304)
|
||||
|
||||
def test_weak_etag_not_match(self):
|
||||
computed_etag = '"xyzzy2"'
|
||||
etags = 'W/"xyzzy1"'
|
||||
self._test_etag(computed_etag, etags, 200)
|
||||
|
||||
def test_multiple_weak_etag_not_match(self):
|
||||
computed_etag = '"xyzzy3"'
|
||||
etags = 'W/"xyzzy1", W/"xyzzy2"'
|
||||
self._test_etag(computed_etag, etags, 200)
|
||||
|
||||
def _test_etag(self, computed_etag, etags, status_code):
|
||||
response = self.fetch(
|
||||
'/etag/' + computed_etag,
|
||||
headers={'If-None-Match': etags}
|
||||
)
|
||||
self.assertEqual(response.code, status_code)
|
||||
|
||||
|
||||
@wsgi_safe
|
||||
class RequestSummaryTest(SimpleHandlerTestCase):
|
||||
class Handler(RequestHandler):
|
||||
def get(self):
|
||||
# remote_ip is optional, although it's set by
|
||||
# both HTTPServer and WSGIAdapter.
|
||||
# Clobber it to make sure it doesn't break logging.
|
||||
self.request.remote_ip = None
|
||||
self.finish(self._request_summary())
|
||||
|
||||
def test_missing_remote_ip(self):
|
||||
resp = self.fetch("/")
|
||||
self.assertEqual(resp.body, b"GET / (None)")
|
||||
|
|
|
@ -12,7 +12,7 @@ from tornado.web import Application, RequestHandler
|
|||
from tornado.util import u
|
||||
|
||||
try:
|
||||
import tornado.websocket
|
||||
import tornado.websocket # noqa
|
||||
from tornado.util import _websocket_mask_python
|
||||
except ImportError:
|
||||
# The unittest module presents misleading errors on ImportError
|
||||
|
@ -53,7 +53,7 @@ class EchoHandler(TestWebSocketHandler):
|
|||
|
||||
class ErrorInOnMessageHandler(TestWebSocketHandler):
|
||||
def on_message(self, message):
|
||||
1/0
|
||||
1 / 0
|
||||
|
||||
|
||||
class HeaderHandler(TestWebSocketHandler):
|
||||
|
@ -75,6 +75,7 @@ class NonWebSocketHandler(RequestHandler):
|
|||
|
||||
class CloseReasonHandler(TestWebSocketHandler):
|
||||
def open(self):
|
||||
self.on_close_called = False
|
||||
self.close(1001, "goodbye")
|
||||
|
||||
|
||||
|
@ -91,7 +92,7 @@ class WebSocketBaseTestCase(AsyncHTTPTestCase):
|
|||
@gen.coroutine
|
||||
def ws_connect(self, path, compression_options=None):
|
||||
ws = yield websocket_connect(
|
||||
'ws://localhost:%d%s' % (self.get_http_port(), path),
|
||||
'ws://127.0.0.1:%d%s' % (self.get_http_port(), path),
|
||||
compression_options=compression_options)
|
||||
raise gen.Return(ws)
|
||||
|
||||
|
@ -105,6 +106,7 @@ class WebSocketBaseTestCase(AsyncHTTPTestCase):
|
|||
ws.close()
|
||||
yield self.close_future
|
||||
|
||||
|
||||
class WebSocketTest(WebSocketBaseTestCase):
|
||||
def get_app(self):
|
||||
self.close_future = Future()
|
||||
|
@ -135,7 +137,7 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
|
||||
def test_websocket_callbacks(self):
|
||||
websocket_connect(
|
||||
'ws://localhost:%d/echo' % self.get_http_port(),
|
||||
'ws://127.0.0.1:%d/echo' % self.get_http_port(),
|
||||
io_loop=self.io_loop, callback=self.stop)
|
||||
ws = self.wait().result()
|
||||
ws.write_message('hello')
|
||||
|
@ -189,14 +191,14 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
with self.assertRaises(IOError):
|
||||
with ExpectLog(gen_log, ".*"):
|
||||
yield websocket_connect(
|
||||
'ws://localhost:%d/' % port,
|
||||
'ws://127.0.0.1:%d/' % port,
|
||||
io_loop=self.io_loop,
|
||||
connect_timeout=3600)
|
||||
|
||||
@gen_test
|
||||
def test_websocket_close_buffered_data(self):
|
||||
ws = yield websocket_connect(
|
||||
'ws://localhost:%d/echo' % self.get_http_port())
|
||||
'ws://127.0.0.1:%d/echo' % self.get_http_port())
|
||||
ws.write_message('hello')
|
||||
ws.write_message('world')
|
||||
# Close the underlying stream.
|
||||
|
@ -207,7 +209,7 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
def test_websocket_headers(self):
|
||||
# Ensure that arbitrary headers can be passed through websocket_connect.
|
||||
ws = yield websocket_connect(
|
||||
HTTPRequest('ws://localhost:%d/header' % self.get_http_port(),
|
||||
HTTPRequest('ws://127.0.0.1:%d/header' % self.get_http_port(),
|
||||
headers={'X-Test': 'hello'}))
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, 'hello')
|
||||
|
@ -221,6 +223,8 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
self.assertIs(msg, None)
|
||||
self.assertEqual(ws.close_code, 1001)
|
||||
self.assertEqual(ws.close_reason, "goodbye")
|
||||
# The on_close callback is called no matter which side closed.
|
||||
yield self.close_future
|
||||
|
||||
@gen_test
|
||||
def test_client_close_reason(self):
|
||||
|
@ -243,8 +247,8 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
def test_check_origin_valid_no_path(self):
|
||||
port = self.get_http_port()
|
||||
|
||||
url = 'ws://localhost:%d/echo' % port
|
||||
headers = {'Origin': 'http://localhost:%d' % port}
|
||||
url = 'ws://127.0.0.1:%d/echo' % port
|
||||
headers = {'Origin': 'http://127.0.0.1:%d' % port}
|
||||
|
||||
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
|
||||
io_loop=self.io_loop)
|
||||
|
@ -257,8 +261,8 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
def test_check_origin_valid_with_path(self):
|
||||
port = self.get_http_port()
|
||||
|
||||
url = 'ws://localhost:%d/echo' % port
|
||||
headers = {'Origin': 'http://localhost:%d/something' % port}
|
||||
url = 'ws://127.0.0.1:%d/echo' % port
|
||||
headers = {'Origin': 'http://127.0.0.1:%d/something' % port}
|
||||
|
||||
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
|
||||
io_loop=self.io_loop)
|
||||
|
@ -271,8 +275,8 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
def test_check_origin_invalid_partial_url(self):
|
||||
port = self.get_http_port()
|
||||
|
||||
url = 'ws://localhost:%d/echo' % port
|
||||
headers = {'Origin': 'localhost:%d' % port}
|
||||
url = 'ws://127.0.0.1:%d/echo' % port
|
||||
headers = {'Origin': '127.0.0.1:%d' % port}
|
||||
|
||||
with self.assertRaises(HTTPError) as cm:
|
||||
yield websocket_connect(HTTPRequest(url, headers=headers),
|
||||
|
@ -283,8 +287,8 @@ class WebSocketTest(WebSocketBaseTestCase):
|
|||
def test_check_origin_invalid(self):
|
||||
port = self.get_http_port()
|
||||
|
||||
url = 'ws://localhost:%d/echo' % port
|
||||
# Host is localhost, which should not be accessible from some other
|
||||
url = 'ws://127.0.0.1:%d/echo' % port
|
||||
# Host is 127.0.0.1, which should not be accessible from some other
|
||||
# domain
|
||||
headers = {'Origin': 'http://somewhereelse.com'}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ try:
|
|||
from tornado.simple_httpclient import SimpleAsyncHTTPClient
|
||||
from tornado.ioloop import IOLoop, TimeoutError
|
||||
from tornado import netutil
|
||||
from tornado.process import Subprocess
|
||||
except ImportError:
|
||||
# These modules are not importable on app engine. Parts of this module
|
||||
# won't work, but e.g. LogTrapTestCase and main() will.
|
||||
|
@ -28,6 +29,7 @@ except ImportError:
|
|||
IOLoop = None
|
||||
netutil = None
|
||||
SimpleAsyncHTTPClient = None
|
||||
Subprocess = None
|
||||
from tornado.log import gen_log, app_log
|
||||
from tornado.stack_context import ExceptionStackContext
|
||||
from tornado.util import raise_exc_info, basestring_type
|
||||
|
@ -214,6 +216,8 @@ class AsyncTestCase(unittest.TestCase):
|
|||
self.io_loop.make_current()
|
||||
|
||||
def tearDown(self):
|
||||
# Clean up Subprocess, so it can be used again with a new ioloop.
|
||||
Subprocess.uninitialize()
|
||||
self.io_loop.clear_current()
|
||||
if (not IOLoop.initialized() or
|
||||
self.io_loop is not IOLoop.instance()):
|
||||
|
@ -413,9 +417,7 @@ class AsyncHTTPSTestCase(AsyncHTTPTestCase):
|
|||
Interface is generally the same as `AsyncHTTPTestCase`.
|
||||
"""
|
||||
def get_http_client(self):
|
||||
# Some versions of libcurl have deadlock bugs with ssl,
|
||||
# so always run these tests with SimpleAsyncHTTPClient.
|
||||
return SimpleAsyncHTTPClient(io_loop=self.io_loop, force_instance=True,
|
||||
return AsyncHTTPClient(io_loop=self.io_loop, force_instance=True,
|
||||
defaults=dict(validate_cert=False))
|
||||
|
||||
def get_httpserver_options(self):
|
||||
|
@ -539,6 +541,9 @@ class LogTrapTestCase(unittest.TestCase):
|
|||
`logging.basicConfig` and the "pretty logging" configured by
|
||||
`tornado.options`. It is not compatible with other log buffering
|
||||
mechanisms, such as those provided by some test runners.
|
||||
|
||||
.. deprecated:: 4.1
|
||||
Use the unittest module's ``--buffer`` option instead, or `.ExpectLog`.
|
||||
"""
|
||||
def run(self, result=None):
|
||||
logger = logging.getLogger()
|
||||
|
|
|
@ -78,6 +78,25 @@ class GzipDecompressor(object):
|
|||
return self.decompressobj.flush()
|
||||
|
||||
|
||||
# Fake unicode literal support: Python 3.2 doesn't have the u'' marker for
|
||||
# literal strings, and alternative solutions like "from __future__ import
|
||||
# unicode_literals" have other problems (see PEP 414). u() can be applied
|
||||
# to ascii strings that include \u escapes (but they must not contain
|
||||
# literal non-ascii characters).
|
||||
if not isinstance(b'', type('')):
|
||||
def u(s):
|
||||
return s
|
||||
unicode_type = str
|
||||
basestring_type = str
|
||||
else:
|
||||
def u(s):
|
||||
return s.decode('unicode_escape')
|
||||
# These names don't exist in py3, so use noqa comments to disable
|
||||
# warnings in flake8.
|
||||
unicode_type = unicode # noqa
|
||||
basestring_type = basestring # noqa
|
||||
|
||||
|
||||
def import_object(name):
|
||||
"""Imports an object by name.
|
||||
|
||||
|
@ -96,6 +115,9 @@ def import_object(name):
|
|||
...
|
||||
ImportError: No module named missing_module
|
||||
"""
|
||||
if isinstance(name, unicode_type) and str is not unicode_type:
|
||||
# On python 2 a byte string is required.
|
||||
name = name.encode('utf-8')
|
||||
if name.count('.') == 0:
|
||||
return __import__(name, None, None)
|
||||
|
||||
|
@ -107,22 +129,6 @@ def import_object(name):
|
|||
raise ImportError("No module named %s" % parts[-1])
|
||||
|
||||
|
||||
# Fake unicode literal support: Python 3.2 doesn't have the u'' marker for
|
||||
# literal strings, and alternative solutions like "from __future__ import
|
||||
# unicode_literals" have other problems (see PEP 414). u() can be applied
|
||||
# to ascii strings that include \u escapes (but they must not contain
|
||||
# literal non-ascii characters).
|
||||
if type('') is not type(b''):
|
||||
def u(s):
|
||||
return s
|
||||
unicode_type = str
|
||||
basestring_type = str
|
||||
else:
|
||||
def u(s):
|
||||
return s.decode('unicode_escape')
|
||||
unicode_type = unicode
|
||||
basestring_type = basestring
|
||||
|
||||
# Deprecated alias that was used before we dropped py25 support.
|
||||
# Left here in case anyone outside Tornado is using it.
|
||||
bytes_type = bytes
|
||||
|
@ -192,21 +198,21 @@ class Configurable(object):
|
|||
__impl_class = None
|
||||
__impl_kwargs = None
|
||||
|
||||
def __new__(cls, **kwargs):
|
||||
def __new__(cls, *args, **kwargs):
|
||||
base = cls.configurable_base()
|
||||
args = {}
|
||||
init_kwargs = {}
|
||||
if cls is base:
|
||||
impl = cls.configured_class()
|
||||
if base.__impl_kwargs:
|
||||
args.update(base.__impl_kwargs)
|
||||
init_kwargs.update(base.__impl_kwargs)
|
||||
else:
|
||||
impl = cls
|
||||
args.update(kwargs)
|
||||
init_kwargs.update(kwargs)
|
||||
instance = super(Configurable, cls).__new__(impl)
|
||||
# initialize vs __init__ chosen for compatibility with AsyncHTTPClient
|
||||
# singleton magic. If we get rid of that we can switch to __init__
|
||||
# here too.
|
||||
instance.initialize(**args)
|
||||
instance.initialize(*args, **init_kwargs)
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
|
@ -227,6 +233,9 @@ class Configurable(object):
|
|||
"""Initialize a `Configurable` subclass instance.
|
||||
|
||||
Configurable classes should use `initialize` instead of ``__init__``.
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
Now accepts positional arguments in addition to keyword arguments.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
|
|
324
tornado/web.py
324
tornado/web.py
|
@ -19,7 +19,9 @@ features that allow it to scale to large numbers of open connections,
|
|||
making it ideal for `long polling
|
||||
<http://en.wikipedia.org/wiki/Push_technology#Long_polling>`_.
|
||||
|
||||
Here is a simple "Hello, world" example app::
|
||||
Here is a simple "Hello, world" example app:
|
||||
|
||||
.. testcode::
|
||||
|
||||
import tornado.ioloop
|
||||
import tornado.web
|
||||
|
@ -33,7 +35,11 @@ Here is a simple "Hello, world" example app::
|
|||
(r"/", MainHandler),
|
||||
])
|
||||
application.listen(8888)
|
||||
tornado.ioloop.IOLoop.instance().start()
|
||||
tornado.ioloop.IOLoop.current().start()
|
||||
|
||||
.. testoutput::
|
||||
:hide:
|
||||
|
||||
|
||||
See the :doc:`guide` for additional information.
|
||||
|
||||
|
@ -50,7 +56,8 @@ request.
|
|||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import (absolute_import, division,
|
||||
print_function, with_statement)
|
||||
|
||||
|
||||
import base64
|
||||
|
@ -84,7 +91,9 @@ from tornado.log import access_log, app_log, gen_log
|
|||
from tornado import stack_context
|
||||
from tornado import template
|
||||
from tornado.escape import utf8, _unicode
|
||||
from tornado.util import import_object, ObjectDict, raise_exc_info, unicode_type, _websocket_mask
|
||||
from tornado.util import (import_object, ObjectDict, raise_exc_info,
|
||||
unicode_type, _websocket_mask)
|
||||
from tornado.httputil import split_host_and_port
|
||||
|
||||
|
||||
try:
|
||||
|
@ -130,12 +139,11 @@ May be overridden by passing a ``version`` keyword argument.
|
|||
DEFAULT_SIGNED_VALUE_MIN_VERSION = 1
|
||||
"""The oldest signed value accepted by `.RequestHandler.get_secure_cookie`.
|
||||
|
||||
May be overrided by passing a ``min_version`` keyword argument.
|
||||
May be overridden by passing a ``min_version`` keyword argument.
|
||||
|
||||
.. versionadded:: 3.2.1
|
||||
"""
|
||||
|
||||
|
||||
class RequestHandler(object):
|
||||
"""Subclass this class and define `get()` or `post()` to make a handler.
|
||||
|
||||
|
@ -267,6 +275,7 @@ class RequestHandler(object):
|
|||
if _has_stream_request_body(self.__class__):
|
||||
if not self.request.body.done():
|
||||
self.request.body.set_exception(iostream.StreamClosedError())
|
||||
self.request.body.exception()
|
||||
|
||||
def clear(self):
|
||||
"""Resets all headers and content for this response."""
|
||||
|
@ -382,6 +391,12 @@ class RequestHandler(object):
|
|||
|
||||
The returned values are always unicode.
|
||||
"""
|
||||
|
||||
# Make sure `get_arguments` isn't accidentally being called with a
|
||||
# positional argument that's assumed to be a default (like in
|
||||
# `get_argument`.)
|
||||
assert isinstance(strip, bool)
|
||||
|
||||
return self._get_arguments(name, self.request.arguments, strip)
|
||||
|
||||
def get_body_argument(self, name, default=_ARG_DEFAULT, strip=True):
|
||||
|
@ -398,7 +413,8 @@ class RequestHandler(object):
|
|||
|
||||
.. versionadded:: 3.2
|
||||
"""
|
||||
return self._get_argument(name, default, self.request.body_arguments, strip)
|
||||
return self._get_argument(name, default, self.request.body_arguments,
|
||||
strip)
|
||||
|
||||
def get_body_arguments(self, name, strip=True):
|
||||
"""Returns a list of the body arguments with the given name.
|
||||
|
@ -425,7 +441,8 @@ class RequestHandler(object):
|
|||
|
||||
.. versionadded:: 3.2
|
||||
"""
|
||||
return self._get_argument(name, default, self.request.query_arguments, strip)
|
||||
return self._get_argument(name, default,
|
||||
self.request.query_arguments, strip)
|
||||
|
||||
def get_query_arguments(self, name, strip=True):
|
||||
"""Returns a list of the query arguments with the given name.
|
||||
|
@ -480,7 +497,8 @@ class RequestHandler(object):
|
|||
|
||||
@property
|
||||
def cookies(self):
|
||||
"""An alias for `self.request.cookies <.httputil.HTTPServerRequest.cookies>`."""
|
||||
"""An alias for
|
||||
`self.request.cookies <.httputil.HTTPServerRequest.cookies>`."""
|
||||
return self.request.cookies
|
||||
|
||||
def get_cookie(self, name, default=None):
|
||||
|
@ -522,6 +540,12 @@ class RequestHandler(object):
|
|||
for k, v in kwargs.items():
|
||||
if k == 'max_age':
|
||||
k = 'max-age'
|
||||
|
||||
# skip falsy values for httponly and secure flags because
|
||||
# SimpleCookie sets them regardless
|
||||
if k in ['httponly', 'secure'] and not v:
|
||||
continue
|
||||
|
||||
morsel[k] = v
|
||||
|
||||
def clear_cookie(self, name, path="/", domain=None):
|
||||
|
@ -588,8 +612,15 @@ class RequestHandler(object):
|
|||
and made it the default.
|
||||
"""
|
||||
self.require_setting("cookie_secret", "secure cookies")
|
||||
return create_signed_value(self.application.settings["cookie_secret"],
|
||||
name, value, version=version)
|
||||
secret = self.application.settings["cookie_secret"]
|
||||
key_version = None
|
||||
if isinstance(secret, dict):
|
||||
if self.application.settings.get("key_version") is None:
|
||||
raise Exception("key_version setting must be used for secret_key dicts")
|
||||
key_version = self.application.settings["key_version"]
|
||||
|
||||
return create_signed_value(secret, name, value, version=version,
|
||||
key_version=key_version)
|
||||
|
||||
def get_secure_cookie(self, name, value=None, max_age_days=31,
|
||||
min_version=None):
|
||||
|
@ -610,6 +641,17 @@ class RequestHandler(object):
|
|||
name, value, max_age_days=max_age_days,
|
||||
min_version=min_version)
|
||||
|
||||
def get_secure_cookie_key_version(self, name, value=None):
|
||||
"""Returns the signing key version of the secure cookie.
|
||||
|
||||
The version is returned as int.
|
||||
"""
|
||||
self.require_setting("cookie_secret", "secure cookies")
|
||||
if value is None:
|
||||
value = self.get_cookie(name)
|
||||
return get_signature_key_version(value)
|
||||
|
||||
|
||||
def redirect(self, url, permanent=False, status=None):
|
||||
"""Sends a redirect to the given (optionally relative) URL.
|
||||
|
||||
|
@ -625,8 +667,7 @@ class RequestHandler(object):
|
|||
else:
|
||||
assert isinstance(status, int) and 300 <= status <= 399
|
||||
self.set_status(status)
|
||||
self.set_header("Location", urlparse.urljoin(utf8(self.request.uri),
|
||||
utf8(url)))
|
||||
self.set_header("Location", utf8(url))
|
||||
self.finish()
|
||||
|
||||
def write(self, chunk):
|
||||
|
@ -646,15 +687,13 @@ class RequestHandler(object):
|
|||
https://github.com/facebook/tornado/issues/1009
|
||||
"""
|
||||
if self._finished:
|
||||
raise RuntimeError("Cannot write() after finish(). May be caused "
|
||||
"by using async operations without the "
|
||||
"@asynchronous decorator.")
|
||||
raise RuntimeError("Cannot write() after finish()")
|
||||
if not isinstance(chunk, (bytes, unicode_type, dict)):
|
||||
raise TypeError("write() only accepts bytes, unicode, and dict objects")
|
||||
message = "write() only accepts bytes, unicode, and dict objects"
|
||||
if isinstance(chunk, list):
|
||||
message += ". Lists not accepted for security reasons; see http://www.tornadoweb.org/en/stable/web.html#tornado.web.RequestHandler.write"
|
||||
raise TypeError(message)
|
||||
if isinstance(chunk, dict):
|
||||
if 'unwrap_json' in chunk:
|
||||
chunk = chunk['unwrap_json']
|
||||
else:
|
||||
chunk = escape.json_encode(chunk)
|
||||
self.set_header("Content-Type", "application/json; charset=UTF-8")
|
||||
chunk = utf8(chunk)
|
||||
|
@ -786,6 +825,7 @@ class RequestHandler(object):
|
|||
current_user=self.current_user,
|
||||
locale=self.locale,
|
||||
_=self.locale.translate,
|
||||
pgettext=self.locale.pgettext,
|
||||
static_url=self.static_url,
|
||||
xsrf_form_html=self.xsrf_form_html,
|
||||
reverse_url=self.reverse_url
|
||||
|
@ -830,7 +870,8 @@ class RequestHandler(object):
|
|||
for transform in self._transforms:
|
||||
self._status_code, self._headers, chunk = \
|
||||
transform.transform_first_chunk(
|
||||
self._status_code, self._headers, chunk, include_footers)
|
||||
self._status_code, self._headers,
|
||||
chunk, include_footers)
|
||||
# Ignore the chunk and only write the headers for HEAD requests
|
||||
if self.request.method == "HEAD":
|
||||
chunk = None
|
||||
|
@ -842,7 +883,7 @@ class RequestHandler(object):
|
|||
for cookie in self._new_cookie.values():
|
||||
self.add_header("Set-Cookie", cookie.OutputString(None))
|
||||
|
||||
start_line = httputil.ResponseStartLine(self.request.version,
|
||||
start_line = httputil.ResponseStartLine('',
|
||||
self._status_code,
|
||||
self._reason)
|
||||
return self.request.connection.write_headers(
|
||||
|
@ -861,9 +902,7 @@ class RequestHandler(object):
|
|||
def finish(self, chunk=None):
|
||||
"""Finishes this response, ending the HTTP request."""
|
||||
if self._finished:
|
||||
raise RuntimeError("finish() called twice. May be caused "
|
||||
"by using async operations without the "
|
||||
"@asynchronous decorator.")
|
||||
raise RuntimeError("finish() called twice")
|
||||
|
||||
if chunk is not None:
|
||||
self.write(chunk)
|
||||
|
@ -915,7 +954,15 @@ class RequestHandler(object):
|
|||
if self._headers_written:
|
||||
gen_log.error("Cannot send error response after headers written")
|
||||
if not self._finished:
|
||||
# If we get an error between writing headers and finishing,
|
||||
# we are unlikely to be able to finish due to a
|
||||
# Content-Length mismatch. Try anyway to release the
|
||||
# socket.
|
||||
try:
|
||||
self.finish()
|
||||
except Exception:
|
||||
gen_log.error("Failed to flush partial response",
|
||||
exc_info=True)
|
||||
return
|
||||
self.clear()
|
||||
|
||||
|
@ -1122,11 +1169,15 @@ class RequestHandler(object):
|
|||
"""Convert a cookie string into a the tuple form returned by
|
||||
_get_raw_xsrf_token.
|
||||
"""
|
||||
|
||||
try:
|
||||
m = _signed_value_version_re.match(utf8(cookie))
|
||||
|
||||
if m:
|
||||
version = int(m.group(1))
|
||||
if version == 2:
|
||||
_, mask, masked_token, timestamp = cookie.split("|")
|
||||
|
||||
mask = binascii.a2b_hex(utf8(mask))
|
||||
token = _websocket_mask(
|
||||
mask, binascii.a2b_hex(utf8(masked_token)))
|
||||
|
@ -1134,7 +1185,7 @@ class RequestHandler(object):
|
|||
return version, token, timestamp
|
||||
else:
|
||||
# Treat unknown versions as not present instead of failing.
|
||||
return None, None, None
|
||||
raise Exception("Unknown xsrf cookie version")
|
||||
else:
|
||||
version = 1
|
||||
try:
|
||||
|
@ -1144,6 +1195,11 @@ class RequestHandler(object):
|
|||
# We don't have a usable timestamp in older versions.
|
||||
timestamp = int(time.time())
|
||||
return (version, token, timestamp)
|
||||
except Exception:
|
||||
# Catch exceptions and return nothing instead of failing.
|
||||
gen_log.debug("Uncaught exception in _decode_xsrf_token",
|
||||
exc_info=True)
|
||||
return None, None, None
|
||||
|
||||
def check_xsrf_cookie(self):
|
||||
"""Verifies that the ``_xsrf`` cookie matches the ``_xsrf`` argument.
|
||||
|
@ -1282,9 +1338,27 @@ class RequestHandler(object):
|
|||
before completing the request. The ``Etag`` header should be set
|
||||
(perhaps with `set_etag_header`) before calling this method.
|
||||
"""
|
||||
etag = self._headers.get("Etag")
|
||||
inm = utf8(self.request.headers.get("If-None-Match", ""))
|
||||
return bool(etag and inm and inm.find(etag) >= 0)
|
||||
computed_etag = utf8(self._headers.get("Etag", ""))
|
||||
# Find all weak and strong etag values from If-None-Match header
|
||||
# because RFC 7232 allows multiple etag values in a single header.
|
||||
etags = re.findall(
|
||||
br'\*|(?:W/)?"[^"]*"',
|
||||
utf8(self.request.headers.get("If-None-Match", ""))
|
||||
)
|
||||
if not computed_etag or not etags:
|
||||
return False
|
||||
|
||||
match = False
|
||||
if etags[0] == b'*':
|
||||
match = True
|
||||
else:
|
||||
# Use a weak comparison when comparing entity-tags.
|
||||
val = lambda x: x[2:] if x.startswith(b'W/') else x
|
||||
for etag in etags:
|
||||
if val(etag) == val(computed_etag):
|
||||
match = True
|
||||
break
|
||||
return match
|
||||
|
||||
def _stack_context_handle_exception(self, type, value, traceback):
|
||||
try:
|
||||
|
@ -1344,7 +1418,10 @@ class RequestHandler(object):
|
|||
if self._auto_finish and not self._finished:
|
||||
self.finish()
|
||||
except Exception as e:
|
||||
try:
|
||||
self._handle_request_exception(e)
|
||||
except Exception:
|
||||
app_log.error("Exception in exception handler", exc_info=True)
|
||||
if (self._prepared_future is not None and
|
||||
not self._prepared_future.done()):
|
||||
# In case we failed before setting _prepared_future, do it
|
||||
|
@ -1369,8 +1446,8 @@ class RequestHandler(object):
|
|||
self.application.log_request(self)
|
||||
|
||||
def _request_summary(self):
|
||||
return self.request.method + " " + self.request.uri + \
|
||||
" (" + self.request.remote_ip + ")"
|
||||
return "%s %s (%s)" % (self.request.method, self.request.uri,
|
||||
self.request.remote_ip)
|
||||
|
||||
def _handle_request_exception(self, e):
|
||||
if isinstance(e, Finish):
|
||||
|
@ -1378,7 +1455,12 @@ class RequestHandler(object):
|
|||
if not self._finished:
|
||||
self.finish()
|
||||
return
|
||||
try:
|
||||
self.log_exception(*sys.exc_info())
|
||||
except Exception:
|
||||
# An error here should still get a best-effort send_error()
|
||||
# to avoid leaking the connection.
|
||||
app_log.error("Error in exception logger", exc_info=True)
|
||||
if self._finished:
|
||||
# Extra errors after the request has been finished should
|
||||
# be logged, but there is no reason to continue to try and
|
||||
|
@ -1441,10 +1523,11 @@ class RequestHandler(object):
|
|||
def asynchronous(method):
|
||||
"""Wrap request handler methods with this if they are asynchronous.
|
||||
|
||||
This decorator is unnecessary if the method is also decorated with
|
||||
``@gen.coroutine`` (it is legal but unnecessary to use the two
|
||||
decorators together, in which case ``@asynchronous`` must be
|
||||
first).
|
||||
This decorator is for callback-style asynchronous methods; for
|
||||
coroutines, use the ``@gen.coroutine`` decorator without
|
||||
``@asynchronous``. (It is legal for legacy reasons to use the two
|
||||
decorators together provided ``@asynchronous`` is first, but
|
||||
``@asynchronous`` will be ignored in this case)
|
||||
|
||||
This decorator should only be applied to the :ref:`HTTP verb
|
||||
methods <verbs>`; its behavior is undefined for any other method.
|
||||
|
@ -1457,10 +1540,12 @@ def asynchronous(method):
|
|||
method returns. It is up to the request handler to call
|
||||
`self.finish() <RequestHandler.finish>` to finish the HTTP
|
||||
request. Without this decorator, the request is automatically
|
||||
finished when the ``get()`` or ``post()`` method returns. Example::
|
||||
finished when the ``get()`` or ``post()`` method returns. Example:
|
||||
|
||||
class MyRequestHandler(web.RequestHandler):
|
||||
@web.asynchronous
|
||||
.. testcode::
|
||||
|
||||
class MyRequestHandler(RequestHandler):
|
||||
@asynchronous
|
||||
def get(self):
|
||||
http = httpclient.AsyncHTTPClient()
|
||||
http.fetch("http://friendfeed.com/", self._on_download)
|
||||
|
@ -1469,18 +1554,23 @@ def asynchronous(method):
|
|||
self.write("Downloaded!")
|
||||
self.finish()
|
||||
|
||||
.. testoutput::
|
||||
:hide:
|
||||
|
||||
.. versionadded:: 3.1
|
||||
The ability to use ``@gen.coroutine`` without ``@asynchronous``.
|
||||
|
||||
"""
|
||||
# Delay the IOLoop import because it's not available on app engine.
|
||||
from tornado.ioloop import IOLoop
|
||||
|
||||
@functools.wraps(method)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
self._auto_finish = False
|
||||
with stack_context.ExceptionStackContext(
|
||||
self._stack_context_handle_exception):
|
||||
result = method(self, *args, **kwargs)
|
||||
if isinstance(result, Future):
|
||||
if is_future(result):
|
||||
# If @asynchronous is used with @gen.coroutine, (but
|
||||
# not @gen.engine), we can automatically finish the
|
||||
# request when the future resolves. Additionally,
|
||||
|
@ -1521,7 +1611,7 @@ def stream_request_body(cls):
|
|||
the entire body has been read.
|
||||
|
||||
There is a subtle interaction between ``data_received`` and asynchronous
|
||||
``prepare``: The first call to ``data_recieved`` may occur at any point
|
||||
``prepare``: The first call to ``data_received`` may occur at any point
|
||||
after the call to ``prepare`` has returned *or yielded*.
|
||||
"""
|
||||
if not issubclass(cls, RequestHandler):
|
||||
|
@ -1591,7 +1681,7 @@ class Application(httputil.HTTPServerConnectionDelegate):
|
|||
])
|
||||
http_server = httpserver.HTTPServer(application)
|
||||
http_server.listen(8080)
|
||||
ioloop.IOLoop.instance().start()
|
||||
ioloop.IOLoop.current().start()
|
||||
|
||||
The constructor for this class takes in a list of `URLSpec` objects
|
||||
or (regexp, request_class) tuples. When we receive requests, we
|
||||
|
@ -1689,7 +1779,7 @@ class Application(httputil.HTTPServerConnectionDelegate):
|
|||
`.TCPServer.bind`/`.TCPServer.start` methods directly.
|
||||
|
||||
Note that after calling this method you still need to call
|
||||
``IOLoop.instance().start()`` to start the server.
|
||||
``IOLoop.current().start()`` to start the server.
|
||||
"""
|
||||
# import is here rather than top level because HTTPServer
|
||||
# is not importable on appengine
|
||||
|
@ -1732,7 +1822,7 @@ class Application(httputil.HTTPServerConnectionDelegate):
|
|||
self.transforms.append(transform_class)
|
||||
|
||||
def _get_host_handlers(self, request):
|
||||
host = request.host.lower().split(':')[0]
|
||||
host = split_host_and_port(request.host.lower())[0]
|
||||
matches = []
|
||||
for pattern, handlers in self.handlers:
|
||||
if pattern.match(host):
|
||||
|
@ -1773,9 +1863,9 @@ class Application(httputil.HTTPServerConnectionDelegate):
|
|||
except TypeError:
|
||||
pass
|
||||
|
||||
def start_request(self, connection):
|
||||
def start_request(self, server_conn, request_conn):
|
||||
# Modern HTTPServer interface
|
||||
return _RequestDispatcher(self, connection)
|
||||
return _RequestDispatcher(self, request_conn)
|
||||
|
||||
def __call__(self, request):
|
||||
# Legacy HTTPServer interface
|
||||
|
@ -1831,7 +1921,8 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate):
|
|||
|
||||
def headers_received(self, start_line, headers):
|
||||
self.set_request(httputil.HTTPServerRequest(
|
||||
connection=self.connection, start_line=start_line, headers=headers))
|
||||
connection=self.connection, start_line=start_line,
|
||||
headers=headers))
|
||||
if self.stream_request_body:
|
||||
self.request.body = Future()
|
||||
return self.execute()
|
||||
|
@ -1848,7 +1939,9 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate):
|
|||
handlers = app._get_host_handlers(self.request)
|
||||
if not handlers:
|
||||
self.handler_class = RedirectHandler
|
||||
self.handler_kwargs = dict(url="http://" + app.default_host + "/")
|
||||
self.handler_kwargs = dict(url="%s://%s/"
|
||||
% (self.request.protocol,
|
||||
app.default_host))
|
||||
return
|
||||
for spec in handlers:
|
||||
match = spec.regex.match(self.request.path)
|
||||
|
@ -1914,11 +2007,14 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate):
|
|||
if self.stream_request_body:
|
||||
self.handler._prepared_future = Future()
|
||||
# Note that if an exception escapes handler._execute it will be
|
||||
# trapped in the Future it returns (which we are ignoring here).
|
||||
# trapped in the Future it returns (which we are ignoring here,
|
||||
# leaving it to be logged when the Future is GC'd).
|
||||
# However, that shouldn't happen because _execute has a blanket
|
||||
# except handler, and we cannot easily access the IOLoop here to
|
||||
# call add_future.
|
||||
self.handler._execute(transforms, *self.path_args, **self.path_kwargs)
|
||||
# call add_future (because of the requirement to remain compatible
|
||||
# with WSGI)
|
||||
f = self.handler._execute(transforms, *self.path_args,
|
||||
**self.path_kwargs)
|
||||
# If we are streaming the request body, then execute() is finished
|
||||
# when the handler has prepared to receive the body. If not,
|
||||
# it doesn't matter when execute() finishes (so we return None)
|
||||
|
@ -1952,6 +2048,8 @@ class HTTPError(Exception):
|
|||
self.log_message = log_message
|
||||
self.args = args
|
||||
self.reason = kwargs.get('reason', None)
|
||||
if log_message and not args:
|
||||
self.log_message = log_message.replace('%', '%%')
|
||||
|
||||
def __str__(self):
|
||||
message = "HTTP %d: %s" % (
|
||||
|
@ -2212,7 +2310,8 @@ class StaticFileHandler(RequestHandler):
|
|||
if content_type:
|
||||
self.set_header("Content-Type", content_type)
|
||||
|
||||
cache_time = self.get_cache_time(self.path, self.modified, content_type)
|
||||
cache_time = self.get_cache_time(self.path, self.modified,
|
||||
content_type)
|
||||
if cache_time > 0:
|
||||
self.set_header("Expires", datetime.datetime.utcnow() +
|
||||
datetime.timedelta(seconds=cache_time))
|
||||
|
@ -2381,7 +2480,8 @@ class StaticFileHandler(RequestHandler):
|
|||
.. versionadded:: 3.1
|
||||
"""
|
||||
stat_result = self._stat()
|
||||
modified = datetime.datetime.utcfromtimestamp(stat_result[stat.ST_MTIME])
|
||||
modified = datetime.datetime.utcfromtimestamp(
|
||||
stat_result[stat.ST_MTIME])
|
||||
return modified
|
||||
|
||||
def get_content_type(self):
|
||||
|
@ -2624,6 +2724,8 @@ class UIModule(object):
|
|||
UI modules often execute additional queries, and they can include
|
||||
additional CSS and JavaScript that will be included in the output
|
||||
page, which is automatically inserted on page render.
|
||||
|
||||
Subclasses of UIModule must override the `render` method.
|
||||
"""
|
||||
def __init__(self, handler):
|
||||
self.handler = handler
|
||||
|
@ -2636,31 +2738,45 @@ class UIModule(object):
|
|||
return self.handler.current_user
|
||||
|
||||
def render(self, *args, **kwargs):
|
||||
"""Overridden in subclasses to return this module's output."""
|
||||
"""Override in subclasses to return this module's output."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def embedded_javascript(self):
|
||||
"""Returns a JavaScript string that will be embedded in the page."""
|
||||
"""Override to return a JavaScript string
|
||||
to be embedded in the page."""
|
||||
return None
|
||||
|
||||
def javascript_files(self):
|
||||
"""Returns a list of JavaScript files required by this module."""
|
||||
"""Override to return a list of JavaScript files needed by this module.
|
||||
|
||||
If the return values are relative paths, they will be passed to
|
||||
`RequestHandler.static_url`; otherwise they will be used as-is.
|
||||
"""
|
||||
return None
|
||||
|
||||
def embedded_css(self):
|
||||
"""Returns a CSS string that will be embedded in the page."""
|
||||
"""Override to return a CSS string
|
||||
that will be embedded in the page."""
|
||||
return None
|
||||
|
||||
def css_files(self):
|
||||
"""Returns a list of CSS files required by this module."""
|
||||
"""Override to returns a list of CSS files required by this module.
|
||||
|
||||
If the return values are relative paths, they will be passed to
|
||||
`RequestHandler.static_url`; otherwise they will be used as-is.
|
||||
"""
|
||||
return None
|
||||
|
||||
def html_head(self):
|
||||
"""Returns a CSS string that will be put in the <head/> element"""
|
||||
"""Override to return an HTML string that will be put in the <head/>
|
||||
element.
|
||||
"""
|
||||
return None
|
||||
|
||||
def html_body(self):
|
||||
"""Returns an HTML string that will be put in the <body/> element"""
|
||||
"""Override to return an HTML string that will be put at the end of
|
||||
the <body/> element.
|
||||
"""
|
||||
return None
|
||||
|
||||
def render_string(self, path, **kwargs):
|
||||
|
@ -2862,11 +2978,13 @@ else:
|
|||
return result == 0
|
||||
|
||||
|
||||
def create_signed_value(secret, name, value, version=None, clock=None):
|
||||
def create_signed_value(secret, name, value, version=None, clock=None,
|
||||
key_version=None):
|
||||
if version is None:
|
||||
version = DEFAULT_SIGNED_VALUE_VERSION
|
||||
if clock is None:
|
||||
clock = time.time
|
||||
|
||||
timestamp = utf8(str(int(clock())))
|
||||
value = base64.b64encode(utf8(value))
|
||||
if version == 1:
|
||||
|
@ -2883,7 +3001,7 @@ def create_signed_value(secret, name, value, version=None, clock=None):
|
|||
#
|
||||
# The fields are:
|
||||
# - format version (i.e. 2; no length prefix)
|
||||
# - key version (currently 0; reserved for future key rotation features)
|
||||
# - key version (integer, default is 0)
|
||||
# - timestamp (integer seconds since epoch)
|
||||
# - name (not encoded; assumed to be ~alphanumeric)
|
||||
# - value (base64-encoded)
|
||||
|
@ -2891,34 +3009,32 @@ def create_signed_value(secret, name, value, version=None, clock=None):
|
|||
def format_field(s):
|
||||
return utf8("%d:" % len(s)) + utf8(s)
|
||||
to_sign = b"|".join([
|
||||
b"2|1:0",
|
||||
b"2",
|
||||
format_field(str(key_version or 0)),
|
||||
format_field(timestamp),
|
||||
format_field(name),
|
||||
format_field(value),
|
||||
b''])
|
||||
|
||||
if isinstance(secret, dict):
|
||||
assert key_version is not None, 'Key version must be set when sign key dict is used'
|
||||
assert version >= 2, 'Version must be at least 2 for key version support'
|
||||
secret = secret[key_version]
|
||||
|
||||
signature = _create_signature_v2(secret, to_sign)
|
||||
return to_sign + signature
|
||||
else:
|
||||
raise ValueError("Unsupported version %d" % version)
|
||||
|
||||
# A leading version number in decimal with no leading zeros, followed by a pipe.
|
||||
# A leading version number in decimal
|
||||
# with no leading zeros, followed by a pipe.
|
||||
_signed_value_version_re = re.compile(br"^([1-9][0-9]*)\|(.*)$")
|
||||
|
||||
|
||||
def decode_signed_value(secret, name, value, max_age_days=31, clock=None, min_version=None):
|
||||
if clock is None:
|
||||
clock = time.time
|
||||
if min_version is None:
|
||||
min_version = DEFAULT_SIGNED_VALUE_MIN_VERSION
|
||||
if min_version > 2:
|
||||
raise ValueError("Unsupported min_version %d" % min_version)
|
||||
if not value:
|
||||
return None
|
||||
|
||||
# Figure out what version this is. Version 1 did not include an
|
||||
def _get_version(value):
|
||||
# Figures out what version value is. Version 1 did not include an
|
||||
# explicit version field and started with arbitrary base64 data,
|
||||
# which makes this tricky.
|
||||
value = utf8(value)
|
||||
m = _signed_value_version_re.match(value)
|
||||
if m is None:
|
||||
version = 1
|
||||
|
@ -2935,13 +3051,31 @@ def decode_signed_value(secret, name, value, max_age_days=31, clock=None, min_ve
|
|||
version = 1
|
||||
except ValueError:
|
||||
version = 1
|
||||
return version
|
||||
|
||||
|
||||
def decode_signed_value(secret, name, value, max_age_days=31,
|
||||
clock=None, min_version=None):
|
||||
if clock is None:
|
||||
clock = time.time
|
||||
if min_version is None:
|
||||
min_version = DEFAULT_SIGNED_VALUE_MIN_VERSION
|
||||
if min_version > 2:
|
||||
raise ValueError("Unsupported min_version %d" % min_version)
|
||||
if not value:
|
||||
return None
|
||||
|
||||
value = utf8(value)
|
||||
version = _get_version(value)
|
||||
|
||||
if version < min_version:
|
||||
return None
|
||||
if version == 1:
|
||||
return _decode_signed_value_v1(secret, name, value, max_age_days, clock)
|
||||
return _decode_signed_value_v1(secret, name, value,
|
||||
max_age_days, clock)
|
||||
elif version == 2:
|
||||
return _decode_signed_value_v2(secret, name, value, max_age_days, clock)
|
||||
return _decode_signed_value_v2(secret, name, value,
|
||||
max_age_days, clock)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
@ -2964,7 +3098,8 @@ def _decode_signed_value_v1(secret, name, value, max_age_days, clock):
|
|||
# digits from the payload to the timestamp without altering the
|
||||
# signature. For backwards compatibility, sanity-check timestamp
|
||||
# here instead of modifying _cookie_signature.
|
||||
gen_log.warning("Cookie timestamp in future; possible tampering %r", value)
|
||||
gen_log.warning("Cookie timestamp in future; possible tampering %r",
|
||||
value)
|
||||
return None
|
||||
if parts[1].startswith(b"0"):
|
||||
gen_log.warning("Tampered cookie %r", value)
|
||||
|
@ -2975,7 +3110,7 @@ def _decode_signed_value_v1(secret, name, value, max_age_days, clock):
|
|||
return None
|
||||
|
||||
|
||||
def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
|
||||
def _decode_fields_v2(value):
|
||||
def _consume_field(s):
|
||||
length, _, rest = s.partition(b':')
|
||||
n = int(length)
|
||||
|
@ -2986,16 +3121,28 @@ def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
|
|||
raise ValueError("malformed v2 signed value field")
|
||||
rest = rest[n + 1:]
|
||||
return field_value, rest
|
||||
|
||||
rest = value[2:] # remove version number
|
||||
try:
|
||||
key_version, rest = _consume_field(rest)
|
||||
timestamp, rest = _consume_field(rest)
|
||||
name_field, rest = _consume_field(rest)
|
||||
value_field, rest = _consume_field(rest)
|
||||
value_field, passed_sig = _consume_field(rest)
|
||||
return int(key_version), timestamp, name_field, value_field, passed_sig
|
||||
|
||||
|
||||
def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
|
||||
try:
|
||||
key_version, timestamp, name_field, value_field, passed_sig = _decode_fields_v2(value)
|
||||
except ValueError:
|
||||
return None
|
||||
passed_sig = rest
|
||||
signed_string = value[:-len(passed_sig)]
|
||||
|
||||
if isinstance(secret, dict):
|
||||
try:
|
||||
secret = secret[key_version]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
expected_sig = _create_signature_v2(secret, signed_string)
|
||||
if not _time_independent_equals(passed_sig, expected_sig):
|
||||
return None
|
||||
|
@ -3011,6 +3158,19 @@ def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
|
|||
return None
|
||||
|
||||
|
||||
def get_signature_key_version(value):
|
||||
value = utf8(value)
|
||||
version = _get_version(value)
|
||||
if version < 2:
|
||||
return None
|
||||
try:
|
||||
key_version, _, _, _, _ = _decode_fields_v2(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return key_version
|
||||
|
||||
|
||||
def _create_signature_v1(secret, *parts):
|
||||
hash = hmac.new(utf8(secret), digestmod=hashlib.sha1)
|
||||
for part in parts:
|
||||
|
|
|
@ -16,7 +16,8 @@ the protocol (known as "draft 76") and are not compatible with this module.
|
|||
Removed support for the draft 76 protocol version.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import (absolute_import, division,
|
||||
print_function, with_statement)
|
||||
# Author: Jacob Kristhammar, 2010
|
||||
|
||||
import base64
|
||||
|
@ -74,17 +75,22 @@ class WebSocketHandler(tornado.web.RequestHandler):
|
|||
http://tools.ietf.org/html/rfc6455.
|
||||
|
||||
Here is an example WebSocket handler that echos back all received messages
|
||||
back to the client::
|
||||
back to the client:
|
||||
|
||||
class EchoWebSocket(websocket.WebSocketHandler):
|
||||
.. testcode::
|
||||
|
||||
class EchoWebSocket(tornado.websocket.WebSocketHandler):
|
||||
def open(self):
|
||||
print "WebSocket opened"
|
||||
print("WebSocket opened")
|
||||
|
||||
def on_message(self, message):
|
||||
self.write_message(u"You said: " + message)
|
||||
|
||||
def on_close(self):
|
||||
print "WebSocket closed"
|
||||
print("WebSocket closed")
|
||||
|
||||
.. testoutput::
|
||||
:hide:
|
||||
|
||||
WebSockets are not standard HTTP connections. The "handshake" is
|
||||
HTTP, but after the handshake, the protocol is
|
||||
|
@ -129,6 +135,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
|
|||
self.close_code = None
|
||||
self.close_reason = None
|
||||
self.stream = None
|
||||
self._on_close_called = False
|
||||
|
||||
@tornado.web.asynchronous
|
||||
def get(self, *args, **kwargs):
|
||||
|
@ -138,16 +145,22 @@ class WebSocketHandler(tornado.web.RequestHandler):
|
|||
# Upgrade header should be present and should be equal to WebSocket
|
||||
if self.request.headers.get("Upgrade", "").lower() != 'websocket':
|
||||
self.set_status(400)
|
||||
self.finish("Can \"Upgrade\" only to \"WebSocket\".")
|
||||
log_msg = "Can \"Upgrade\" only to \"WebSocket\"."
|
||||
self.finish(log_msg)
|
||||
gen_log.debug(log_msg)
|
||||
return
|
||||
|
||||
# Connection header should be upgrade. Some proxy servers/load balancers
|
||||
# Connection header should be upgrade.
|
||||
# Some proxy servers/load balancers
|
||||
# might mess with it.
|
||||
headers = self.request.headers
|
||||
connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(","))
|
||||
connection = map(lambda s: s.strip().lower(),
|
||||
headers.get("Connection", "").split(","))
|
||||
if 'upgrade' not in connection:
|
||||
self.set_status(400)
|
||||
self.finish("\"Connection\" must be \"Upgrade\".")
|
||||
log_msg = "\"Connection\" must be \"Upgrade\"."
|
||||
self.finish(log_msg)
|
||||
gen_log.debug(log_msg)
|
||||
return
|
||||
|
||||
# Handle WebSocket Origin naming convention differences
|
||||
|
@ -159,30 +172,29 @@ class WebSocketHandler(tornado.web.RequestHandler):
|
|||
else:
|
||||
origin = self.request.headers.get("Sec-Websocket-Origin", None)
|
||||
|
||||
|
||||
# If there was an origin header, check to make sure it matches
|
||||
# according to check_origin. When the origin is None, we assume it
|
||||
# did not come from a browser and that it can be passed on.
|
||||
if origin is not None and not self.check_origin(origin):
|
||||
self.set_status(403)
|
||||
self.finish("Cross origin websockets not allowed")
|
||||
log_msg = "Cross origin websockets not allowed"
|
||||
self.finish(log_msg)
|
||||
gen_log.debug(log_msg)
|
||||
return
|
||||
|
||||
self.stream = self.request.connection.detach()
|
||||
self.stream.set_close_callback(self.on_connection_close)
|
||||
|
||||
if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
|
||||
self.ws_connection = WebSocketProtocol13(
|
||||
self, compression_options=self.get_compression_options())
|
||||
self.ws_connection = self.get_websocket_protocol()
|
||||
if self.ws_connection:
|
||||
self.ws_connection.accept_connection()
|
||||
else:
|
||||
if not self.stream.closed():
|
||||
self.stream.write(tornado.escape.utf8(
|
||||
"HTTP/1.1 426 Upgrade Required\r\n"
|
||||
"Sec-WebSocket-Version: 8\r\n\r\n"))
|
||||
"Sec-WebSocket-Version: 7, 8, 13\r\n\r\n"))
|
||||
self.stream.close()
|
||||
|
||||
|
||||
def write_message(self, message, binary=False):
|
||||
"""Sends the given message to the client of this Web Socket.
|
||||
|
||||
|
@ -229,7 +241,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
|
|||
"""
|
||||
return None
|
||||
|
||||
def open(self):
|
||||
def open(self, *args, **kwargs):
|
||||
"""Invoked when a new WebSocket is opened.
|
||||
|
||||
The arguments to `open` are extracted from the `tornado.web.URLSpec`
|
||||
|
@ -350,6 +362,8 @@ class WebSocketHandler(tornado.web.RequestHandler):
|
|||
if self.ws_connection:
|
||||
self.ws_connection.on_connection_close()
|
||||
self.ws_connection = None
|
||||
if not self._on_close_called:
|
||||
self._on_close_called = True
|
||||
self.on_close()
|
||||
|
||||
def send_error(self, *args, **kwargs):
|
||||
|
@ -362,6 +376,13 @@ class WebSocketHandler(tornado.web.RequestHandler):
|
|||
# we can close the connection more gracefully.
|
||||
self.stream.close()
|
||||
|
||||
def get_websocket_protocol(self):
|
||||
websocket_version = self.request.headers.get("Sec-WebSocket-Version")
|
||||
if websocket_version in ("7", "8", "13"):
|
||||
return WebSocketProtocol13(
|
||||
self, compression_options=self.get_compression_options())
|
||||
|
||||
|
||||
def _wrap_method(method):
|
||||
def _disallow_for_websocket(self, *args, **kwargs):
|
||||
if self.stream is None:
|
||||
|
@ -499,7 +520,8 @@ class WebSocketProtocol13(WebSocketProtocol):
|
|||
self._handle_websocket_headers()
|
||||
self._accept_connection()
|
||||
except ValueError:
|
||||
gen_log.debug("Malformed WebSocket request received", exc_info=True)
|
||||
gen_log.debug("Malformed WebSocket request received",
|
||||
exc_info=True)
|
||||
self._abort()
|
||||
return
|
||||
|
||||
|
@ -535,7 +557,8 @@ class WebSocketProtocol13(WebSocketProtocol):
|
|||
selected = self.handler.select_subprotocol(subprotocols)
|
||||
if selected:
|
||||
assert selected in subprotocols
|
||||
subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
|
||||
subprotocol_header = ("Sec-WebSocket-Protocol: %s\r\n"
|
||||
% selected)
|
||||
|
||||
extension_header = ''
|
||||
extensions = self._parse_extensions_header(self.request.headers)
|
||||
|
@ -703,7 +726,8 @@ class WebSocketProtocol13(WebSocketProtocol):
|
|||
if self._masked_frame:
|
||||
self.stream.read_bytes(4, self._on_masking_key)
|
||||
else:
|
||||
self.stream.read_bytes(self._frame_length, self._on_frame_data)
|
||||
self.stream.read_bytes(self._frame_length,
|
||||
self._on_frame_data)
|
||||
elif payloadlen == 126:
|
||||
self.stream.read_bytes(2, self._on_frame_length_16)
|
||||
elif payloadlen == 127:
|
||||
|
@ -737,7 +761,8 @@ class WebSocketProtocol13(WebSocketProtocol):
|
|||
self._wire_bytes_in += len(data)
|
||||
self._frame_mask = data
|
||||
try:
|
||||
self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
|
||||
self.stream.read_bytes(self._frame_length,
|
||||
self._on_masked_frame_data)
|
||||
except StreamClosedError:
|
||||
self._abort()
|
||||
|
||||
|
@ -852,12 +877,15 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
|
|||
This class should not be instantiated directly; use the
|
||||
`websocket_connect` function instead.
|
||||
"""
|
||||
def __init__(self, io_loop, request, compression_options=None):
|
||||
def __init__(self, io_loop, request, on_message_callback=None,
|
||||
compression_options=None):
|
||||
self.compression_options = compression_options
|
||||
self.connect_future = TracebackFuture()
|
||||
self.protocol = None
|
||||
self.read_future = None
|
||||
self.read_queue = collections.deque()
|
||||
self.key = base64.b64encode(os.urandom(16))
|
||||
self._on_message_callback = on_message_callback
|
||||
|
||||
scheme, sep, rest = request.url.partition(':')
|
||||
scheme = {'ws': 'http', 'wss': 'https'}[scheme]
|
||||
|
@ -880,7 +908,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
|
|||
self.tcp_client = TCPClient(io_loop=io_loop)
|
||||
super(WebSocketClientConnection, self).__init__(
|
||||
io_loop, None, request, lambda: None, self._on_http_response,
|
||||
104857600, self.tcp_client, 65536)
|
||||
104857600, self.tcp_client, 65536, 104857600)
|
||||
|
||||
def close(self, code=None, reason=None):
|
||||
"""Closes the websocket connection.
|
||||
|
@ -919,9 +947,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
|
|||
start_line, headers)
|
||||
|
||||
self.headers = headers
|
||||
self.protocol = WebSocketProtocol13(
|
||||
self, mask_outgoing=True,
|
||||
compression_options=self.compression_options)
|
||||
self.protocol = self.get_websocket_protocol()
|
||||
self.protocol._process_server_headers(self.key, self.headers)
|
||||
self.protocol._receive_frame()
|
||||
|
||||
|
@ -946,6 +972,9 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
|
|||
def read_message(self, callback=None):
|
||||
"""Reads a message from the WebSocket server.
|
||||
|
||||
If on_message_callback was specified at WebSocket
|
||||
initialization, this function will never return messages
|
||||
|
||||
Returns a future whose result is the message, or None
|
||||
if the connection is closed. If a callback argument
|
||||
is given it will be called with the future when it is
|
||||
|
@ -962,7 +991,9 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
|
|||
return future
|
||||
|
||||
def on_message(self, message):
|
||||
if self.read_future is not None:
|
||||
if self._on_message_callback:
|
||||
self._on_message_callback(message)
|
||||
elif self.read_future is not None:
|
||||
self.read_future.set_result(message)
|
||||
self.read_future = None
|
||||
else:
|
||||
|
@ -971,9 +1002,13 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
|
|||
def on_pong(self, data):
|
||||
pass
|
||||
|
||||
def get_websocket_protocol(self):
|
||||
return WebSocketProtocol13(self, mask_outgoing=True,
|
||||
compression_options=self.compression_options)
|
||||
|
||||
|
||||
def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
|
||||
compression_options=None):
|
||||
on_message_callback=None, compression_options=None):
|
||||
"""Client-side websocket support.
|
||||
|
||||
Takes a url and returns a Future whose result is a
|
||||
|
@ -982,11 +1017,26 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
|
|||
``compression_options`` is interpreted in the same way as the
|
||||
return value of `.WebSocketHandler.get_compression_options`.
|
||||
|
||||
The connection supports two styles of operation. In the coroutine
|
||||
style, the application typically calls
|
||||
`~.WebSocketClientConnection.read_message` in a loop::
|
||||
|
||||
conn = yield websocket_connection(loop)
|
||||
while True:
|
||||
msg = yield conn.read_message()
|
||||
if msg is None: break
|
||||
# Do something with msg
|
||||
|
||||
In the callback style, pass an ``on_message_callback`` to
|
||||
``websocket_connect``. In both styles, a message of ``None``
|
||||
indicates that the connection has been closed.
|
||||
|
||||
.. versionchanged:: 3.2
|
||||
Also accepts ``HTTPRequest`` objects in place of urls.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
Added ``compression_options``.
|
||||
Added ``compression_options`` and ``on_message_callback``.
|
||||
The ``io_loop`` argument is deprecated.
|
||||
"""
|
||||
if io_loop is None:
|
||||
io_loop = IOLoop.current()
|
||||
|
@ -1000,7 +1050,9 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
|
|||
request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
|
||||
request = httpclient._RequestProxy(
|
||||
request, httpclient.HTTPRequest._DEFAULTS)
|
||||
conn = WebSocketClientConnection(io_loop, request, compression_options)
|
||||
conn = WebSocketClientConnection(io_loop, request,
|
||||
on_message_callback=on_message_callback,
|
||||
compression_options=compression_options)
|
||||
if callback is not None:
|
||||
io_loop.add_future(conn.connect_future, callback)
|
||||
return conn.connect_future
|
||||
|
|
|
@ -207,7 +207,7 @@ class WSGIAdapter(object):
|
|||
body = environ["wsgi.input"].read(
|
||||
int(headers["Content-Length"]))
|
||||
else:
|
||||
body = ""
|
||||
body = b""
|
||||
protocol = environ["wsgi.url_scheme"]
|
||||
remote_ip = environ.get("REMOTE_ADDR", "")
|
||||
if environ.get("HTTP_HOST"):
|
||||
|
@ -253,7 +253,7 @@ class WSGIContainer(object):
|
|||
container = tornado.wsgi.WSGIContainer(simple_app)
|
||||
http_server = tornado.httpserver.HTTPServer(container)
|
||||
http_server.listen(8888)
|
||||
tornado.ioloop.IOLoop.instance().start()
|
||||
tornado.ioloop.IOLoop.current().start()
|
||||
|
||||
This class is intended to let other frameworks (Django, web.py, etc)
|
||||
run on the Tornado HTTP server and I/O loop.
|
||||
|
@ -284,7 +284,8 @@ class WSGIContainer(object):
|
|||
if not data:
|
||||
raise Exception("WSGI app did not call start_response")
|
||||
|
||||
status_code = int(data["status"].split()[0])
|
||||
status_code, reason = data["status"].split(' ', 1)
|
||||
status_code = int(status_code)
|
||||
headers = data["headers"]
|
||||
header_set = set(k.lower() for (k, v) in headers)
|
||||
body = escape.utf8(body)
|
||||
|
@ -296,13 +297,12 @@ class WSGIContainer(object):
|
|||
if "server" not in header_set:
|
||||
headers.append(("Server", "TornadoServer/%s" % tornado.version))
|
||||
|
||||
parts = [escape.utf8("HTTP/1.1 " + data["status"] + "\r\n")]
|
||||
start_line = httputil.ResponseStartLine("HTTP/1.1", status_code, reason)
|
||||
header_obj = httputil.HTTPHeaders()
|
||||
for key, value in headers:
|
||||
parts.append(escape.utf8(key) + b": " + escape.utf8(value) + b"\r\n")
|
||||
parts.append(b"\r\n")
|
||||
parts.append(body)
|
||||
request.write(b"".join(parts))
|
||||
request.finish()
|
||||
header_obj.add(key, value)
|
||||
request.connection.write_headers(start_line, header_obj, chunk=body)
|
||||
request.connection.finish()
|
||||
self._log(status_code, request)
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Reference in a new issue