Merge remote-tracking branch 'origin/dev'

This commit is contained in:
echel0n 2014-06-17 11:17:30 -07:00
commit e65fb9d09b
89 changed files with 5797 additions and 1696 deletions

View file

@ -389,9 +389,9 @@ def main():
io_loop.add_timeout(datetime.timedelta(seconds=5), startup)
# autoreload.
tornado.autoreload.add_reload_hook(autoreload_shutdown)
if sickbeard.AUTO_UPDATE:
tornado.autoreload.start(io_loop)
tornado.autoreload.add_reload_hook(autoreload_shutdown)
# start IOLoop.
io_loop.start()

View file

Before

Width:  |  Height:  |  Size: 1.5 KiB

After

Width:  |  Height:  |  Size: 1.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 977 B

After

Width:  |  Height:  |  Size: 1.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 977 B

After

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 977 B

After

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 977 B

After

Width:  |  Height:  |  Size: 1.4 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 977 B

After

Width:  |  Height:  |  Size: 1.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1 KiB

After

Width:  |  Height:  |  Size: 815 B

View file

Before

Width:  |  Height:  |  Size: 1 KiB

After

Width:  |  Height:  |  Size: 1 KiB

View file

Before

Width:  |  Height:  |  Size: 1 KiB

After

Width:  |  Height:  |  Size: 1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.1 KiB

After

Width:  |  Height:  |  Size: 954 B

View file

Before

Width:  |  Height:  |  Size: 1.5 KiB

After

Width:  |  Height:  |  Size: 1.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.7 KiB

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 915 B

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 986 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 1,005 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 KiB

View file

Before

Width:  |  Height:  |  Size: 1.5 KiB

After

Width:  |  Height:  |  Size: 1.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.3 KiB

After

Width:  |  Height:  |  Size: 1.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.3 KiB

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 KiB

View file

@ -1824,7 +1824,9 @@ def getEpList(epIDs, showid=None):
def autoreload_shutdown():
logger.log('SickRage is now auto-reloading, please stand by ...')
webserveInit.server.stop()
# halt all tasks
halt()
saveAll()
cleanup_tornado_sockets(IOLoop.current())
# save settings
saveAll()

View file

@ -101,11 +101,11 @@ def foldersAtPath(path, includeParent=False):
class WebFileBrowser(RequestHandler):
def index(self, path=''):
def index(self, path='', *args, **kwargs):
self.set_header("Content-Type", "application/json")
return self.finish(json.dumps(foldersAtPath(path, True)))
return json.dumps(foldersAtPath(path, True))
def complete(self, term):
self.set_header("Content-Type", "application/json")
paths = [entry['path'] for entry in foldersAtPath(os.path.dirname(term)) if 'path' in entry]
return self.finish(json.dumps(paths))
return json.dumps(paths)

View file

@ -88,4 +88,4 @@ class AddSceneExceptionsRefresh(AddSceneExceptionsCustom):
def execute(self):
self.connection.action(
"CREATE TABLE scene_exceptions_refresh (list TEXT, last_refreshed INTEGER)")
"CREATE TABLE scene_exceptions_refresh (list TEXT PRIMARY KEY, last_refreshed INTEGER)")

View file

@ -78,7 +78,8 @@ def logFailed(release):
provider = sql_results[0]["provider"]
if not hasFailed(release, size, provider):
myDB.action("INSERT INTO failed (release, size, provider) VALUES (?, ?, ?)", [release, size, provider])
with db.DBConnection('failed.db') as myDB:
myDB.action("INSERT INTO failed (release, size, provider) VALUES (?, ?, ?)", [release, size, provider])
deleteLoggedSnatch(release, size, provider)

View file

@ -48,6 +48,10 @@ reverseNames = {u'ERROR': ERROR,
u'DEBUG': DEBUG,
u'DB': DB}
# send logging to null
class NullHandler(logging.Handler):
def emit(self, record):
pass
class SBRotatingLogHandler(object):
def __init__(self, log_file, num_files, num_bytes):
@ -143,8 +147,7 @@ class SBRotatingLogHandler(object):
logging.getLogger('subliminal').setLevel(log_level)
logging.getLogger('imdbpy').setLevel(log_level)
# send logging to null
logging.getLogger('tornado.access').addHandler(logging.NullHandler())
logging.getLogger('tornado.access').addHandler(NullHandler())
# already logging in new log folder, close the old handler
if old_handler:

View file

@ -50,7 +50,7 @@ class TorrentRssProvider(generic.TorrentProvider):
if cookies:
self.cookies = cookies
else:
self.cookies = None
self.cookies = ''
def configStr(self):
return self.name + '|' + self.url + '|' + self.cookies + '|' + str(int(self.enabled)) + '|' + self.search_mode + '|' + str(int(self.search_fallback)) + '|' + str(int(self.backlog_only))

View file

@ -23,6 +23,7 @@ import inspect
import os.path
import time
import traceback
import urllib
import re
import threading
@ -91,35 +92,41 @@ from tornado.ioloop import IOLoop
req_headers = None
def require_basic_auth(handler_class):
def basicauth(handler_class):
def wrap_execute(handler_execute):
def require_basic_auth(handler, kwargs):
def get_auth():
def basicauth(handler, transforms, *args, **kwargs):
def _request_basic_auth(handler):
handler.set_status(401)
handler.set_header('WWW-Authenticate', 'Basic realm=Restricted')
handler._transforms = []
handler.finish()
return False
if not sickbeard.WEB_USERNAME and not sickbeard.WEB_PASSWORD:
if not handler.get_secure_cookie("user"):
handler.set_secure_cookie("user", str(time.time()))
return True
try:
auth_hdr = handler.request.headers.get('Authorization')
auth_header = handler.request.headers.get('Authorization')
if auth_header and auth_header.startswith('Basic '):
auth_decoded = base64.decodestring(auth_header[6:])
basicauth_user, basicauth_pass = auth_decoded.split(':', 2)
if basicauth_user == sickbeard.WEB_USERNAME and basicauth_pass == sickbeard.WEB_PASSWORD:
if auth_hdr == None:
return _request_basic_auth(handler)
if not auth_hdr.startswith('Basic '):
return _request_basic_auth(handler)
auth_decoded = base64.decodestring(auth_hdr[6:])
username, password = auth_decoded.split(':', 2)
if username == sickbeard.WEB_USERNAME and password == sickbeard.WEB_PASSWORD:
#logger.log('authenticated user successfully', logger.DEBUG)
if not handler.get_secure_cookie("user"):
handler.set_secure_cookie("user", str(time.time()))
return True
handler.clear_cookie("user")
get_auth()
else:
if handler.get_secure_cookie("user"):
handler.clear_cookie("user")
return _request_basic_auth(handler)
except Exception, e:
handler.clear_cookie("user")
return _request_basic_auth(handler)
return True
def _execute(self, transforms, *args, **kwargs):
if not require_basic_auth(self, kwargs):
if not basicauth(self, transforms, *args, **kwargs):
return False
return handler_execute(self, transforms, *args, **kwargs)
@ -128,12 +135,12 @@ def require_basic_auth(handler_class):
handler_class._execute = wrap_execute(handler_class._execute)
return handler_class
@require_basic_auth
class RedirectHandler(RequestHandler):
def get(self, path, **kwargs):
self.redirect(path, permanent=True)
@basicauth
class IndexHandler(RedirectHandler):
def __init__(self, application, request, **kwargs):
super(IndexHandler, self).__init__(application, request, **kwargs)
@ -154,7 +161,7 @@ class IndexHandler(RedirectHandler):
args[arg] = value[0]
return args
def _dispatch(self):
def _dispatch(self, callback):
args = None
path = self.request.uri.split('?')[0]
@ -193,32 +200,35 @@ class IndexHandler(RedirectHandler):
if func:
if args:
return func(**args)
callback(func(**args))
else:
return func()
callback(func())
if self.request.uri != ('/'):
raise HTTPError(404)
def get_response(self):
raise gen.Return('hello')
callback(HTTPError(404))
def get_current_user(self):
return self.get_secure_cookie("user")
@authenticated
@asynchronous
@gen.coroutine
@gen.engine
def get(self, *args, **kwargs):
resp = yield self.get_response()
self.finish(resp)
@gen.coroutine
def get_response(self):
raise gen.Return(self._dispatch())
try:
result = yield gen.Task(self._dispatch)
self.finish(result)
except Exception as e:
logger.log(ex(e), logger.ERROR)
logger.log(u"Traceback: " + traceback.format_exc(), logger.DEBUG)
self.finish(ex(e))
def post(self, *args, **kwargs):
self.finish(self._dispatch())
try:
result = yield gen.Task(self._dispatch)
self.finish(result)
except Exception as e:
logger.log(ex(e), logger.ERROR)
logger.log(u"Traceback: " + traceback.format_exc(), logger.DEBUG)
self.finish(ex(e))
def robots_txt(self, *args, **kwargs):
""" Keep web crawlers out """
@ -542,7 +552,6 @@ def _getEpisode(show, season=None, episode=None, absolute=None):
return epObj
def ManageMenu():
manageMenu = [
{'title': 'Backlog Overview', 'path': 'manage/backlogOverview/'},
@ -617,17 +626,6 @@ class ManageSearches(IndexHandler):
self.redirect("/manage/manageSearches/")
def forceVersionCheck(self, *args, **kwargs):
# force a check to see if there is a new version
result = sickbeard.versionCheckScheduler.action.check_for_new_version(force=True) # @UndefinedVariable
if result:
logger.log(u"Forcing version check")
self.redirect("/manage/manageSearches/")
class Manage(IndexHandler):
def index(self, *args, **kwargs):
t = PageTemplate(file="manage.tmpl")
@ -2477,6 +2475,14 @@ class HomePostProcess(IndexHandler):
return _munge(t)
def forceVersionCheck(self, *args, **kwargs):
# force a check to see if there is a new version
if sickbeard.versionCheckScheduler.action.check_for_new_version(force=True):
logger.log(u"Forcing version check")
self.redirect("/home/")
def processEpisode(self, dir=None, nzbName=None, jobName=None, quiet=None, process_method=None, force=None,
is_priority=None, failed="0", type="auto"):
@ -3304,7 +3310,6 @@ class Home(IndexHandler):
# auto-reload
tornado.autoreload.start(IOLoop.current())
tornado.autoreload.add_reload_hook(sickbeard.autoreload_shutdown)
updated = sickbeard.versionCheckScheduler.action.update() # @UndefinedVariable

View file

@ -13,6 +13,7 @@ from tornado.ioloop import IOLoop
server = None
class MultiStaticFileHandler(StaticFileHandler):
def initialize(self, paths, default_filename=None):
self.paths = paths
@ -33,6 +34,7 @@ class MultiStaticFileHandler(StaticFileHandler):
# Oops file not found anywhere!
raise HTTPError(404)
def initWebServer(options={}):
options.setdefault('port', 8081)
options.setdefault('host', '0.0.0.0')
@ -100,7 +102,6 @@ def initWebServer(options={}):
app = Application([],
debug=sickbeard.DEBUG,
gzip=True,
autoreload=sickbeard.AUTO_UPDATE,
xheaders=True,
cookie_secret='61oETzKXQAGaYdkL5gEmGeJJFuYh7EQnp2XdTP1o/Vo=',
login_url='/login'
@ -116,13 +117,15 @@ def initWebServer(options={}):
# Static Path Handler
app.add_handlers(".*$", [
('%s/%s/(.*)([^/]*)' % (options['web_root'], 'images'), MultiStaticFileHandler,
(r'/(favicon\.ico)', MultiStaticFileHandler,
{'paths': '%s/%s' % (options['web_root'], 'images/ico/favicon.ico')}),
(r'%s/%s/(.*)(/?)' % (options['web_root'], 'images'), MultiStaticFileHandler,
{'paths': [os.path.join(options['data_root'], 'images'),
os.path.join(sickbeard.CACHE_DIR, 'images'),
os.path.join(sickbeard.CACHE_DIR, 'images', 'thumbnails')]}),
('%s/%s/(.*)([^/]*)' % (options['web_root'], 'css'), MultiStaticFileHandler,
(r'%s/%s/(.*)(/?)' % (options['web_root'], 'css'), MultiStaticFileHandler,
{'paths': [os.path.join(options['data_root'], 'css')]}),
('%s/%s/(.*)([^/]*)' % (options['web_root'], 'js'), MultiStaticFileHandler,
(r'%s/%s/(.*)(/?)' % (options['web_root'], 'js'), MultiStaticFileHandler,
{'paths': [os.path.join(options['data_root'], 'js')]})
])
@ -132,7 +135,7 @@ def initWebServer(options={}):
if enable_https:
protocol = "https"
server = HTTPServer(app, no_keep_alive=True,
ssl_options={"certfile": https_cert, "keyfile": https_key})
ssl_options={"certfile": https_cert, "keyfile": https_key})
else:
protocol = "http"
server = HTTPServer(app, no_keep_alive=True)
@ -140,7 +143,11 @@ def initWebServer(options={}):
logger.log(u"Starting SickRage on " + protocol + "://" + str(options['host']) + ":" + str(
options['port']) + "/")
server.listen(options['port'], options['host'])
try:
server.listen(options['port'], options['host'])
except:
pass
def shutdown():
global server
@ -148,7 +155,6 @@ def shutdown():
logger.log('Shutting down tornado')
try:
IOLoop.current().stop()
server.stop()
except RuntimeError:
pass
except:

View file

@ -25,5 +25,5 @@ from __future__ import absolute_import, division, print_function, with_statement
# is zero for an official release, positive for a development branch,
# or negative for a release candidate or beta (after the base version
# number has been incremented)
version = "3.2.2"
version_info = (3, 2, 2, 0)
version = "4.0.dev1"
version_info = (4, 0, 0, -100)

View file

@ -34,15 +34,29 @@ See the individual service classes below for complete documentation.
Example usage for Google OpenID::
class GoogleLoginHandler(tornado.web.RequestHandler,
tornado.auth.GoogleMixin):
class GoogleOAuth2LoginHandler(tornado.web.RequestHandler,
tornado.auth.GoogleOAuth2Mixin):
@tornado.gen.coroutine
def get(self):
if self.get_argument("openid.mode", None):
user = yield self.get_authenticated_user()
# Save the user with e.g. set_secure_cookie()
if self.get_argument('code', False):
user = yield self.get_authenticated_user(
redirect_uri='http://your.site.com/auth/google',
code=self.get_argument('code'))
# Save the user with e.g. set_secure_cookie
else:
yield self.authenticate_redirect()
yield self.authorize_redirect(
redirect_uri='http://your.site.com/auth/google',
client_id=self.settings['google_oauth']['key'],
scope=['profile', 'email'],
response_type='code',
extra_params={'approval_prompt': 'auto'})
.. versionchanged:: 4.0
All of the callback interfaces in this module are now guaranteed
to run their callback with an argument of ``None`` on error.
Previously some functions would do this while others would simply
terminate the request on their own. This change also ensures that
errors are more consistently reported through the ``Future`` interfaces.
"""
from __future__ import absolute_import, division, print_function, with_statement
@ -61,6 +75,7 @@ from tornado import httpclient
from tornado import escape
from tornado.httputil import url_concat
from tornado.log import gen_log
from tornado.stack_context import ExceptionStackContext
from tornado.util import bytes_type, u, unicode_type, ArgReplacer
try:
@ -73,6 +88,11 @@ try:
except ImportError:
import urllib as urllib_parse # py2
try:
long # py2
except NameError:
long = int # py3
class AuthError(Exception):
pass
@ -103,7 +123,14 @@ def _auth_return_future(f):
if callback is not None:
future.add_done_callback(
functools.partial(_auth_future_to_callback, callback))
f(*args, **kwargs)
def handle_exception(typ, value, tb):
if future.done():
return False
else:
future.set_exc_info((typ, value, tb))
return True
with ExceptionStackContext(handle_exception):
f(*args, **kwargs)
return future
return wrapper
@ -161,7 +188,7 @@ class OpenIdMixin(object):
url = self._OPENID_ENDPOINT
if http_client is None:
http_client = self.get_auth_http_client()
http_client.fetch(url, self.async_callback(
http_client.fetch(url, functools.partial(
self._on_authentication_verified, callback),
method="POST", body=urllib_parse.urlencode(args))
@ -333,7 +360,7 @@ class OAuthMixin(object):
http_client.fetch(
self._oauth_request_token_url(callback_uri=callback_uri,
extra_params=extra_params),
self.async_callback(
functools.partial(
self._on_request_token,
self._OAUTH_AUTHORIZE_URL,
callback_uri,
@ -341,7 +368,7 @@ class OAuthMixin(object):
else:
http_client.fetch(
self._oauth_request_token_url(),
self.async_callback(
functools.partial(
self._on_request_token, self._OAUTH_AUTHORIZE_URL,
callback_uri,
callback))
@ -378,7 +405,7 @@ class OAuthMixin(object):
if http_client is None:
http_client = self.get_auth_http_client()
http_client.fetch(self._oauth_access_token_url(token),
self.async_callback(self._on_access_token, callback))
functools.partial(self._on_access_token, callback))
def _oauth_request_token_url(self, callback_uri=None, extra_params=None):
consumer_token = self._oauth_consumer_token()
@ -455,7 +482,7 @@ class OAuthMixin(object):
access_token = _oauth_parse_response(response.body)
self._oauth_get_user_future(access_token).add_done_callback(
self.async_callback(self._on_oauth_get_user, access_token, future))
functools.partial(self._on_oauth_get_user, access_token, future))
def _oauth_consumer_token(self):
"""Subclasses must override this to return their OAuth consumer keys.
@ -640,7 +667,7 @@ class TwitterMixin(OAuthMixin):
"""
http = self.get_auth_http_client()
http.fetch(self._oauth_request_token_url(callback_uri=callback_uri),
self.async_callback(
functools.partial(
self._on_request_token, self._OAUTH_AUTHENTICATE_URL,
None, callback))
@ -698,7 +725,7 @@ class TwitterMixin(OAuthMixin):
if args:
url += "?" + urllib_parse.urlencode(args)
http = self.get_auth_http_client()
http_callback = self.async_callback(self._on_twitter_request, callback)
http_callback = functools.partial(self._on_twitter_request, callback)
if post_args is not None:
http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args),
callback=http_callback)
@ -815,7 +842,7 @@ class FriendFeedMixin(OAuthMixin):
args.update(oauth)
if args:
url += "?" + urllib_parse.urlencode(args)
callback = self.async_callback(self._on_friendfeed_request, callback)
callback = functools.partial(self._on_friendfeed_request, callback)
http = self.get_auth_http_client()
if post_args is not None:
http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args),
@ -856,6 +883,10 @@ class FriendFeedMixin(OAuthMixin):
class GoogleMixin(OpenIdMixin, OAuthMixin):
"""Google Open ID / OAuth authentication.
*Deprecated:* New applications should use `GoogleOAuth2Mixin`
below instead of this class. As of May 19, 2014, Google has stopped
supporting registration-free authentication.
No application registration is necessary to use Google for
authentication or to access Google resources on behalf of a user.
@ -926,7 +957,7 @@ class GoogleMixin(OpenIdMixin, OAuthMixin):
http = self.get_auth_http_client()
token = dict(key=token, secret="")
http.fetch(self._oauth_access_token_url(token),
self.async_callback(self._on_access_token, callback))
functools.partial(self._on_access_token, callback))
else:
chain_future(OpenIdMixin.get_authenticated_user(self),
callback)
@ -945,6 +976,19 @@ class GoogleMixin(OpenIdMixin, OAuthMixin):
class GoogleOAuth2Mixin(OAuth2Mixin):
"""Google authentication using OAuth2.
In order to use, register your application with Google and copy the
relevant parameters to your application settings.
* Go to the Google Dev Console at http://console.developers.google.com
* Select a project, or create a new one.
* In the sidebar on the left, select APIs & Auth.
* In the list of APIs, find the Google+ API service and set it to ON.
* In the sidebar on the left, select Credentials.
* In the OAuth section of the page, select Create New Client ID.
* Set the Redirect URI to point to your auth handler
* Copy the "Client secret" and "Client ID" to the application settings as
{"google_oauth": {"key": CLIENT_ID, "secret": CLIENT_SECRET}}
.. versionadded:: 3.2
"""
_OAUTH_AUTHORIZE_URL = "https://accounts.google.com/o/oauth2/auth"
@ -958,7 +1002,7 @@ class GoogleOAuth2Mixin(OAuth2Mixin):
Example usage::
class GoogleOAuth2LoginHandler(LoginHandler,
class GoogleOAuth2LoginHandler(tornado.web.RequestHandler,
tornado.auth.GoogleOAuth2Mixin):
@tornado.gen.coroutine
def get(self):
@ -985,7 +1029,7 @@ class GoogleOAuth2Mixin(OAuth2Mixin):
})
http.fetch(self._OAUTH_ACCESS_TOKEN_URL,
self.async_callback(self._on_access_token, callback),
functools.partial(self._on_access_token, callback),
method="POST", headers={'Content-Type': 'application/x-www-form-urlencoded'}, body=body)
def _on_access_token(self, future, response):
@ -1026,7 +1070,7 @@ class FacebookMixin(object):
@tornado.web.asynchronous
def get(self):
if self.get_argument("session", None):
self.get_authenticated_user(self.async_callback(self._on_auth))
self.get_authenticated_user(self._on_auth)
return
yield self.authenticate_redirect()
@ -1112,7 +1156,7 @@ class FacebookMixin(object):
session = escape.json_decode(self.get_argument("session"))
self.facebook_request(
method="facebook.users.getInfo",
callback=self.async_callback(
callback=functools.partial(
self._on_get_user_info, callback, session),
session_key=session["session_key"],
uids=session["uid"],
@ -1138,7 +1182,7 @@ class FacebookMixin(object):
def get(self):
self.facebook_request(
method="stream.get",
callback=self.async_callback(self._on_stream),
callback=self._on_stream,
session_key=self.current_user["session_key"])
def _on_stream(self, stream):
@ -1162,7 +1206,7 @@ class FacebookMixin(object):
url = "http://api.facebook.com/restserver.php?" + \
urllib_parse.urlencode(args)
http = self.get_auth_http_client()
http.fetch(url, callback=self.async_callback(
http.fetch(url, callback=functools.partial(
self._parse_response, callback))
def _on_get_user_info(self, callback, session, users):
@ -1260,7 +1304,7 @@ class FacebookGraphMixin(OAuth2Mixin):
fields.update(extra_fields)
http.fetch(self._oauth_request_token_url(**args),
self.async_callback(self._on_access_token, redirect_uri, client_id,
functools.partial(self._on_access_token, redirect_uri, client_id,
client_secret, callback, fields))
def _on_access_token(self, redirect_uri, client_id, client_secret,
@ -1277,7 +1321,7 @@ class FacebookGraphMixin(OAuth2Mixin):
self.facebook_request(
path="/me",
callback=self.async_callback(
callback=functools.partial(
self._on_get_user_info, future, session, fields),
access_token=session["access_token"],
fields=",".join(fields)
@ -1344,7 +1388,7 @@ class FacebookGraphMixin(OAuth2Mixin):
if all_args:
url += "?" + urllib_parse.urlencode(all_args)
callback = self.async_callback(self._on_facebook_request, callback)
callback = functools.partial(self._on_facebook_request, callback)
http = self.get_auth_http_client()
if post_args is not None:
http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args),

View file

@ -14,7 +14,7 @@
# License for the specific language governing permissions and limitations
# under the License.
"""xAutomatically restart the server when a source file is modified.
"""Automatically restart the server when a source file is modified.
Most applications should not access this module directly. Instead,
pass the keyword argument ``autoreload=True`` to the

View file

@ -40,52 +40,132 @@ class ReturnValueIgnoredError(Exception):
pass
class _DummyFuture(object):
class Future(object):
"""Placeholder for an asynchronous result.
A ``Future`` encapsulates the result of an asynchronous
operation. In synchronous applications ``Futures`` are used
to wait for the result from a thread or process pool; in
Tornado they are normally used with `.IOLoop.add_future` or by
yielding them in a `.gen.coroutine`.
`tornado.concurrent.Future` is similar to
`concurrent.futures.Future`, but not thread-safe (and therefore
faster for use with single-threaded event loops).
In addition to ``exception`` and ``set_exception``, methods ``exc_info``
and ``set_exc_info`` are supported to capture tracebacks in Python 2.
The traceback is automatically available in Python 3, but in the
Python 2 futures backport this information is discarded.
This functionality was previously available in a separate class
``TracebackFuture``, which is now a deprecated alias for this class.
.. versionchanged:: 4.0
`tornado.concurrent.Future` is always a thread-unsafe ``Future``
with support for the ``exc_info`` methods. Previously it would
be an alias for the thread-safe `concurrent.futures.Future`
if that package was available and fall back to the thread-unsafe
implementation if it was not.
"""
def __init__(self):
self._done = False
self._result = None
self._exception = None
self._exc_info = None
self._callbacks = []
def cancel(self):
"""Cancel the operation, if possible.
Tornado ``Futures`` do not support cancellation, so this method always
returns False.
"""
return False
def cancelled(self):
"""Returns True if the operation has been cancelled.
Tornado ``Futures`` do not support cancellation, so this method
always returns False.
"""
return False
def running(self):
"""Returns True if this operation is currently running."""
return not self._done
def done(self):
"""Returns True if the future has finished running."""
return self._done
def result(self, timeout=None):
self._check_done()
if self._exception:
"""If the operation succeeded, return its result. If it failed,
re-raise its exception.
"""
if self._result is not None:
return self._result
if self._exc_info is not None:
raise_exc_info(self._exc_info)
elif self._exception is not None:
raise self._exception
self._check_done()
return self._result
def exception(self, timeout=None):
self._check_done()
if self._exception:
"""If the operation raised an exception, return the `Exception`
object. Otherwise returns None.
"""
if self._exception is not None:
return self._exception
else:
self._check_done()
return None
def add_done_callback(self, fn):
"""Attaches the given callback to the `Future`.
It will be invoked with the `Future` as its argument when the Future
has finished running and its result is available. In Tornado
consider using `.IOLoop.add_future` instead of calling
`add_done_callback` directly.
"""
if self._done:
fn(self)
else:
self._callbacks.append(fn)
def set_result(self, result):
"""Sets the result of a ``Future``.
It is undefined to call any of the ``set`` methods more than once
on the same object.
"""
self._result = result
self._set_done()
def set_exception(self, exception):
"""Sets the exception of a ``Future.``"""
self._exception = exception
self._set_done()
def exc_info(self):
"""Returns a tuple in the same format as `sys.exc_info` or None.
.. versionadded:: 4.0
"""
return self._exc_info
def set_exc_info(self, exc_info):
"""Sets the exception information of a ``Future.``
Preserves tracebacks on Python 2.
.. versionadded:: 4.0
"""
self._exc_info = exc_info
self.set_exception(exc_info[1])
def _check_done(self):
if not self._done:
raise Exception("DummyFuture does not support blocking for results")
@ -97,38 +177,16 @@ class _DummyFuture(object):
cb(self)
self._callbacks = None
TracebackFuture = Future
if futures is None:
Future = _DummyFuture
FUTURES = Future
else:
Future = futures.Future
FUTURES = (futures.Future, Future)
class TracebackFuture(Future):
"""Subclass of `Future` which can store a traceback with
exceptions.
The traceback is automatically available in Python 3, but in the
Python 2 futures backport this information is discarded.
"""
def __init__(self):
super(TracebackFuture, self).__init__()
self.__exc_info = None
def exc_info(self):
return self.__exc_info
def set_exc_info(self, exc_info):
"""Traceback-aware replacement for
`~concurrent.futures.Future.set_exception`.
"""
self.__exc_info = exc_info
self.set_exception(exc_info[1])
def result(self, timeout=None):
if self.__exc_info is not None:
raise_exc_info(self.__exc_info)
else:
return super(TracebackFuture, self).result(timeout=timeout)
def is_future(x):
return isinstance(x, FUTURES)
class DummyExecutor(object):
@ -254,10 +312,13 @@ def return_future(f):
def chain_future(a, b):
"""Chain two futures together so that when one completes, so does the other.
The result (success or failure) of ``a`` will be copied to ``b``.
The result (success or failure) of ``a`` will be copied to ``b``, unless
``b`` has already been completed or cancelled by the time ``a`` finishes.
"""
def copy(future):
assert future is a
if b.done():
return
if (isinstance(a, TracebackFuture) and isinstance(b, TracebackFuture)
and a.exc_info() is not None):
b.set_exc_info(a.exc_info())

View file

@ -87,7 +87,6 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
for curl in self._curls:
curl.close()
self._multi.close()
self._closed = True
super(CurlAsyncHTTPClient, self).close()
def fetch_impl(self, request, callback):
@ -268,6 +267,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
info["callback"](HTTPResponse(
request=info["request"], code=code, headers=info["headers"],
buffer=buffer, effective_url=effective_url, error=error,
reason=info['headers'].get("X-Http-Reason", None),
request_time=time.time() - info["curl_start_time"],
time_info=time_info))
except Exception:
@ -470,7 +470,11 @@ def _curl_header_callback(headers, header_line):
header_line = header_line.strip()
if header_line.startswith("HTTP/"):
headers.clear()
return
try:
(__, __, reason) = httputil.parse_response_start_line(header_line)
header_line = "X-Http-Reason: %s" % reason
except httputil.HTTPInputError:
return
if not header_line:
return
headers.parse_line(header_line)

View file

@ -87,9 +87,9 @@ import itertools
import sys
import types
from tornado.concurrent import Future, TracebackFuture
from tornado.concurrent import Future, TracebackFuture, is_future, chain_future
from tornado.ioloop import IOLoop
from tornado.stack_context import ExceptionStackContext, wrap
from tornado import stack_context
class KeyReuseError(Exception):
@ -112,6 +112,10 @@ class ReturnValueIgnoredError(Exception):
pass
class TimeoutError(Exception):
"""Exception raised by ``with_timeout``."""
def engine(func):
"""Callback-oriented decorator for asynchronous generators.
@ -129,45 +133,20 @@ def engine(func):
`~tornado.web.RequestHandler` :ref:`HTTP verb methods <verbs>`,
which use ``self.finish()`` in place of a callback argument.
"""
func = _make_coroutine_wrapper(func, replace_callback=False)
@functools.wraps(func)
def wrapper(*args, **kwargs):
runner = None
def handle_exception(typ, value, tb):
# if the function throws an exception before its first "yield"
# (or is not a generator at all), the Runner won't exist yet.
# However, in that case we haven't reached anything asynchronous
# yet, so we can just let the exception propagate.
if runner is not None:
return runner.handle_exception(typ, value, tb)
return False
with ExceptionStackContext(handle_exception) as deactivate:
try:
result = func(*args, **kwargs)
except (Return, StopIteration) as e:
result = getattr(e, 'value', None)
else:
if isinstance(result, types.GeneratorType):
def final_callback(value):
if value is not None:
raise ReturnValueIgnoredError(
"@gen.engine functions cannot return values: "
"%r" % (value,))
assert value is None
deactivate()
runner = Runner(result, final_callback)
runner.run()
return
if result is not None:
future = func(*args, **kwargs)
def final_callback(future):
if future.result() is not None:
raise ReturnValueIgnoredError(
"@gen.engine functions cannot return values: %r" %
(result,))
deactivate()
# no yield, so we're done
(future.result(),))
future.add_done_callback(final_callback)
return wrapper
def coroutine(func):
def coroutine(func, replace_callback=True):
"""Decorator for asynchronous generators.
Any generator that yields objects from this module must be wrapped
@ -191,43 +170,56 @@ def coroutine(func):
From the caller's perspective, ``@gen.coroutine`` is similar to
the combination of ``@return_future`` and ``@gen.engine``.
"""
return _make_coroutine_wrapper(func, replace_callback=True)
def _make_coroutine_wrapper(func, replace_callback):
"""The inner workings of ``@gen.coroutine`` and ``@gen.engine``.
The two decorators differ in their treatment of the ``callback``
argument, so we cannot simply implement ``@engine`` in terms of
``@coroutine``.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
runner = None
future = TracebackFuture()
if 'callback' in kwargs:
if replace_callback and 'callback' in kwargs:
callback = kwargs.pop('callback')
IOLoop.current().add_future(
future, lambda future: callback(future.result()))
def handle_exception(typ, value, tb):
try:
if runner is not None and runner.handle_exception(typ, value, tb):
return True
except Exception:
typ, value, tb = sys.exc_info()
future.set_exc_info((typ, value, tb))
return True
with ExceptionStackContext(handle_exception) as deactivate:
try:
result = func(*args, **kwargs)
except (Return, StopIteration) as e:
result = getattr(e, 'value', None)
except Exception:
deactivate()
future.set_exc_info(sys.exc_info())
try:
result = func(*args, **kwargs)
except (Return, StopIteration) as e:
result = getattr(e, 'value', None)
except Exception:
future.set_exc_info(sys.exc_info())
return future
else:
if isinstance(result, types.GeneratorType):
# Inline the first iteration of Runner.run. This lets us
# avoid the cost of creating a Runner when the coroutine
# never actually yields, which in turn allows us to
# use "optional" coroutines in critical path code without
# performance penalty for the synchronous case.
try:
orig_stack_contexts = stack_context._state.contexts
yielded = next(result)
if stack_context._state.contexts is not orig_stack_contexts:
yielded = TracebackFuture()
yielded.set_exception(
stack_context.StackContextInconsistentError(
'stack_context inconsistency (probably caused '
'by yield within a "with StackContext" block)'))
except (StopIteration, Return) as e:
future.set_result(getattr(e, 'value', None))
except Exception:
future.set_exc_info(sys.exc_info())
else:
Runner(result, future, yielded)
return future
else:
if isinstance(result, types.GeneratorType):
def final_callback(value):
deactivate()
future.set_result(value)
runner = Runner(result, final_callback)
runner.run()
return future
deactivate()
future.set_result(result)
future.set_result(result)
return future
return wrapper
@ -348,7 +340,7 @@ class WaitAll(YieldPoint):
return [self.runner.pop_result(key) for key in self.keys]
class Task(YieldPoint):
def Task(func, *args, **kwargs):
"""Runs a single asynchronous operation.
Takes a function (and optional additional arguments) and runs it with
@ -362,25 +354,25 @@ class Task(YieldPoint):
func(args, callback=(yield gen.Callback(key)))
result = yield gen.Wait(key)
.. versionchanged:: 4.0
``gen.Task`` is now a function that returns a `.Future`, instead of
a subclass of `YieldPoint`. It still behaves the same way when
yielded.
"""
def __init__(self, func, *args, **kwargs):
assert "callback" not in kwargs
self.args = args
self.kwargs = kwargs
self.func = func
def start(self, runner):
self.runner = runner
self.key = object()
runner.register_callback(self.key)
self.kwargs["callback"] = runner.result_callback(self.key)
self.func(*self.args, **self.kwargs)
def is_ready(self):
return self.runner.is_ready(self.key)
def get_result(self):
return self.runner.pop_result(self.key)
future = Future()
def handle_exception(typ, value, tb):
if future.done():
return False
future.set_exc_info((typ, value, tb))
return True
def set_result(result):
if future.done():
return
future.set_result(result)
with stack_context.ExceptionStackContext(handle_exception):
func(*args, callback=_argument_adapter(set_result), **kwargs)
return future
class YieldFuture(YieldPoint):
@ -414,10 +406,14 @@ class YieldFuture(YieldPoint):
class Multi(YieldPoint):
"""Runs multiple asynchronous operations in parallel.
Takes a list of ``Tasks`` or other ``YieldPoints`` and returns a list of
Takes a list of ``YieldPoints`` or ``Futures`` and returns a list of
their responses. It is not necessary to call `Multi` explicitly,
since the engine will do so automatically when the generator yields
a list of ``YieldPoints``.
a list of ``YieldPoints`` or a mixture of ``YieldPoints`` and ``Futures``.
Instead of a list, the argument may also be a dictionary whose values are
Futures, in which case a parallel dictionary is returned mapping the same
keys to their results.
"""
def __init__(self, children):
self.keys = None
@ -426,7 +422,7 @@ class Multi(YieldPoint):
children = children.values()
self.children = []
for i in children:
if isinstance(i, Future):
if is_future(i):
i = YieldFuture(i)
self.children.append(i)
assert all(isinstance(i, YieldPoint) for i in self.children)
@ -450,18 +446,127 @@ class Multi(YieldPoint):
return list(result)
class _NullYieldPoint(YieldPoint):
def start(self, runner):
pass
def multi_future(children):
"""Wait for multiple asynchronous futures in parallel.
def is_ready(self):
return True
Takes a list of ``Futures`` (but *not* other ``YieldPoints``) and returns
a new Future that resolves when all the other Futures are done.
If all the ``Futures`` succeeded, the returned Future's result is a list
of their results. If any failed, the returned Future raises the exception
of the first one to fail.
def get_result(self):
return None
Instead of a list, the argument may also be a dictionary whose values are
Futures, in which case a parallel dictionary is returned mapping the same
keys to their results.
It is not necessary to call `multi_future` explcitly, since the engine will
do so automatically when the generator yields a list of `Futures`.
This function is faster than the `Multi` `YieldPoint` because it does not
require the creation of a stack context.
.. versionadded:: 4.0
"""
if isinstance(children, dict):
keys = list(children.keys())
children = children.values()
else:
keys = None
assert all(is_future(i) for i in children)
unfinished_children = set(children)
future = Future()
if not children:
future.set_result({} if keys is not None else [])
def callback(f):
unfinished_children.remove(f)
if not unfinished_children:
try:
result_list = [i.result() for i in children]
except Exception:
future.set_exc_info(sys.exc_info())
else:
if keys is not None:
future.set_result(dict(zip(keys, result_list)))
else:
future.set_result(result_list)
for f in children:
f.add_done_callback(callback)
return future
_null_yield_point = _NullYieldPoint()
def maybe_future(x):
"""Converts ``x`` into a `.Future`.
If ``x`` is already a `.Future`, it is simply returned; otherwise
it is wrapped in a new `.Future`. This is suitable for use as
``result = yield gen.maybe_future(f())`` when you don't know whether
``f()`` returns a `.Future` or not.
"""
if is_future(x):
return x
else:
fut = Future()
fut.set_result(x)
return fut
def with_timeout(timeout, future, io_loop=None):
"""Wraps a `.Future` in a timeout.
Raises `TimeoutError` if the input future does not complete before
``timeout``, which may be specified in any form allowed by
`.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or an absolute time
relative to `.IOLoop.time`)
Currently only supports Futures, not other `YieldPoint` classes.
.. versionadded:: 4.0
"""
# TODO: allow yield points in addition to futures?
# Tricky to do with stack_context semantics.
#
# It's tempting to optimize this by cancelling the input future on timeout
# instead of creating a new one, but A) we can't know if we are the only
# one waiting on the input future, so cancelling it might disrupt other
# callers and B) concurrent futures can only be cancelled while they are
# in the queue, so cancellation cannot reliably bound our waiting time.
result = Future()
chain_future(future, result)
if io_loop is None:
io_loop = IOLoop.current()
timeout_handle = io_loop.add_timeout(
timeout,
lambda: result.set_exception(TimeoutError("Timeout")))
if isinstance(future, Future):
# We know this future will resolve on the IOLoop, so we don't
# need the extra thread-safety of IOLoop.add_future (and we also
# don't care about StackContext here.
future.add_done_callback(
lambda future: io_loop.remove_timeout(timeout_handle))
else:
# concurrent.futures.Futures may resolve on any thread, so we
# need to route them back to the IOLoop.
io_loop.add_future(
future, lambda future: io_loop.remove_timeout(timeout_handle))
return result
_null_future = Future()
_null_future.set_result(None)
moment = Future()
moment.__doc__ = \
"""A special object which may be yielded to allow the IOLoop to run for
one iteration.
This is not needed in normal use but it can be helpful in long-running
coroutines that are likely to yield Futures that are ready instantly.
Usage: ``yield gen.moment``
.. versionadded:: 4.0
"""
moment.set_result(None)
class Runner(object):
@ -469,35 +574,55 @@ class Runner(object):
Maintains information about pending callbacks and their results.
``final_callback`` is run after the generator exits.
The results of the generator are stored in ``result_future`` (a
`.TracebackFuture`)
"""
def __init__(self, gen, final_callback):
def __init__(self, gen, result_future, first_yielded):
self.gen = gen
self.final_callback = final_callback
self.yield_point = _null_yield_point
self.pending_callbacks = set()
self.results = {}
self.result_future = result_future
self.future = _null_future
self.yield_point = None
self.pending_callbacks = None
self.results = None
self.running = False
self.finished = False
self.exc_info = None
self.had_exception = False
self.io_loop = IOLoop.current()
# For efficiency, we do not create a stack context until we
# reach a YieldPoint (stack contexts are required for the historical
# semantics of YieldPoints, but not for Futures). When we have
# done so, this field will be set and must be called at the end
# of the coroutine.
self.stack_context_deactivate = None
if self.handle_yield(first_yielded):
self.run()
def register_callback(self, key):
"""Adds ``key`` to the list of callbacks."""
if self.pending_callbacks is None:
# Lazily initialize the old-style YieldPoint data structures.
self.pending_callbacks = set()
self.results = {}
if key in self.pending_callbacks:
raise KeyReuseError("key %r is already pending" % (key,))
self.pending_callbacks.add(key)
def is_ready(self, key):
"""Returns true if a result is available for ``key``."""
if key not in self.pending_callbacks:
if self.pending_callbacks is None or key not in self.pending_callbacks:
raise UnknownKeyError("key %r is not pending" % (key,))
return key in self.results
def set_result(self, key, result):
"""Sets the result for ``key`` and attempts to resume the generator."""
self.results[key] = result
self.run()
if self.yield_point is not None and self.yield_point.is_ready():
try:
self.future.set_result(self.yield_point.get_result())
except:
self.future.set_exc_info(sys.exc_info())
self.yield_point = None
self.run()
def pop_result(self, key):
"""Returns the result for ``key`` and unregisters it."""
@ -513,25 +638,27 @@ class Runner(object):
try:
self.running = True
while True:
if self.exc_info is None:
try:
if not self.yield_point.is_ready():
return
next = self.yield_point.get_result()
self.yield_point = None
except Exception:
self.exc_info = sys.exc_info()
future = self.future
if not future.done():
return
self.future = None
try:
if self.exc_info is not None:
orig_stack_contexts = stack_context._state.contexts
try:
value = future.result()
except Exception:
self.had_exception = True
exc_info = self.exc_info
self.exc_info = None
yielded = self.gen.throw(*exc_info)
yielded = self.gen.throw(*sys.exc_info())
else:
yielded = self.gen.send(next)
yielded = self.gen.send(value)
if stack_context._state.contexts is not orig_stack_contexts:
self.gen.throw(
stack_context.StackContextInconsistentError(
'stack_context inconsistency (probably caused '
'by yield within a "with StackContext" block)'))
except (StopIteration, Return) as e:
self.finished = True
self.yield_point = _null_yield_point
self.future = _null_future
if self.pending_callbacks and not self.had_exception:
# If we ran cleanly without waiting on all callbacks
# raise an error (really more of a warning). If we
@ -540,46 +667,105 @@ class Runner(object):
raise LeakedCallbackError(
"finished without waiting for callbacks %r" %
self.pending_callbacks)
self.final_callback(getattr(e, 'value', None))
self.final_callback = None
self.result_future.set_result(getattr(e, 'value', None))
self.result_future = None
self._deactivate_stack_context()
return
except Exception:
self.finished = True
self.yield_point = _null_yield_point
raise
if isinstance(yielded, (list, dict)):
yielded = Multi(yielded)
elif isinstance(yielded, Future):
yielded = YieldFuture(yielded)
if isinstance(yielded, YieldPoint):
self.yield_point = yielded
try:
self.yield_point.start(self)
except Exception:
self.exc_info = sys.exc_info()
else:
self.exc_info = (BadYieldError(
"yielded unknown object %r" % (yielded,)),)
self.future = _null_future
self.result_future.set_exc_info(sys.exc_info())
self.result_future = None
self._deactivate_stack_context()
return
if not self.handle_yield(yielded):
return
finally:
self.running = False
def result_callback(self, key):
def inner(*args, **kwargs):
if kwargs or len(args) > 1:
result = Arguments(args, kwargs)
elif args:
result = args[0]
def handle_yield(self, yielded):
if isinstance(yielded, list):
if all(is_future(f) for f in yielded):
yielded = multi_future(yielded)
else:
result = None
self.set_result(key, result)
return wrap(inner)
yielded = Multi(yielded)
elif isinstance(yielded, dict):
if all(is_future(f) for f in yielded.values()):
yielded = multi_future(yielded)
else:
yielded = Multi(yielded)
if isinstance(yielded, YieldPoint):
self.future = TracebackFuture()
def start_yield_point():
try:
yielded.start(self)
if yielded.is_ready():
self.future.set_result(
yielded.get_result())
else:
self.yield_point = yielded
except Exception:
self.future = TracebackFuture()
self.future.set_exc_info(sys.exc_info())
if self.stack_context_deactivate is None:
# Start a stack context if this is the first
# YieldPoint we've seen.
with stack_context.ExceptionStackContext(
self.handle_exception) as deactivate:
self.stack_context_deactivate = deactivate
def cb():
start_yield_point()
self.run()
self.io_loop.add_callback(cb)
return False
else:
start_yield_point()
elif is_future(yielded):
self.future = yielded
if not self.future.done() or self.future is moment:
self.io_loop.add_future(
self.future, lambda f: self.run())
return False
else:
self.future = TracebackFuture()
self.future.set_exception(BadYieldError(
"yielded unknown object %r" % (yielded,)))
return True
def result_callback(self, key):
return stack_context.wrap(_argument_adapter(
functools.partial(self.set_result, key)))
def handle_exception(self, typ, value, tb):
if not self.running and not self.finished:
self.exc_info = (typ, value, tb)
self.future = TracebackFuture()
self.future.set_exc_info((typ, value, tb))
self.run()
return True
else:
return False
def _deactivate_stack_context(self):
if self.stack_context_deactivate is not None:
self.stack_context_deactivate()
self.stack_context_deactivate = None
Arguments = collections.namedtuple('Arguments', ['args', 'kwargs'])
def _argument_adapter(callback):
"""Returns a function that when invoked runs ``callback`` with one arg.
If the function returned by this function is called with exactly
one argument, that argument is passed to ``callback``. Otherwise
the args tuple and kwargs dict are wrapped in an `Arguments` object.
"""
def wrapper(*args, **kwargs):
if kwargs or len(args) > 1:
callback(Arguments(args, kwargs))
elif args:
callback(args[0])
else:
callback(None)
return wrapper

650
tornado/http1connection.py Normal file
View file

@ -0,0 +1,650 @@
#!/usr/bin/env python
#
# Copyright 2014 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Client and server implementations of HTTP/1.x.
.. versionadded:: 4.0
"""
from __future__ import absolute_import, division, print_function, with_statement
from tornado.concurrent import Future
from tornado.escape import native_str, utf8
from tornado import gen
from tornado import httputil
from tornado import iostream
from tornado.log import gen_log, app_log
from tornado import stack_context
from tornado.util import GzipDecompressor
class _QuietException(Exception):
def __init__(self):
pass
class _ExceptionLoggingContext(object):
"""Used with the ``with`` statement when calling delegate methods to
log any exceptions with the given logger. Any exceptions caught are
converted to _QuietException
"""
def __init__(self, logger):
self.logger = logger
def __enter__(self):
pass
def __exit__(self, typ, value, tb):
if value is not None:
self.logger.error("Uncaught exception", exc_info=(typ, value, tb))
raise _QuietException
class HTTP1ConnectionParameters(object):
"""Parameters for `.HTTP1Connection` and `.HTTP1ServerConnection`.
"""
def __init__(self, no_keep_alive=False, chunk_size=None,
max_header_size=None, header_timeout=None, max_body_size=None,
body_timeout=None, use_gzip=False):
"""
:arg bool no_keep_alive: If true, always close the connection after
one request.
:arg int chunk_size: how much data to read into memory at once
:arg int max_header_size: maximum amount of data for HTTP headers
:arg float header_timeout: how long to wait for all headers (seconds)
:arg int max_body_size: maximum amount of data for body
:arg float body_timeout: how long to wait while reading body (seconds)
:arg bool use_gzip: if true, decode incoming ``Content-Encoding: gzip``
"""
self.no_keep_alive = no_keep_alive
self.chunk_size = chunk_size or 65536
self.max_header_size = max_header_size or 65536
self.header_timeout = header_timeout
self.max_body_size = max_body_size
self.body_timeout = body_timeout
self.use_gzip = use_gzip
class HTTP1Connection(httputil.HTTPConnection):
"""Implements the HTTP/1.x protocol.
This class can be on its own for clients, or via `HTTP1ServerConnection`
for servers.
"""
def __init__(self, stream, is_client, params=None, context=None):
"""
:arg stream: an `.IOStream`
:arg bool is_client: client or server
:arg params: a `.HTTP1ConnectionParameters` instance or ``None``
:arg context: an opaque application-defined object that can be accessed
as ``connection.context``.
"""
self.is_client = is_client
self.stream = stream
if params is None:
params = HTTP1ConnectionParameters()
self.params = params
self.context = context
self.no_keep_alive = params.no_keep_alive
# The body limits can be altered by the delegate, so save them
# here instead of just referencing self.params later.
self._max_body_size = (self.params.max_body_size or
self.stream.max_buffer_size)
self._body_timeout = self.params.body_timeout
# _write_finished is set to True when finish() has been called,
# i.e. there will be no more data sent. Data may still be in the
# stream's write buffer.
self._write_finished = False
# True when we have read the entire incoming body.
self._read_finished = False
# _finish_future resolves when all data has been written and flushed
# to the IOStream.
self._finish_future = Future()
# If true, the connection should be closed after this request
# (after the response has been written in the server side,
# and after it has been read in the client)
self._disconnect_on_finish = False
self._clear_callbacks()
# Save the start lines after we read or write them; they
# affect later processing (e.g. 304 responses and HEAD methods
# have content-length but no bodies)
self._request_start_line = None
self._response_start_line = None
self._request_headers = None
# True if we are writing output with chunked encoding.
self._chunking_output = None
# While reading a body with a content-length, this is the
# amount left to read.
self._expected_content_remaining = None
# A Future for our outgoing writes, returned by IOStream.write.
self._pending_write = None
def read_response(self, delegate):
"""Read a single HTTP response.
Typical client-mode usage is to write a request using `write_headers`,
`write`, and `finish`, and then call ``read_response``.
:arg delegate: a `.HTTPMessageDelegate`
Returns a `.Future` that resolves to None after the full response has
been read.
"""
if self.params.use_gzip:
delegate = _GzipMessageDelegate(delegate, self.params.chunk_size)
return self._read_message(delegate)
@gen.coroutine
def _read_message(self, delegate):
need_delegate_close = False
try:
header_future = self.stream.read_until_regex(
b"\r?\n\r?\n",
max_bytes=self.params.max_header_size)
if self.params.header_timeout is None:
header_data = yield header_future
else:
try:
header_data = yield gen.with_timeout(
self.stream.io_loop.time() + self.params.header_timeout,
header_future,
io_loop=self.stream.io_loop)
except gen.TimeoutError:
self.close()
raise gen.Return(False)
start_line, headers = self._parse_headers(header_data)
if self.is_client:
start_line = httputil.parse_response_start_line(start_line)
self._response_start_line = start_line
else:
start_line = httputil.parse_request_start_line(start_line)
self._request_start_line = start_line
self._request_headers = headers
self._disconnect_on_finish = not self._can_keep_alive(
start_line, headers)
need_delegate_close = True
with _ExceptionLoggingContext(app_log):
header_future = delegate.headers_received(start_line, headers)
if header_future is not None:
yield header_future
if self.stream is None:
# We've been detached.
need_delegate_close = False
raise gen.Return(False)
skip_body = False
if self.is_client:
if (self._request_start_line is not None and
self._request_start_line.method == 'HEAD'):
skip_body = True
code = start_line.code
if code == 304:
skip_body = True
if code >= 100 and code < 200:
# TODO: client delegates will get headers_received twice
# in the case of a 100-continue. Document or change?
yield self._read_message(delegate)
else:
if (headers.get("Expect") == "100-continue" and
not self._write_finished):
self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
if not skip_body:
body_future = self._read_body(headers, delegate)
if body_future is not None:
if self._body_timeout is None:
yield body_future
else:
try:
yield gen.with_timeout(
self.stream.io_loop.time() + self._body_timeout,
body_future, self.stream.io_loop)
except gen.TimeoutError:
gen_log.info("Timeout reading body from %s",
self.context)
self.stream.close()
raise gen.Return(False)
self._read_finished = True
if not self._write_finished or self.is_client:
need_delegate_close = False
with _ExceptionLoggingContext(app_log):
delegate.finish()
# If we're waiting for the application to produce an asynchronous
# response, and we're not detached, register a close callback
# on the stream (we didn't need one while we were reading)
if (not self._finish_future.done() and
self.stream is not None and
not self.stream.closed()):
self.stream.set_close_callback(self._on_connection_close)
yield self._finish_future
if self.is_client and self._disconnect_on_finish:
self.close()
if self.stream is None:
raise gen.Return(False)
except httputil.HTTPInputError as e:
gen_log.info("Malformed HTTP message from %s: %s",
self.context, e)
self.close()
raise gen.Return(False)
finally:
if need_delegate_close:
with _ExceptionLoggingContext(app_log):
delegate.on_connection_close()
self._clear_callbacks()
raise gen.Return(True)
def _clear_callbacks(self):
"""Clears the callback attributes.
This allows the request handler to be garbage collected more
quickly in CPython by breaking up reference cycles.
"""
self._write_callback = None
self._write_future = None
self._close_callback = None
if self.stream is not None:
self.stream.set_close_callback(None)
def set_close_callback(self, callback):
"""Sets a callback that will be run when the connection is closed.
.. deprecated:: 4.0
Use `.HTTPMessageDelegate.on_connection_close` instead.
"""
self._close_callback = stack_context.wrap(callback)
def _on_connection_close(self):
# Note that this callback is only registered on the IOStream
# when we have finished reading the request and are waiting for
# the application to produce its response.
if self._close_callback is not None:
callback = self._close_callback
self._close_callback = None
callback()
if not self._finish_future.done():
self._finish_future.set_result(None)
self._clear_callbacks()
def close(self):
if self.stream is not None:
self.stream.close()
self._clear_callbacks()
if not self._finish_future.done():
self._finish_future.set_result(None)
def detach(self):
"""Take control of the underlying stream.
Returns the underlying `.IOStream` object and stops all further
HTTP processing. May only be called during
`.HTTPMessageDelegate.headers_received`. Intended for implementing
protocols like websockets that tunnel over an HTTP handshake.
"""
self._clear_callbacks()
stream = self.stream
self.stream = None
return stream
def set_body_timeout(self, timeout):
"""Sets the body timeout for a single request.
Overrides the value from `.HTTP1ConnectionParameters`.
"""
self._body_timeout = timeout
def set_max_body_size(self, max_body_size):
"""Sets the body size limit for a single request.
Overrides the value from `.HTTP1ConnectionParameters`.
"""
self._max_body_size = max_body_size
def write_headers(self, start_line, headers, chunk=None, callback=None):
"""Implements `.HTTPConnection.write_headers`."""
if self.is_client:
self._request_start_line = start_line
# Client requests with a non-empty body must have either a
# Content-Length or a Transfer-Encoding.
self._chunking_output = (
start_line.method in ('POST', 'PUT', 'PATCH') and
'Content-Length' not in headers and
'Transfer-Encoding' not in headers)
else:
self._response_start_line = start_line
self._chunking_output = (
# TODO: should this use
# self._request_start_line.version or
# start_line.version?
self._request_start_line.version == 'HTTP/1.1' and
# 304 responses have no body (not even a zero-length body), and so
# should not have either Content-Length or Transfer-Encoding.
# headers.
start_line.code != 304 and
# No need to chunk the output if a Content-Length is specified.
'Content-Length' not in headers and
# Applications are discouraged from touching Transfer-Encoding,
# but if they do, leave it alone.
'Transfer-Encoding' not in headers)
# If a 1.0 client asked for keep-alive, add the header.
if (self._request_start_line.version == 'HTTP/1.0' and
(self._request_headers.get('Connection', '').lower()
== 'keep-alive')):
headers['Connection'] = 'Keep-Alive'
if self._chunking_output:
headers['Transfer-Encoding'] = 'chunked'
if (not self.is_client and
(self._request_start_line.method == 'HEAD' or
start_line.code == 304)):
self._expected_content_remaining = 0
elif 'Content-Length' in headers:
self._expected_content_remaining = int(headers['Content-Length'])
else:
self._expected_content_remaining = None
lines = [utf8("%s %s %s" % start_line)]
lines.extend([utf8(n) + b": " + utf8(v) for n, v in headers.get_all()])
for line in lines:
if b'\n' in line:
raise ValueError('Newline in header: ' + repr(line))
future = None
if self.stream.closed():
future = self._write_future = Future()
future.set_exception(iostream.StreamClosedError())
else:
if callback is not None:
self._write_callback = stack_context.wrap(callback)
else:
future = self._write_future = Future()
data = b"\r\n".join(lines) + b"\r\n\r\n"
if chunk:
data += self._format_chunk(chunk)
self._pending_write = self.stream.write(data)
self._pending_write.add_done_callback(self._on_write_complete)
return future
def _format_chunk(self, chunk):
if self._expected_content_remaining is not None:
self._expected_content_remaining -= len(chunk)
if self._expected_content_remaining < 0:
# Close the stream now to stop further framing errors.
self.stream.close()
raise httputil.HTTPOutputError(
"Tried to write more data than Content-Length")
if self._chunking_output and chunk:
# Don't write out empty chunks because that means END-OF-STREAM
# with chunked encoding
return utf8("%x" % len(chunk)) + b"\r\n" + chunk + b"\r\n"
else:
return chunk
def write(self, chunk, callback=None):
"""Implements `.HTTPConnection.write`.
For backwards compatibility is is allowed but deprecated to
skip `write_headers` and instead call `write()` with a
pre-encoded header block.
"""
future = None
if self.stream.closed():
future = self._write_future = Future()
self._write_future.set_exception(iostream.StreamClosedError())
else:
if callback is not None:
self._write_callback = stack_context.wrap(callback)
else:
future = self._write_future = Future()
self._pending_write = self.stream.write(self._format_chunk(chunk))
self._pending_write.add_done_callback(self._on_write_complete)
return future
def finish(self):
"""Implements `.HTTPConnection.finish`."""
if (self._expected_content_remaining is not None and
self._expected_content_remaining != 0 and
not self.stream.closed()):
self.stream.close()
raise httputil.HTTPOutputError(
"Tried to write %d bytes less than Content-Length" %
self._expected_content_remaining)
if self._chunking_output:
if not self.stream.closed():
self._pending_write = self.stream.write(b"0\r\n\r\n")
self._pending_write.add_done_callback(self._on_write_complete)
self._write_finished = True
# If the app finished the request while we're still reading,
# divert any remaining data away from the delegate and
# close the connection when we're done sending our response.
# Closing the connection is the only way to avoid reading the
# whole input body.
if not self._read_finished:
self._disconnect_on_finish = True
# No more data is coming, so instruct TCP to send any remaining
# data immediately instead of waiting for a full packet or ack.
self.stream.set_nodelay(True)
if self._pending_write is None:
self._finish_request(None)
else:
self._pending_write.add_done_callback(self._finish_request)
def _on_write_complete(self, future):
if self._write_callback is not None:
callback = self._write_callback
self._write_callback = None
self.stream.io_loop.add_callback(callback)
if self._write_future is not None:
future = self._write_future
self._write_future = None
future.set_result(None)
def _can_keep_alive(self, start_line, headers):
if self.params.no_keep_alive:
return False
connection_header = headers.get("Connection")
if connection_header is not None:
connection_header = connection_header.lower()
if start_line.version == "HTTP/1.1":
return connection_header != "close"
elif ("Content-Length" in headers
or start_line.method in ("HEAD", "GET")):
return connection_header == "keep-alive"
return False
def _finish_request(self, future):
self._clear_callbacks()
if not self.is_client and self._disconnect_on_finish:
self.close()
return
# Turn Nagle's algorithm back on, leaving the stream in its
# default state for the next request.
self.stream.set_nodelay(False)
if not self._finish_future.done():
self._finish_future.set_result(None)
def _parse_headers(self, data):
data = native_str(data.decode('latin1'))
eol = data.find("\r\n")
start_line = data[:eol]
try:
headers = httputil.HTTPHeaders.parse(data[eol:])
except ValueError:
# probably form split() if there was no ':' in the line
raise httputil.HTTPInputError("Malformed HTTP headers: %r" %
data[eol:100])
return start_line, headers
def _read_body(self, headers, delegate):
content_length = headers.get("Content-Length")
if content_length:
content_length = int(content_length)
if content_length > self._max_body_size:
raise httputil.HTTPInputError("Content-Length too long")
return self._read_fixed_body(content_length, delegate)
if headers.get("Transfer-Encoding") == "chunked":
return self._read_chunked_body(delegate)
if self.is_client:
return self._read_body_until_close(delegate)
return None
@gen.coroutine
def _read_fixed_body(self, content_length, delegate):
while content_length > 0:
body = yield self.stream.read_bytes(
min(self.params.chunk_size, content_length), partial=True)
content_length -= len(body)
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
yield gen.maybe_future(delegate.data_received(body))
@gen.coroutine
def _read_chunked_body(self, delegate):
# TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
total_size = 0
while True:
chunk_len = yield self.stream.read_until(b"\r\n", max_bytes=64)
chunk_len = int(chunk_len.strip(), 16)
if chunk_len == 0:
return
total_size += chunk_len
if total_size > self._max_body_size:
raise httputil.HTTPInputError("chunked body too large")
bytes_to_read = chunk_len
while bytes_to_read:
chunk = yield self.stream.read_bytes(
min(bytes_to_read, self.params.chunk_size), partial=True)
bytes_to_read -= len(chunk)
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
yield gen.maybe_future(delegate.data_received(chunk))
# chunk ends with \r\n
crlf = yield self.stream.read_bytes(2)
assert crlf == b"\r\n"
@gen.coroutine
def _read_body_until_close(self, delegate):
body = yield self.stream.read_until_close()
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
delegate.data_received(body)
class _GzipMessageDelegate(httputil.HTTPMessageDelegate):
"""Wraps an `HTTPMessageDelegate` to decode ``Content-Encoding: gzip``.
"""
def __init__(self, delegate, chunk_size):
self._delegate = delegate
self._chunk_size = chunk_size
self._decompressor = None
def headers_received(self, start_line, headers):
if headers.get("Content-Encoding") == "gzip":
self._decompressor = GzipDecompressor()
# Downstream delegates will only see uncompressed data,
# so rename the content-encoding header.
# (but note that curl_httpclient doesn't do this).
headers.add("X-Consumed-Content-Encoding",
headers["Content-Encoding"])
del headers["Content-Encoding"]
return self._delegate.headers_received(start_line, headers)
@gen.coroutine
def data_received(self, chunk):
if self._decompressor:
compressed_data = chunk
while compressed_data:
decompressed = self._decompressor.decompress(
compressed_data, self._chunk_size)
if decompressed:
yield gen.maybe_future(
self._delegate.data_received(decompressed))
compressed_data = self._decompressor.unconsumed_tail
else:
yield gen.maybe_future(self._delegate.data_received(chunk))
def finish(self):
if self._decompressor is not None:
tail = self._decompressor.flush()
if tail:
# I believe the tail will always be empty (i.e.
# decompress will return all it can). The purpose
# of the flush call is to detect errors such
# as truncated input. But in case it ever returns
# anything, treat it as an extra chunk
self._delegate.data_received(tail)
return self._delegate.finish()
class HTTP1ServerConnection(object):
"""An HTTP/1.x server."""
def __init__(self, stream, params=None, context=None):
"""
:arg stream: an `.IOStream`
:arg params: a `.HTTP1ConnectionParameters` or None
:arg context: an opaque application-defined object that is accessible
as ``connection.context``
"""
self.stream = stream
if params is None:
params = HTTP1ConnectionParameters()
self.params = params
self.context = context
self._serving_future = None
@gen.coroutine
def close(self):
"""Closes the connection.
Returns a `.Future` that resolves after the serving loop has exited.
"""
self.stream.close()
# Block until the serving loop is done, but ignore any exceptions
# (start_serving is already responsible for logging them).
try:
yield self._serving_future
except Exception:
pass
def start_serving(self, delegate):
"""Starts serving requests on this connection.
:arg delegate: a `.HTTPServerConnectionDelegate`
"""
assert isinstance(delegate, httputil.HTTPServerConnectionDelegate)
self._serving_future = self._server_request_loop(delegate)
# Register the future on the IOLoop so its errors get logged.
self.stream.io_loop.add_future(self._serving_future,
lambda f: f.result())
@gen.coroutine
def _server_request_loop(self, delegate):
try:
while True:
conn = HTTP1Connection(self.stream, False,
self.params, self.context)
request_delegate = delegate.start_request(self, conn)
try:
ret = yield conn.read_response(request_delegate)
except (iostream.StreamClosedError,
iostream.UnsatisfiableReadError):
return
except _QuietException:
# This exception was already logged.
conn.close()
return
except Exception:
gen_log.error("Uncaught exception", exc_info=True)
conn.close()
return
if not ret:
return
yield gen.moment
finally:
delegate.on_close(self)

View file

@ -25,6 +25,11 @@ to switch to ``curl_httpclient`` for reasons such as the following:
Note that if you are using ``curl_httpclient``, it is highly recommended that
you use a recent version of ``libcurl`` and ``pycurl``. Currently the minimum
supported version is 7.18.2, and the recommended version is 7.21.1 or newer.
It is highly recommended that your ``libcurl`` installation is built with
asynchronous DNS resolver (threaded or c-ares), otherwise you may encounter
various problems with request timeouts (for more information, see
http://curl.haxx.se/libcurl/c/curl_easy_setopt.html#CURLOPTCONNECTTIMEOUTMS
and comments in curl_httpclient.py).
"""
from __future__ import absolute_import, division, print_function, with_statement
@ -34,7 +39,7 @@ import time
import weakref
from tornado.concurrent import TracebackFuture
from tornado.escape import utf8
from tornado.escape import utf8, native_str
from tornado import httputil, stack_context
from tornado.ioloop import IOLoop
from tornado.util import Configurable
@ -105,10 +110,21 @@ class AsyncHTTPClient(Configurable):
actually creates an instance of an implementation-specific
subclass, and instances are reused as a kind of pseudo-singleton
(one per `.IOLoop`). The keyword argument ``force_instance=True``
can be used to suppress this singleton behavior. Constructor
arguments other than ``io_loop`` and ``force_instance`` are
deprecated. The implementation subclass as well as arguments to
its constructor can be set with the static method `configure()`
can be used to suppress this singleton behavior. Unless
``force_instance=True`` is used, no arguments other than
``io_loop`` should be passed to the `AsyncHTTPClient` constructor.
The implementation subclass as well as arguments to its
constructor can be set with the static method `configure()`
All `AsyncHTTPClient` implementations support a ``defaults``
keyword argument, which can be used to set default values for
`HTTPRequest` attributes. For example::
AsyncHTTPClient.configure(
None, defaults=dict(user_agent="MyUserAgent"))
# or with force_instance:
client = AsyncHTTPClient(force_instance=True,
defaults=dict(user_agent="MyUserAgent"))
"""
@classmethod
def configurable_base(cls):
@ -141,6 +157,7 @@ class AsyncHTTPClient(Configurable):
self.defaults = dict(HTTPRequest._DEFAULTS)
if defaults is not None:
self.defaults.update(defaults)
self._closed = False
def close(self):
"""Destroys this HTTP client, freeing any file descriptors used.
@ -155,6 +172,7 @@ class AsyncHTTPClient(Configurable):
``close()``.
"""
self._closed = True
if self._async_clients().get(self.io_loop) is self:
del self._async_clients()[self.io_loop]
@ -166,7 +184,7 @@ class AsyncHTTPClient(Configurable):
kwargs: ``HTTPRequest(request, **kwargs)``
This method returns a `.Future` whose result is an
`HTTPResponse`. The ``Future`` wil raise an `HTTPError` if
`HTTPResponse`. The ``Future`` will raise an `HTTPError` if
the request returned a non-200 response code.
If a ``callback`` is given, it will be invoked with the `HTTPResponse`.
@ -174,6 +192,8 @@ class AsyncHTTPClient(Configurable):
Instead, you must check the response's ``error`` attribute or
call its `~HTTPResponse.rethrow` method.
"""
if self._closed:
raise RuntimeError("fetch() called on closed AsyncHTTPClient")
if not isinstance(request, HTTPRequest):
request = HTTPRequest(url=request, **kwargs)
# We may modify this (to add Host, Accept-Encoding, etc),
@ -259,14 +279,27 @@ class HTTPRequest(object):
proxy_password=None, allow_nonstandard_methods=None,
validate_cert=None, ca_certs=None,
allow_ipv6=None,
client_key=None, client_cert=None):
client_key=None, client_cert=None, body_producer=None,
expect_100_continue=False):
r"""All parameters except ``url`` are optional.
:arg string url: URL to fetch
:arg string method: HTTP method, e.g. "GET" or "POST"
:arg headers: Additional HTTP headers to pass on the request
:arg body: HTTP body to pass on the request
:type headers: `~tornado.httputil.HTTPHeaders` or `dict`
:arg body: HTTP request body as a string (byte or unicode; if unicode
the utf-8 encoding will be used)
:arg body_producer: Callable used for lazy/asynchronous request bodies.
It is called with one argument, a ``write`` function, and should
return a `.Future`. It should call the write function with new
data as it becomes available. The write function returns a
`.Future` which can be used for flow control.
Only one of ``body`` and ``body_producer`` may
be specified. ``body_producer`` is not supported on
``curl_httpclient``. When using ``body_producer`` it is recommended
to pass a ``Content-Length`` in the headers as otherwise chunked
encoding will be used, and many servers do not support chunked
encoding on requests. New in Tornado 4.0
:arg string auth_username: Username for HTTP authentication
:arg string auth_password: Password for HTTP authentication
:arg string auth_mode: Authentication mode; default is "basic".
@ -319,6 +352,11 @@ class HTTPRequest(object):
note below when used with ``curl_httpclient``.
:arg string client_cert: Filename for client SSL certificate, if any.
See note below when used with ``curl_httpclient``.
:arg bool expect_100_continue: If true, send the
``Expect: 100-continue`` header and wait for a continue response
before sending the request body. Only supported with
simple_httpclient.
.. note::
@ -334,6 +372,9 @@ class HTTPRequest(object):
.. versionadded:: 3.1
The ``auth_mode`` argument.
.. versionadded:: 4.0
The ``body_producer`` and ``expect_100_continue`` arguments.
"""
# Note that some of these attributes go through property setters
# defined below.
@ -348,6 +389,7 @@ class HTTPRequest(object):
self.url = url
self.method = method
self.body = body
self.body_producer = body_producer
self.auth_username = auth_username
self.auth_password = auth_password
self.auth_mode = auth_mode
@ -367,6 +409,7 @@ class HTTPRequest(object):
self.allow_ipv6 = allow_ipv6
self.client_key = client_key
self.client_cert = client_cert
self.expect_100_continue = expect_100_continue
self.start_time = time.time()
@property
@ -388,6 +431,14 @@ class HTTPRequest(object):
def body(self, value):
self._body = utf8(value)
@property
def body_producer(self):
return self._body_producer
@body_producer.setter
def body_producer(self, value):
self._body_producer = stack_context.wrap(value)
@property
def streaming_callback(self):
return self._streaming_callback
@ -423,8 +474,6 @@ class HTTPResponse(object):
* code: numeric HTTP status code, e.g. 200 or 404
* reason: human-readable reason phrase describing the status code
(with curl_httpclient, this is a default value rather than the
server's actual response)
* headers: `tornado.httputil.HTTPHeaders` object
@ -466,7 +515,8 @@ class HTTPResponse(object):
self.effective_url = effective_url
if error is None:
if self.code < 200 or self.code >= 300:
self.error = HTTPError(self.code, response=self)
self.error = HTTPError(self.code, message=self.reason,
response=self)
else:
self.error = None
else:
@ -556,7 +606,7 @@ def main():
if options.print_headers:
print(response.headers)
if options.print_body:
print(response.body)
print(native_str(response.body))
client.close()
if __name__ == "__main__":

View file

@ -20,70 +20,55 @@ Typical applications have little direct interaction with the `HTTPServer`
class except to start a server at the beginning of the process
(and even that is often done indirectly via `tornado.web.Application.listen`).
This module also defines the `HTTPRequest` class which is exposed via
`tornado.web.RequestHandler.request`.
.. versionchanged:: 4.0
The ``HTTPRequest`` class that used to live in this module has been moved
to `tornado.httputil.HTTPServerRequest`. The old name remains as an alias.
"""
from __future__ import absolute_import, division, print_function, with_statement
import socket
import ssl
import time
import copy
from tornado.escape import native_str, parse_qs_bytes
from tornado.escape import native_str
from tornado.http1connection import HTTP1ServerConnection, HTTP1ConnectionParameters
from tornado import gen
from tornado import httputil
from tornado import iostream
from tornado.log import gen_log
from tornado import netutil
from tornado.tcpserver import TCPServer
from tornado import stack_context
from tornado.util import bytes_type
try:
import Cookie # py2
except ImportError:
import http.cookies as Cookie # py3
class HTTPServer(TCPServer):
class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
r"""A non-blocking, single-threaded HTTP server.
A server is defined by a request callback that takes an HTTPRequest
instance as an argument and writes a valid HTTP response with
`HTTPRequest.write`. `HTTPRequest.finish` finishes the request (but does
not necessarily close the connection in the case of HTTP/1.1 keep-alive
requests). A simple example server that echoes back the URI you
requested::
A server is defined by either a request callback that takes a
`.HTTPServerRequest` as an argument or a `.HTTPServerConnectionDelegate`
instance.
A simple example server that echoes back the URI you requested::
import tornado.httpserver
import tornado.ioloop
def handle_request(request):
message = "You requested %s\n" % request.uri
request.write("HTTP/1.1 200 OK\r\nContent-Length: %d\r\n\r\n%s" % (
len(message), message))
request.finish()
request.connection.write_headers(
httputil.ResponseStartLine('HTTP/1.1', 200, 'OK'),
{"Content-Length": str(len(message))})
request.connection.write(message)
request.connection.finish()
http_server = tornado.httpserver.HTTPServer(handle_request)
http_server.listen(8888)
tornado.ioloop.IOLoop.instance().start()
`HTTPServer` is a very basic connection handler. It parses the request
headers and body, but the request callback is responsible for producing
the response exactly as it will appear on the wire. This affords
maximum flexibility for applications to implement whatever parts
of HTTP responses are required.
Applications should use the methods of `.HTTPConnection` to write
their response.
`HTTPServer` supports keep-alive connections by default
(automatically for HTTP/1.1, or for HTTP/1.0 when the client
requests ``Connection: keep-alive``). This means that the request
callback must generate a properly-framed response, using either
the ``Content-Length`` header or ``Transfer-Encoding: chunked``.
Applications that are unable to frame their responses properly
should instead return a ``Connection: close`` header in each
response and pass ``no_keep_alive=True`` to the `HTTPServer`
constructor.
requests ``Connection: keep-alive``).
If ``xheaders`` is ``True``, we support the
``X-Real-Ip``/``X-Forwarded-For`` and
@ -143,407 +128,169 @@ class HTTPServer(TCPServer):
servers if you want to create your listening sockets in some
way other than `tornado.netutil.bind_sockets`.
.. versionchanged:: 4.0
Added ``gzip``, ``chunk_size``, ``max_header_size``,
``idle_connection_timeout``, ``body_timeout``, ``max_body_size``
arguments. Added support for `.HTTPServerConnectionDelegate`
instances as ``request_callback``.
"""
def __init__(self, request_callback, no_keep_alive=False, io_loop=None,
xheaders=False, ssl_options=None, protocol=None, **kwargs):
xheaders=False, ssl_options=None, protocol=None, gzip=False,
chunk_size=None, max_header_size=None,
idle_connection_timeout=None, body_timeout=None,
max_body_size=None, max_buffer_size=None):
self.request_callback = request_callback
self.no_keep_alive = no_keep_alive
self.xheaders = xheaders
self.protocol = protocol
self.conn_params = HTTP1ConnectionParameters(
use_gzip=gzip,
chunk_size=chunk_size,
max_header_size=max_header_size,
header_timeout=idle_connection_timeout or 3600,
max_body_size=max_body_size,
body_timeout=body_timeout)
TCPServer.__init__(self, io_loop=io_loop, ssl_options=ssl_options,
**kwargs)
max_buffer_size=max_buffer_size,
read_chunk_size=chunk_size)
self._connections = set()
@gen.coroutine
def close_all_connections(self):
while self._connections:
# Peek at an arbitrary element of the set
conn = next(iter(self._connections))
yield conn.close()
def handle_stream(self, stream, address):
HTTPConnection(stream, address, self.request_callback,
self.no_keep_alive, self.xheaders, self.protocol)
context = _HTTPRequestContext(stream, address,
self.protocol)
conn = HTTP1ServerConnection(
stream, self.conn_params, context)
self._connections.add(conn)
conn.start_serving(self)
def start_request(self, server_conn, request_conn):
return _ServerRequestAdapter(self, request_conn)
def on_close(self, server_conn):
self._connections.remove(server_conn)
class _BadRequestException(Exception):
"""Exception class for malformed HTTP requests."""
pass
class HTTPConnection(object):
"""Handles a connection to an HTTP client, executing HTTP requests.
We parse HTTP headers and bodies, and execute the request callback
until the HTTP conection is closed.
"""
def __init__(self, stream, address, request_callback, no_keep_alive=False,
xheaders=False, protocol=None):
self.stream = stream
class _HTTPRequestContext(object):
def __init__(self, stream, address, protocol):
self.address = address
self.protocol = protocol
# Save the socket's address family now so we know how to
# interpret self.address even after the stream is closed
# and its socket attribute replaced with None.
self.address_family = stream.socket.family
self.request_callback = request_callback
self.no_keep_alive = no_keep_alive
self.xheaders = xheaders
self.protocol = protocol
self._clear_request_state()
# Save stack context here, outside of any request. This keeps
# contexts from one request from leaking into the next.
self._header_callback = stack_context.wrap(self._on_headers)
self.stream.set_close_callback(self._on_connection_close)
self.stream.read_until(b"\r\n\r\n", self._header_callback)
def _clear_request_state(self):
"""Clears the per-request state.
This is run in between requests to allow the previous handler
to be garbage collected (and prevent spurious close callbacks),
and when the connection is closed (to break up cycles and
facilitate garbage collection in cpython).
"""
self._request = None
self._request_finished = False
self._write_callback = None
self._close_callback = None
def set_close_callback(self, callback):
"""Sets a callback that will be run when the connection is closed.
Use this instead of accessing
`HTTPConnection.stream.set_close_callback
<.BaseIOStream.set_close_callback>` directly (which was the
recommended approach prior to Tornado 3.0).
"""
self._close_callback = stack_context.wrap(callback)
def _on_connection_close(self):
if self._close_callback is not None:
callback = self._close_callback
self._close_callback = None
callback()
# Delete any unfinished callbacks to break up reference cycles.
self._header_callback = None
self._clear_request_state()
def close(self):
self.stream.close()
# Remove this reference to self, which would otherwise cause a
# cycle and delay garbage collection of this connection.
self._header_callback = None
self._clear_request_state()
def write(self, chunk, callback=None):
"""Writes a chunk of output to the stream."""
if not self.stream.closed():
self._write_callback = stack_context.wrap(callback)
self.stream.write(chunk, self._on_write_complete)
def finish(self):
"""Finishes the request."""
self._request_finished = True
# No more data is coming, so instruct TCP to send any remaining
# data immediately instead of waiting for a full packet or ack.
self.stream.set_nodelay(True)
if not self.stream.writing():
self._finish_request()
def _on_write_complete(self):
if self._write_callback is not None:
callback = self._write_callback
self._write_callback = None
callback()
# _on_write_complete is enqueued on the IOLoop whenever the
# IOStream's write buffer becomes empty, but it's possible for
# another callback that runs on the IOLoop before it to
# simultaneously write more data and finish the request. If
# there is still data in the IOStream, a future
# _on_write_complete will be responsible for calling
# _finish_request.
if self._request_finished and not self.stream.writing():
self._finish_request()
def _finish_request(self):
if self.no_keep_alive or self._request is None:
disconnect = True
if stream.socket is not None:
self.address_family = stream.socket.family
else:
connection_header = self._request.headers.get("Connection")
if connection_header is not None:
connection_header = connection_header.lower()
if self._request.supports_http_1_1():
disconnect = connection_header == "close"
elif ("Content-Length" in self._request.headers
or self._request.method in ("HEAD", "GET")):
disconnect = connection_header != "keep-alive"
else:
disconnect = True
self._clear_request_state()
if disconnect:
self.close()
return
try:
# Use a try/except instead of checking stream.closed()
# directly, because in some cases the stream doesn't discover
# that it's closed until you try to read from it.
self.stream.read_until(b"\r\n\r\n", self._header_callback)
# Turn Nagle's algorithm back on, leaving the stream in its
# default state for the next request.
self.stream.set_nodelay(False)
except iostream.StreamClosedError:
self.close()
def _on_headers(self, data):
try:
data = native_str(data.decode('latin1'))
eol = data.find("\r\n")
start_line = data[:eol]
try:
method, uri, version = start_line.split(" ")
except ValueError:
raise _BadRequestException("Malformed HTTP request line")
if not version.startswith("HTTP/"):
raise _BadRequestException("Malformed HTTP version in HTTP Request-Line")
try:
headers = httputil.HTTPHeaders.parse(data[eol:])
except ValueError:
# Probably from split() if there was no ':' in the line
raise _BadRequestException("Malformed HTTP headers")
# HTTPRequest wants an IP, not a full socket address
if self.address_family in (socket.AF_INET, socket.AF_INET6):
remote_ip = self.address[0]
else:
# Unix (or other) socket; fake the remote address
remote_ip = '0.0.0.0'
self._request = HTTPRequest(
connection=self, method=method, uri=uri, version=version,
headers=headers, remote_ip=remote_ip, protocol=self.protocol)
content_length = headers.get("Content-Length")
if content_length:
content_length = int(content_length)
if content_length > self.stream.max_buffer_size:
raise _BadRequestException("Content-Length too long")
if headers.get("Expect") == "100-continue":
self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
self.stream.read_bytes(content_length, self._on_request_body)
return
self.request_callback(self._request)
except _BadRequestException as e:
gen_log.info("Malformed HTTP request from %r: %s",
self.address, e)
self.close()
return
def _on_request_body(self, data):
self._request.body = data
if self._request.method in ("POST", "PATCH", "PUT"):
httputil.parse_body_arguments(
self._request.headers.get("Content-Type", ""), data,
self._request.body_arguments, self._request.files)
for k, v in self._request.body_arguments.items():
self._request.arguments.setdefault(k, []).extend(v)
self.request_callback(self._request)
class HTTPRequest(object):
"""A single HTTP request.
All attributes are type `str` unless otherwise noted.
.. attribute:: method
HTTP request method, e.g. "GET" or "POST"
.. attribute:: uri
The requested uri.
.. attribute:: path
The path portion of `uri`
.. attribute:: query
The query portion of `uri`
.. attribute:: version
HTTP version specified in request, e.g. "HTTP/1.1"
.. attribute:: headers
`.HTTPHeaders` dictionary-like object for request headers. Acts like
a case-insensitive dictionary with additional methods for repeated
headers.
.. attribute:: body
Request body, if present, as a byte string.
.. attribute:: remote_ip
Client's IP address as a string. If ``HTTPServer.xheaders`` is set,
will pass along the real IP address provided by a load balancer
in the ``X-Real-Ip`` or ``X-Forwarded-For`` header.
.. versionchanged:: 3.1
The list format of ``X-Forwarded-For`` is now supported.
.. attribute:: protocol
The protocol used, either "http" or "https". If ``HTTPServer.xheaders``
is set, will pass along the protocol used by a load balancer if
reported via an ``X-Scheme`` header.
.. attribute:: host
The requested hostname, usually taken from the ``Host`` header.
.. attribute:: arguments
GET/POST arguments are available in the arguments property, which
maps arguments names to lists of values (to support multiple values
for individual names). Names are of type `str`, while arguments
are byte strings. Note that this is different from
`.RequestHandler.get_argument`, which returns argument values as
unicode strings.
.. attribute:: query_arguments
Same format as ``arguments``, but contains only arguments extracted
from the query string.
.. versionadded:: 3.2
.. attribute:: body_arguments
Same format as ``arguments``, but contains only arguments extracted
from the request body.
.. versionadded:: 3.2
.. attribute:: files
File uploads are available in the files property, which maps file
names to lists of `.HTTPFile`.
.. attribute:: connection
An HTTP request is attached to a single HTTP connection, which can
be accessed through the "connection" attribute. Since connections
are typically kept open in HTTP/1.1, multiple requests can be handled
sequentially on a single connection.
"""
def __init__(self, method, uri, version="HTTP/1.0", headers=None,
body=None, remote_ip=None, protocol=None, host=None,
files=None, connection=None):
self.method = method
self.uri = uri
self.version = version
self.headers = headers or httputil.HTTPHeaders()
self.body = body or ""
# set remote IP and protocol
self.remote_ip = remote_ip
self.address_family = None
# In HTTPServerRequest we want an IP, not a full socket address.
if (self.address_family in (socket.AF_INET, socket.AF_INET6) and
address is not None):
self.remote_ip = address[0]
else:
# Unix (or other) socket; fake the remote address.
self.remote_ip = '0.0.0.0'
if protocol:
self.protocol = protocol
elif connection and isinstance(connection.stream,
iostream.SSLIOStream):
elif isinstance(stream, iostream.SSLIOStream):
self.protocol = "https"
else:
self.protocol = "http"
self._orig_remote_ip = self.remote_ip
self._orig_protocol = self.protocol
# xheaders can override the defaults
if connection and connection.xheaders:
# Squid uses X-Forwarded-For, others use X-Real-Ip
ip = self.headers.get("X-Forwarded-For", self.remote_ip)
ip = ip.split(',')[-1].strip()
ip = self.headers.get(
"X-Real-Ip", ip)
if netutil.is_valid_ip(ip):
self.remote_ip = ip
# AWS uses X-Forwarded-Proto
proto = self.headers.get(
"X-Scheme", self.headers.get("X-Forwarded-Proto", self.protocol))
if proto in ("http", "https"):
self.protocol = proto
def __str__(self):
if self.address_family in (socket.AF_INET, socket.AF_INET6):
return self.remote_ip
elif isinstance(self.address, bytes):
# Python 3 with the -bb option warns about str(bytes),
# so convert it explicitly.
# Unix socket addresses are str on mac but bytes on linux.
return native_str(self.address)
else:
return str(self.address)
self.host = host or self.headers.get("Host") or "127.0.0.1"
self.files = files or {}
def _apply_xheaders(self, headers):
"""Rewrite the ``remote_ip`` and ``protocol`` fields."""
# Squid uses X-Forwarded-For, others use X-Real-Ip
ip = headers.get("X-Forwarded-For", self.remote_ip)
ip = ip.split(',')[-1].strip()
ip = headers.get("X-Real-Ip", ip)
if netutil.is_valid_ip(ip):
self.remote_ip = ip
# AWS uses X-Forwarded-Proto
proto_header = headers.get(
"X-Scheme", headers.get("X-Forwarded-Proto",
self.protocol))
if proto_header in ("http", "https"):
self.protocol = proto_header
def _unapply_xheaders(self):
"""Undo changes from `_apply_xheaders`.
Xheaders are per-request so they should not leak to the next
request on the same connection.
"""
self.remote_ip = self._orig_remote_ip
self.protocol = self._orig_protocol
class _ServerRequestAdapter(httputil.HTTPMessageDelegate):
"""Adapts the `HTTPMessageDelegate` interface to the interface expected
by our clients.
"""
def __init__(self, server, connection):
self.server = server
self.connection = connection
self._start_time = time.time()
self._finish_time = None
self.request = None
if isinstance(server.request_callback,
httputil.HTTPServerConnectionDelegate):
self.delegate = server.request_callback.start_request(connection)
self._chunks = None
else:
self.delegate = None
self._chunks = []
self.path, sep, self.query = uri.partition('?')
self.arguments = parse_qs_bytes(self.query, keep_blank_values=True)
self.query_arguments = copy.deepcopy(self.arguments)
self.body_arguments = {}
def headers_received(self, start_line, headers):
if self.server.xheaders:
self.connection.context._apply_xheaders(headers)
if self.delegate is None:
self.request = httputil.HTTPServerRequest(
connection=self.connection, start_line=start_line,
headers=headers)
else:
return self.delegate.headers_received(start_line, headers)
def supports_http_1_1(self):
"""Returns True if this request supports HTTP/1.1 semantics"""
return self.version == "HTTP/1.1"
@property
def cookies(self):
"""A dictionary of Cookie.Morsel objects."""
if not hasattr(self, "_cookies"):
self._cookies = Cookie.SimpleCookie()
if "Cookie" in self.headers:
try:
self._cookies.load(
native_str(self.headers["Cookie"]))
except Exception:
self._cookies = {}
return self._cookies
def write(self, chunk, callback=None):
"""Writes the given chunk to the response stream."""
assert isinstance(chunk, bytes_type)
self.connection.write(chunk, callback=callback)
def data_received(self, chunk):
if self.delegate is None:
self._chunks.append(chunk)
else:
return self.delegate.data_received(chunk)
def finish(self):
"""Finishes this HTTP request on the open connection."""
self.connection.finish()
self._finish_time = time.time()
def full_url(self):
"""Reconstructs the full URL for this request."""
return self.protocol + "://" + self.host + self.uri
def request_time(self):
"""Returns the amount of time it took for this request to execute."""
if self._finish_time is None:
return time.time() - self._start_time
if self.delegate is None:
self.request.body = b''.join(self._chunks)
self.request._parse_body()
self.server.request_callback(self.request)
else:
return self._finish_time - self._start_time
self.delegate.finish()
self._cleanup()
def get_ssl_certificate(self, binary_form=False):
"""Returns the client's SSL certificate, if any.
def on_connection_close(self):
if self.delegate is None:
self._chunks = None
else:
self.delegate.on_connection_close()
self._cleanup()
To use client certificates, the HTTPServer must have been constructed
with cert_reqs set in ssl_options, e.g.::
def _cleanup(self):
if self.server.xheaders:
self.connection.context._unapply_xheaders()
server = HTTPServer(app,
ssl_options=dict(
certfile="foo.crt",
keyfile="foo.key",
cert_reqs=ssl.CERT_REQUIRED,
ca_certs="cacert.crt"))
By default, the return value is a dictionary (or None, if no
client certificate is present). If ``binary_form`` is true, a
DER-encoded form of the certificate is returned instead. See
SSLSocket.getpeercert() in the standard library for more
details.
http://docs.python.org/library/ssl.html#sslsocket-objects
"""
try:
return self.connection.stream.socket.getpeercert(
binary_form=binary_form)
except ssl.SSLError:
return None
def __repr__(self):
attrs = ("protocol", "host", "method", "uri", "version", "remote_ip")
args = ", ".join(["%s=%r" % (n, getattr(self, n)) for n in attrs])
return "%s(%s, headers=%s)" % (
self.__class__.__name__, args, dict(self.headers))
HTTPRequest = httputil.HTTPServerRequest

View file

@ -14,20 +14,31 @@
# License for the specific language governing permissions and limitations
# under the License.
"""HTTP utility code shared by clients and servers."""
"""HTTP utility code shared by clients and servers.
This module also defines the `HTTPServerRequest` class which is exposed
via `tornado.web.RequestHandler.request`.
"""
from __future__ import absolute_import, division, print_function, with_statement
import calendar
import collections
import copy
import datetime
import email.utils
import numbers
import re
import time
from tornado.escape import native_str, parse_qs_bytes, utf8
from tornado.log import gen_log
from tornado.util import ObjectDict
from tornado.util import ObjectDict, bytes_type
try:
import Cookie # py2
except ImportError:
import http.cookies as Cookie # py3
try:
from httplib import responses # py2
@ -43,6 +54,13 @@ try:
except ImportError:
from urllib.parse import urlencode # py3
try:
from ssl import SSLError
except ImportError:
# ssl is unavailable on app engine.
class SSLError(Exception):
pass
class _NormalizedHeaderCache(dict):
"""Dynamic cached mapping of header names to Http-Header-Case.
@ -212,6 +230,337 @@ class HTTPHeaders(dict):
return HTTPHeaders(self)
class HTTPServerRequest(object):
"""A single HTTP request.
All attributes are type `str` unless otherwise noted.
.. attribute:: method
HTTP request method, e.g. "GET" or "POST"
.. attribute:: uri
The requested uri.
.. attribute:: path
The path portion of `uri`
.. attribute:: query
The query portion of `uri`
.. attribute:: version
HTTP version specified in request, e.g. "HTTP/1.1"
.. attribute:: headers
`.HTTPHeaders` dictionary-like object for request headers. Acts like
a case-insensitive dictionary with additional methods for repeated
headers.
.. attribute:: body
Request body, if present, as a byte string.
.. attribute:: remote_ip
Client's IP address as a string. If ``HTTPServer.xheaders`` is set,
will pass along the real IP address provided by a load balancer
in the ``X-Real-Ip`` or ``X-Forwarded-For`` header.
.. versionchanged:: 3.1
The list format of ``X-Forwarded-For`` is now supported.
.. attribute:: protocol
The protocol used, either "http" or "https". If ``HTTPServer.xheaders``
is set, will pass along the protocol used by a load balancer if
reported via an ``X-Scheme`` header.
.. attribute:: host
The requested hostname, usually taken from the ``Host`` header.
.. attribute:: arguments
GET/POST arguments are available in the arguments property, which
maps arguments names to lists of values (to support multiple values
for individual names). Names are of type `str`, while arguments
are byte strings. Note that this is different from
`.RequestHandler.get_argument`, which returns argument values as
unicode strings.
.. attribute:: query_arguments
Same format as ``arguments``, but contains only arguments extracted
from the query string.
.. versionadded:: 3.2
.. attribute:: body_arguments
Same format as ``arguments``, but contains only arguments extracted
from the request body.
.. versionadded:: 3.2
.. attribute:: files
File uploads are available in the files property, which maps file
names to lists of `.HTTPFile`.
.. attribute:: connection
An HTTP request is attached to a single HTTP connection, which can
be accessed through the "connection" attribute. Since connections
are typically kept open in HTTP/1.1, multiple requests can be handled
sequentially on a single connection.
.. versionchanged:: 4.0
Moved from ``tornado.httpserver.HTTPRequest``.
"""
def __init__(self, method=None, uri=None, version="HTTP/1.0", headers=None,
body=None, host=None, files=None, connection=None,
start_line=None):
if start_line is not None:
method, uri, version = start_line
self.method = method
self.uri = uri
self.version = version
self.headers = headers or HTTPHeaders()
self.body = body or ""
# set remote IP and protocol
context = getattr(connection, 'context', None)
self.remote_ip = getattr(context, 'remote_ip')
self.protocol = getattr(context, 'protocol', "http")
self.host = host or self.headers.get("Host") or "127.0.0.1"
self.files = files or {}
self.connection = connection
self._start_time = time.time()
self._finish_time = None
self.path, sep, self.query = uri.partition('?')
self.arguments = parse_qs_bytes(self.query, keep_blank_values=True)
self.query_arguments = copy.deepcopy(self.arguments)
self.body_arguments = {}
def supports_http_1_1(self):
"""Returns True if this request supports HTTP/1.1 semantics.
.. deprecated:: 4.0
Applications are less likely to need this information with the
introduction of `.HTTPConnection`. If you still need it, access
the ``version`` attribute directly.
"""
return self.version == "HTTP/1.1"
@property
def cookies(self):
"""A dictionary of Cookie.Morsel objects."""
if not hasattr(self, "_cookies"):
self._cookies = Cookie.SimpleCookie()
if "Cookie" in self.headers:
try:
self._cookies.load(
native_str(self.headers["Cookie"]))
except Exception:
self._cookies = {}
return self._cookies
def write(self, chunk, callback=None):
"""Writes the given chunk to the response stream.
.. deprecated:: 4.0
Use ``request.connection`` and the `.HTTPConnection` methods
to write the response.
"""
assert isinstance(chunk, bytes_type)
self.connection.write(chunk, callback=callback)
def finish(self):
"""Finishes this HTTP request on the open connection.
.. deprecated:: 4.0
Use ``request.connection`` and the `.HTTPConnection` methods
to write the response.
"""
self.connection.finish()
self._finish_time = time.time()
def full_url(self):
"""Reconstructs the full URL for this request."""
return self.protocol + "://" + self.host + self.uri
def request_time(self):
"""Returns the amount of time it took for this request to execute."""
if self._finish_time is None:
return time.time() - self._start_time
else:
return self._finish_time - self._start_time
def get_ssl_certificate(self, binary_form=False):
"""Returns the client's SSL certificate, if any.
To use client certificates, the HTTPServer must have been constructed
with cert_reqs set in ssl_options, e.g.::
server = HTTPServer(app,
ssl_options=dict(
certfile="foo.crt",
keyfile="foo.key",
cert_reqs=ssl.CERT_REQUIRED,
ca_certs="cacert.crt"))
By default, the return value is a dictionary (or None, if no
client certificate is present). If ``binary_form`` is true, a
DER-encoded form of the certificate is returned instead. See
SSLSocket.getpeercert() in the standard library for more
details.
http://docs.python.org/library/ssl.html#sslsocket-objects
"""
try:
return self.connection.stream.socket.getpeercert(
binary_form=binary_form)
except SSLError:
return None
def _parse_body(self):
parse_body_arguments(
self.headers.get("Content-Type", ""), self.body,
self.body_arguments, self.files,
self.headers)
for k, v in self.body_arguments.items():
self.arguments.setdefault(k, []).extend(v)
def __repr__(self):
attrs = ("protocol", "host", "method", "uri", "version", "remote_ip")
args = ", ".join(["%s=%r" % (n, getattr(self, n)) for n in attrs])
return "%s(%s, headers=%s)" % (
self.__class__.__name__, args, dict(self.headers))
class HTTPInputError(Exception):
"""Exception class for malformed HTTP requests or responses
from remote sources.
.. versionadded:: 4.0
"""
pass
class HTTPOutputError(Exception):
"""Exception class for errors in HTTP output.
.. versionadded:: 4.0
"""
pass
class HTTPServerConnectionDelegate(object):
"""Implement this interface to handle requests from `.HTTPServer`.
.. versionadded:: 4.0
"""
def start_request(self, server_conn, request_conn):
"""This method is called by the server when a new request has started.
:arg server_conn: is an opaque object representing the long-lived
(e.g. tcp-level) connection.
:arg request_conn: is a `.HTTPConnection` object for a single
request/response exchange.
This method should return a `.HTTPMessageDelegate`.
"""
raise NotImplementedError()
def on_close(self, server_conn):
"""This method is called when a connection has been closed.
:arg server_conn: is a server connection that has previously been
passed to ``start_request``.
"""
pass
class HTTPMessageDelegate(object):
"""Implement this interface to handle an HTTP request or response.
.. versionadded:: 4.0
"""
def headers_received(self, start_line, headers):
"""Called when the HTTP headers have been received and parsed.
:arg start_line: a `.RequestStartLine` or `.ResponseStartLine`
depending on whether this is a client or server message.
:arg headers: a `.HTTPHeaders` instance.
Some `.HTTPConnection` methods can only be called during
``headers_received``.
May return a `.Future`; if it does the body will not be read
until it is done.
"""
pass
def data_received(self, chunk):
"""Called when a chunk of data has been received.
May return a `.Future` for flow control.
"""
pass
def finish(self):
"""Called after the last chunk of data has been received."""
pass
def on_connection_close(self):
"""Called if the connection is closed without finishing the request.
If ``headers_received`` is called, either ``finish`` or
``on_connection_close`` will be called, but not both.
"""
pass
class HTTPConnection(object):
"""Applications use this interface to write their responses.
.. versionadded:: 4.0
"""
def write_headers(self, start_line, headers, chunk=None, callback=None):
"""Write an HTTP header block.
:arg start_line: a `.RequestStartLine` or `.ResponseStartLine`.
:arg headers: a `.HTTPHeaders` instance.
:arg chunk: the first (optional) chunk of data. This is an optimization
so that small responses can be written in the same call as their
headers.
:arg callback: a callback to be run when the write is complete.
Returns a `.Future` if no callback is given.
"""
raise NotImplementedError()
def write(self, chunk, callback=None):
"""Writes a chunk of body data.
The callback will be run when the write is complete. If no callback
is given, returns a Future.
"""
raise NotImplementedError()
def finish(self):
"""Indicates that the last body data has been written.
"""
raise NotImplementedError()
def url_concat(url, args):
"""Concatenate url and argument dictionary regardless of whether
url has existing query parameters.
@ -310,7 +659,7 @@ def _int_or_none(val):
return int(val)
def parse_body_arguments(content_type, body, arguments, files):
def parse_body_arguments(content_type, body, arguments, files, headers=None):
"""Parses a form request body.
Supports ``application/x-www-form-urlencoded`` and
@ -319,6 +668,10 @@ def parse_body_arguments(content_type, body, arguments, files):
and ``files`` parameters are dictionaries that will be updated
with the parsed contents.
"""
if headers and 'Content-Encoding' in headers:
gen_log.warning("Unsupported Content-Encoding: %s",
headers['Content-Encoding'])
return
if content_type.startswith("application/x-www-form-urlencoded"):
try:
uri_arguments = parse_qs_bytes(native_str(body), keep_blank_values=True)
@ -405,6 +758,48 @@ def format_timestamp(ts):
raise TypeError("unknown timestamp type: %r" % ts)
return email.utils.formatdate(ts, usegmt=True)
RequestStartLine = collections.namedtuple(
'RequestStartLine', ['method', 'path', 'version'])
def parse_request_start_line(line):
"""Returns a (method, path, version) tuple for an HTTP 1.x request line.
The response is a `collections.namedtuple`.
>>> parse_request_start_line("GET /foo HTTP/1.1")
RequestStartLine(method='GET', path='/foo', version='HTTP/1.1')
"""
try:
method, path, version = line.split(" ")
except ValueError:
raise HTTPInputError("Malformed HTTP request line")
if not version.startswith("HTTP/"):
raise HTTPInputError(
"Malformed HTTP version in HTTP Request-Line: %r" % version)
return RequestStartLine(method, path, version)
ResponseStartLine = collections.namedtuple(
'ResponseStartLine', ['version', 'code', 'reason'])
def parse_response_start_line(line):
"""Returns a (version, code, reason) tuple for an HTTP 1.x response line.
The response is a `collections.namedtuple`.
>>> parse_response_start_line("HTTP/1.1 200 OK")
ResponseStartLine(version='HTTP/1.1', code=200, reason='OK')
"""
line = native_str(line)
match = re.match("(HTTP/1.[01]) ([0-9]+) ([^\r]*)", line)
if not match:
raise HTTPInputError("Error parsing response start line")
return ResponseStartLine(match.group(1), int(match.group(2)),
match.group(3))
# _parseparam and _parse_header are copied and modified from python2.7's cgi.py
# The original 2.7 version of this code did not correctly support some
# combinations of semicolons and double quotes.

View file

@ -32,6 +32,7 @@ import datetime
import errno
import functools
import heapq
import itertools
import logging
import numbers
import os
@ -41,10 +42,11 @@ import threading
import time
import traceback
from tornado.concurrent import Future, TracebackFuture
from tornado.concurrent import TracebackFuture, is_future
from tornado.log import app_log, gen_log
from tornado import stack_context
from tornado.util import Configurable
from tornado.util import errno_from_exception
try:
import signal
@ -156,6 +158,15 @@ class IOLoop(Configurable):
assert not IOLoop.initialized()
IOLoop._instance = self
@staticmethod
def clear_instance():
"""Clear the global `IOLoop` instance.
.. versionadded:: 4.0
"""
if hasattr(IOLoop, "_instance"):
del IOLoop._instance
@staticmethod
def current():
"""Returns the current thread's `IOLoop`.
@ -244,21 +255,40 @@ class IOLoop(Configurable):
raise NotImplementedError()
def add_handler(self, fd, handler, events):
"""Registers the given handler to receive the given events for fd.
"""Registers the given handler to receive the given events for ``fd``.
The ``fd`` argument may either be an integer file descriptor or
a file-like object with a ``fileno()`` method (and optionally a
``close()`` method, which may be called when the `IOLoop` is shut
down).
The ``events`` argument is a bitwise or of the constants
``IOLoop.READ``, ``IOLoop.WRITE``, and ``IOLoop.ERROR``.
When an event occurs, ``handler(fd, events)`` will be run.
.. versionchanged:: 4.0
Added the ability to pass file-like objects in addition to
raw file descriptors.
"""
raise NotImplementedError()
def update_handler(self, fd, events):
"""Changes the events we listen for fd."""
"""Changes the events we listen for ``fd``.
.. versionchanged:: 4.0
Added the ability to pass file-like objects in addition to
raw file descriptors.
"""
raise NotImplementedError()
def remove_handler(self, fd):
"""Stop listening for events on fd."""
"""Stop listening for events on ``fd``.
.. versionchanged:: 4.0
Added the ability to pass file-like objects in addition to
raw file descriptors.
"""
raise NotImplementedError()
def set_blocking_signal_threshold(self, seconds, action):
@ -372,7 +402,7 @@ class IOLoop(Configurable):
future_cell[0] = TracebackFuture()
future_cell[0].set_exc_info(sys.exc_info())
else:
if isinstance(result, Future):
if is_future(result):
future_cell[0] = result
else:
future_cell[0] = TracebackFuture()
@ -456,6 +486,19 @@ class IOLoop(Configurable):
"""
raise NotImplementedError()
def spawn_callback(self, callback, *args, **kwargs):
"""Calls the given callback on the next IOLoop iteration.
Unlike all other callback-related methods on IOLoop,
``spawn_callback`` does not associate the callback with its caller's
``stack_context``, so it is suitable for fire-and-forget callbacks
that should not interfere with the caller.
.. versionadded:: 4.0
"""
with stack_context.NullContext():
self.add_callback(callback, *args, **kwargs)
def add_future(self, future, callback):
"""Schedules a callback on the ``IOLoop`` when the given
`.Future` is finished.
@ -463,7 +506,7 @@ class IOLoop(Configurable):
The callback is invoked with one argument, the
`.Future`.
"""
assert isinstance(future, Future)
assert is_future(future)
callback = stack_context.wrap(callback)
future.add_done_callback(
lambda future: self.add_callback(callback, future))
@ -474,7 +517,13 @@ class IOLoop(Configurable):
For use in subclasses.
"""
try:
callback()
ret = callback()
if ret is not None and is_future(ret):
# Functions that return Futures typically swallow all
# exceptions and store them in the Future. If a Future
# makes it out to the IOLoop, ensure its exception (if any)
# gets logged too.
self.add_future(ret, lambda f: f.result())
except Exception:
self.handle_callback_exception(callback)
@ -490,6 +539,47 @@ class IOLoop(Configurable):
"""
app_log.error("Exception in callback %r", callback, exc_info=True)
def split_fd(self, fd):
"""Returns an (fd, obj) pair from an ``fd`` parameter.
We accept both raw file descriptors and file-like objects as
input to `add_handler` and related methods. When a file-like
object is passed, we must retain the object itself so we can
close it correctly when the `IOLoop` shuts down, but the
poller interfaces favor file descriptors (they will accept
file-like objects and call ``fileno()`` for you, but they
always return the descriptor itself).
This method is provided for use by `IOLoop` subclasses and should
not generally be used by application code.
.. versionadded:: 4.0
"""
try:
return fd.fileno(), fd
except AttributeError:
return fd, fd
def close_fd(self, fd):
"""Utility method to close an ``fd``.
If ``fd`` is a file-like object, we close it directly; otherwise
we use `os.close`.
This method is provided for use by `IOLoop` subclasses (in
implementations of ``IOLoop.close(all_fds=True)`` and should
not generally be used by application code.
.. versionadded:: 4.0
"""
try:
try:
fd.close()
except AttributeError:
os.close(fd)
except OSError:
pass
class PollIOLoop(IOLoop):
"""Base class for IOLoops built around a select-like function.
@ -515,7 +605,8 @@ class PollIOLoop(IOLoop):
self._closing = False
self._thread_ident = None
self._blocking_signal_threshold = None
self._timeout_counter = itertools.count()
# Create a pipe that we send bogus data to when we want to wake
# the I/O loop when it is idle
self._waker = Waker()
@ -528,26 +619,24 @@ class PollIOLoop(IOLoop):
self._closing = True
self.remove_handler(self._waker.fileno())
if all_fds:
for fd in self._handlers.keys():
try:
close_method = getattr(fd, 'close', None)
if close_method is not None:
close_method()
else:
os.close(fd)
except Exception:
gen_log.debug("error closing fd %s", fd, exc_info=True)
for fd, handler in self._handlers.values():
self.close_fd(fd)
self._waker.close()
self._impl.close()
self._callbacks = None
self._timeouts = None
def add_handler(self, fd, handler, events):
self._handlers[fd] = stack_context.wrap(handler)
fd, obj = self.split_fd(fd)
self._handlers[fd] = (obj, stack_context.wrap(handler))
self._impl.register(fd, events | self.ERROR)
def update_handler(self, fd, events):
fd, obj = self.split_fd(fd)
self._impl.modify(fd, events | self.ERROR)
def remove_handler(self, fd):
fd, obj = self.split_fd(fd)
self._handlers.pop(fd, None)
self._events.pop(fd, None)
try:
@ -566,6 +655,8 @@ class PollIOLoop(IOLoop):
action if action is not None else signal.SIG_DFL)
def start(self):
if self._running:
raise RuntimeError("IOLoop is already running")
self._setup_logging()
if self._stopped:
self._stopped = False
@ -608,19 +699,16 @@ class PollIOLoop(IOLoop):
try:
while True:
poll_timeout = _POLL_TIMEOUT
# Prevent IO event starvation by delaying new callbacks
# to the next iteration of the event loop.
with self._callback_lock:
callbacks = self._callbacks
self._callbacks = []
for callback in callbacks:
self._run_callback(callback)
# Closures may be holding on to a lot of memory, so allow
# them to be freed before we go into our poll wait.
callbacks = callback = None
# Add any timeouts that have come due to the callback list.
# Do not run anything until we have determined which ones
# are ready, so timeouts that call add_timeout cannot
# schedule anything in this iteration.
if self._timeouts:
now = self.time()
while self._timeouts:
@ -630,11 +718,9 @@ class PollIOLoop(IOLoop):
self._cancellations -= 1
elif self._timeouts[0].deadline <= now:
timeout = heapq.heappop(self._timeouts)
self._run_callback(timeout.callback)
callbacks.append(timeout.callback)
del timeout
else:
seconds = self._timeouts[0].deadline - now
poll_timeout = min(seconds, poll_timeout)
break
if (self._cancellations > 512
and self._cancellations > (len(self._timeouts) >> 1)):
@ -645,10 +731,25 @@ class PollIOLoop(IOLoop):
if x.callback is not None]
heapq.heapify(self._timeouts)
for callback in callbacks:
self._run_callback(callback)
# Closures may be holding on to a lot of memory, so allow
# them to be freed before we go into our poll wait.
callbacks = callback = None
if self._callbacks:
# If any callbacks or timeouts called add_callback,
# we don't want to wait in poll() before we run them.
poll_timeout = 0.0
elif self._timeouts:
# If there are any timeouts, schedule the first one.
# Use self.time() instead of 'now' to account for time
# spent running callbacks.
poll_timeout = self._timeouts[0].deadline - self.time()
poll_timeout = max(0, min(poll_timeout, _POLL_TIMEOUT))
else:
# No timeouts and no callbacks, so use the default.
poll_timeout = _POLL_TIMEOUT
if not self._running:
break
@ -666,9 +767,7 @@ class PollIOLoop(IOLoop):
# two ways EINTR might be signaled:
# * e.errno == errno.EINTR
# * e.args is like (errno.EINTR, 'Interrupted system call')
if (getattr(e, 'errno', None) == errno.EINTR or
(isinstance(getattr(e, 'args', None), tuple) and
len(e.args) == 2 and e.args[0] == errno.EINTR)):
if errno_from_exception(e) == errno.EINTR:
continue
else:
raise
@ -685,15 +784,17 @@ class PollIOLoop(IOLoop):
while self._events:
fd, events = self._events.popitem()
try:
self._handlers[fd](fd, events)
fd_obj, handler_func = self._handlers[fd]
handler_func(fd_obj, events)
except (OSError, IOError) as e:
if e.args[0] == errno.EPIPE:
if errno_from_exception(e) == errno.EPIPE:
# Happens when the client closes the connection
pass
else:
self.handle_callback_exception(self._handlers.get(fd))
except Exception:
self.handle_callback_exception(self._handlers.get(fd))
fd_obj = handler_func = None
finally:
# reset the stopped flag so another start/stop pair can be issued
@ -765,16 +866,21 @@ class _Timeout(object):
"""An IOLoop timeout, a UNIX timestamp and a callback"""
# Reduce memory overhead when there are lots of pending callbacks
__slots__ = ['deadline', 'callback']
__slots__ = ['deadline', 'callback', 'tiebreaker']
def __init__(self, deadline, callback, io_loop):
if isinstance(deadline, numbers.Real):
self.deadline = deadline
elif isinstance(deadline, datetime.timedelta):
self.deadline = io_loop.time() + _Timeout.timedelta_to_seconds(deadline)
now = io_loop.time()
try:
self.deadline = now + deadline.total_seconds()
except AttributeError: # py2.6
self.deadline = now + _Timeout.timedelta_to_seconds(deadline)
else:
raise TypeError("Unsupported deadline %r" % deadline)
self.callback = callback
self.tiebreaker = next(io_loop._timeout_counter)
@staticmethod
def timedelta_to_seconds(td):
@ -786,12 +892,12 @@ class _Timeout(object):
# in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons
# use __lt__).
def __lt__(self, other):
return ((self.deadline, id(self)) <
(other.deadline, id(other)))
return ((self.deadline, self.tiebreaker) <
(other.deadline, other.tiebreaker))
def __le__(self, other):
return ((self.deadline, id(self)) <=
(other.deadline, id(other)))
return ((self.deadline, self.tiebreaker) <=
(other.deadline, other.tiebreaker))
class PeriodicCallback(object):

View file

@ -31,21 +31,27 @@ import errno
import numbers
import os
import socket
import ssl
import sys
import re
from tornado.concurrent import TracebackFuture
from tornado import ioloop
from tornado.log import gen_log, app_log
from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError
from tornado import stack_context
from tornado.util import bytes_type
from tornado.util import bytes_type, errno_from_exception
try:
from tornado.platform.posix import _set_nonblocking
except ImportError:
_set_nonblocking = None
try:
import ssl
except ImportError:
# ssl is not available on Google App Engine
ssl = None
# These errnos indicate that a non-blocking operation must be retried
# at a later time. On most platforms they're the same value, but on
# some they differ.
@ -53,7 +59,8 @@ _ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
# These errnos indicate that a connection has been abruptly terminated.
# They should be caught and handled less noisily than other errors.
_ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE)
_ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE,
errno.ETIMEDOUT)
class StreamClosedError(IOError):
@ -66,12 +73,31 @@ class StreamClosedError(IOError):
pass
class UnsatisfiableReadError(Exception):
"""Exception raised when a read cannot be satisfied.
Raised by ``read_until`` and ``read_until_regex`` with a ``max_bytes``
argument.
"""
pass
class StreamBufferFullError(Exception):
"""Exception raised by `IOStream` methods when the buffer is full.
"""
class BaseIOStream(object):
"""A utility class to write to and read from a non-blocking file or socket.
We support a non-blocking ``write()`` and a family of ``read_*()`` methods.
All of the methods take callbacks (since writing and reading are
non-blocking and asynchronous).
All of the methods take an optional ``callback`` argument and return a
`.Future` only if no callback is given. When the operation completes,
the callback will be run or the `.Future` will resolve with the data
read (or ``None`` for ``write()``). All outstanding ``Futures`` will
resolve with a `StreamClosedError` when the stream is closed; users
of the callback interface will be notified via
`.BaseIOStream.set_close_callback` instead.
When a stream is closed due to an error, the IOStream's ``error``
attribute contains the exception object.
@ -80,24 +106,48 @@ class BaseIOStream(object):
`read_from_fd`, and optionally `get_fd_error`.
"""
def __init__(self, io_loop=None, max_buffer_size=None,
read_chunk_size=4096):
read_chunk_size=None, max_write_buffer_size=None):
"""`BaseIOStream` constructor.
:arg io_loop: The `.IOLoop` to use; defaults to `.IOLoop.current`.
:arg max_buffer_size: Maximum amount of incoming data to buffer;
defaults to 100MB.
:arg read_chunk_size: Amount of data to read at one time from the
underlying transport; defaults to 64KB.
:arg max_write_buffer_size: Amount of outgoing data to buffer;
defaults to unlimited.
.. versionchanged:: 4.0
Add the ``max_write_buffer_size`` parameter. Changed default
``read_chunk_size`` to 64KB.
"""
self.io_loop = io_loop or ioloop.IOLoop.current()
self.max_buffer_size = max_buffer_size or 104857600
self.read_chunk_size = read_chunk_size
# A chunk size that is too close to max_buffer_size can cause
# spurious failures.
self.read_chunk_size = min(read_chunk_size or 65536,
self.max_buffer_size // 2)
self.max_write_buffer_size = max_write_buffer_size
self.error = None
self._read_buffer = collections.deque()
self._write_buffer = collections.deque()
self._read_buffer_size = 0
self._write_buffer_size = 0
self._write_buffer_frozen = False
self._read_delimiter = None
self._read_regex = None
self._read_max_bytes = None
self._read_bytes = None
self._read_partial = False
self._read_until_close = False
self._read_callback = None
self._read_future = None
self._streaming_callback = None
self._write_callback = None
self._write_future = None
self._close_callback = None
self._connect_callback = None
self._connect_future = None
self._connecting = False
self._state = None
self._pending_callbacks = 0
@ -142,98 +192,162 @@ class BaseIOStream(object):
"""
return None
def read_until_regex(self, regex, callback):
"""Run ``callback`` when we read the given regex pattern.
def read_until_regex(self, regex, callback=None, max_bytes=None):
"""Asynchronously read until we have matched the given regex.
The callback will get the data read (including the data that
matched the regex and anything that came before it) as an argument.
The result includes the data that matches the regex and anything
that came before it. If a callback is given, it will be run
with the data as an argument; if not, this method returns a
`.Future`.
If ``max_bytes`` is not None, the connection will be closed
if more than ``max_bytes`` bytes have been read and the regex is
not satisfied.
.. versionchanged:: 4.0
Added the ``max_bytes`` argument. The ``callback`` argument is
now optional and a `.Future` will be returned if it is omitted.
"""
self._set_read_callback(callback)
future = self._set_read_callback(callback)
self._read_regex = re.compile(regex)
self._try_inline_read()
self._read_max_bytes = max_bytes
try:
self._try_inline_read()
except UnsatisfiableReadError as e:
# Handle this the same way as in _handle_events.
gen_log.info("Unsatisfiable read, closing connection: %s" % e)
self.close(exc_info=True)
return future
return future
def read_until(self, delimiter, callback):
"""Run ``callback`` when we read the given delimiter.
def read_until(self, delimiter, callback=None, max_bytes=None):
"""Asynchronously read until we have found the given delimiter.
The callback will get the data read (including the delimiter)
as an argument.
The result includes all the data read including the delimiter.
If a callback is given, it will be run with the data as an argument;
if not, this method returns a `.Future`.
If ``max_bytes`` is not None, the connection will be closed
if more than ``max_bytes`` bytes have been read and the delimiter
is not found.
.. versionchanged:: 4.0
Added the ``max_bytes`` argument. The ``callback`` argument is
now optional and a `.Future` will be returned if it is omitted.
"""
self._set_read_callback(callback)
future = self._set_read_callback(callback)
self._read_delimiter = delimiter
self._try_inline_read()
self._read_max_bytes = max_bytes
try:
self._try_inline_read()
except UnsatisfiableReadError as e:
# Handle this the same way as in _handle_events.
gen_log.info("Unsatisfiable read, closing connection: %s" % e)
self.close(exc_info=True)
return future
return future
def read_bytes(self, num_bytes, callback, streaming_callback=None):
"""Run callback when we read the given number of bytes.
def read_bytes(self, num_bytes, callback=None, streaming_callback=None,
partial=False):
"""Asynchronously read a number of bytes.
If a ``streaming_callback`` is given, it will be called with chunks
of data as they become available, and the argument to the final
``callback`` will be empty. Otherwise, the ``callback`` gets
the data as an argument.
of data as they become available, and the final result will be empty.
Otherwise, the result is all the data that was read.
If a callback is given, it will be run with the data as an argument;
if not, this method returns a `.Future`.
If ``partial`` is true, the callback is run as soon as we have
any bytes to return (but never more than ``num_bytes``)
.. versionchanged:: 4.0
Added the ``partial`` argument. The callback argument is now
optional and a `.Future` will be returned if it is omitted.
"""
self._set_read_callback(callback)
future = self._set_read_callback(callback)
assert isinstance(num_bytes, numbers.Integral)
self._read_bytes = num_bytes
self._read_partial = partial
self._streaming_callback = stack_context.wrap(streaming_callback)
self._try_inline_read()
return future
def read_until_close(self, callback, streaming_callback=None):
"""Reads all data from the socket until it is closed.
def read_until_close(self, callback=None, streaming_callback=None):
"""Asynchronously reads all data from the socket until it is closed.
If a ``streaming_callback`` is given, it will be called with chunks
of data as they become available, and the argument to the final
``callback`` will be empty. Otherwise, the ``callback`` gets the
data as an argument.
of data as they become available, and the final result will be empty.
Otherwise, the result is all the data that was read.
If a callback is given, it will be run with the data as an argument;
if not, this method returns a `.Future`.
Subject to ``max_buffer_size`` limit from `IOStream` constructor if
a ``streaming_callback`` is not used.
.. versionchanged:: 4.0
The callback argument is now optional and a `.Future` will
be returned if it is omitted.
"""
self._set_read_callback(callback)
future = self._set_read_callback(callback)
self._streaming_callback = stack_context.wrap(streaming_callback)
if self.closed():
if self._streaming_callback is not None:
self._run_callback(self._streaming_callback,
self._consume(self._read_buffer_size))
self._run_callback(self._read_callback,
self._consume(self._read_buffer_size))
self._streaming_callback = None
self._read_callback = None
return
self._run_read_callback(self._read_buffer_size, True)
self._run_read_callback(self._read_buffer_size, False)
return future
self._read_until_close = True
self._streaming_callback = stack_context.wrap(streaming_callback)
self._try_inline_read()
return future
def write(self, data, callback=None):
"""Write the given data to this stream.
"""Asynchronously write the given data to this stream.
If ``callback`` is given, we call it when all of the buffered write
data has been successfully written to the stream. If there was
previously buffered write data and an old write callback, that
callback is simply overwritten with this new callback.
If no ``callback`` is given, this method returns a `.Future` that
resolves (with a result of ``None``) when the write has been
completed. If `write` is called again before that `.Future` has
resolved, the previous future will be orphaned and will never resolve.
.. versionchanged:: 4.0
Now returns a `.Future` if no callback is given.
"""
assert isinstance(data, bytes_type)
self._check_closed()
# We use bool(_write_buffer) as a proxy for write_buffer_size>0,
# so never put empty strings in the buffer.
if data:
if (self.max_write_buffer_size is not None and
self._write_buffer_size + len(data) > self.max_write_buffer_size):
raise StreamBufferFullError("Reached maximum read buffer size")
# Break up large contiguous strings before inserting them in the
# write buffer, so we don't have to recopy the entire thing
# as we slice off pieces to send to the socket.
WRITE_BUFFER_CHUNK_SIZE = 128 * 1024
if len(data) > WRITE_BUFFER_CHUNK_SIZE:
for i in range(0, len(data), WRITE_BUFFER_CHUNK_SIZE):
self._write_buffer.append(data[i:i + WRITE_BUFFER_CHUNK_SIZE])
else:
self._write_buffer.append(data)
self._write_callback = stack_context.wrap(callback)
for i in range(0, len(data), WRITE_BUFFER_CHUNK_SIZE):
self._write_buffer.append(data[i:i + WRITE_BUFFER_CHUNK_SIZE])
self._write_buffer_size += len(data)
if callback is not None:
self._write_callback = stack_context.wrap(callback)
future = None
else:
future = self._write_future = TracebackFuture()
if not self._connecting:
self._handle_write()
if self._write_buffer:
self._add_io_state(self.io_loop.WRITE)
self._maybe_add_error_listener()
return future
def set_close_callback(self, callback):
"""Call the given callback when the stream is closed."""
"""Call the given callback when the stream is closed.
This is not necessary for applications that use the `.Future`
interface; all outstanding ``Futures`` will resolve with a
`StreamClosedError` when the stream is closed.
"""
self._close_callback = stack_context.wrap(callback)
self._maybe_add_error_listener()
def close(self, exc_info=False):
"""Close this stream.
@ -251,13 +365,9 @@ class BaseIOStream(object):
if self._read_until_close:
if (self._streaming_callback is not None and
self._read_buffer_size):
self._run_callback(self._streaming_callback,
self._consume(self._read_buffer_size))
callback = self._read_callback
self._read_callback = None
self._run_read_callback(self._read_buffer_size, True)
self._read_until_close = False
self._run_callback(callback,
self._consume(self._read_buffer_size))
self._run_read_callback(self._read_buffer_size, False)
if self._state is not None:
self.io_loop.remove_handler(self.fileno())
self._state = None
@ -269,6 +379,25 @@ class BaseIOStream(object):
# If there are pending callbacks, don't run the close callback
# until they're done (see _maybe_add_error_handler)
if self.closed() and self._pending_callbacks == 0:
futures = []
if self._read_future is not None:
futures.append(self._read_future)
self._read_future = None
if self._write_future is not None:
futures.append(self._write_future)
self._write_future = None
if self._connect_future is not None:
futures.append(self._connect_future)
self._connect_future = None
for future in futures:
if (isinstance(self.error, (socket.error, IOError)) and
errno_from_exception(self.error) in _ERRNO_CONNRESET):
# Treat connection resets as closed connections so
# clients only have to catch one kind of exception
# to avoid logging.
future.set_exception(StreamClosedError())
else:
future.set_exception(self.error or StreamClosedError())
if self._close_callback is not None:
cb = self._close_callback
self._close_callback = None
@ -282,7 +411,7 @@ class BaseIOStream(object):
def reading(self):
"""Returns true if we are currently reading from the stream."""
return self._read_callback is not None
return self._read_callback is not None or self._read_future is not None
def writing(self):
"""Returns true if we are currently writing to the stream."""
@ -309,16 +438,22 @@ class BaseIOStream(object):
def _handle_events(self, fd, events):
if self.closed():
gen_log.warning("Got events for closed stream %d", fd)
gen_log.warning("Got events for closed stream %s", fd)
return
try:
if self._connecting:
# Most IOLoops will report a write failed connect
# with the WRITE event, but SelectIOLoop reports a
# READ as well so we must check for connecting before
# either.
self._handle_connect()
if self.closed():
return
if events & self.io_loop.READ:
self._handle_read()
if self.closed():
return
if events & self.io_loop.WRITE:
if self._connecting:
self._handle_connect()
self._handle_write()
if self.closed():
return
@ -334,13 +469,20 @@ class BaseIOStream(object):
state |= self.io_loop.READ
if self.writing():
state |= self.io_loop.WRITE
if state == self.io_loop.ERROR:
if state == self.io_loop.ERROR and self._read_buffer_size == 0:
# If the connection is idle, listen for reads too so
# we can tell if the connection is closed. If there is
# data in the read buffer we won't run the close callback
# yet anyway, so we don't need to listen in this case.
state |= self.io_loop.READ
if state != self._state:
assert self._state is not None, \
"shouldn't happen: _handle_events without self._state"
self._state = state
self.io_loop.update_handler(self.fileno(), self._state)
except UnsatisfiableReadError as e:
gen_log.info("Unsatisfiable read, closing connection: %s" % e)
self.close(exc_info=True)
except Exception:
gen_log.error("Uncaught exception, closing connection.",
exc_info=True)
@ -381,42 +523,108 @@ class BaseIOStream(object):
self._pending_callbacks += 1
self.io_loop.add_callback(wrapper)
def _read_to_buffer_loop(self):
# This method is called from _handle_read and _try_inline_read.
try:
if self._read_bytes is not None:
target_bytes = self._read_bytes
elif self._read_max_bytes is not None:
target_bytes = self._read_max_bytes
elif self.reading():
# For read_until without max_bytes, or
# read_until_close, read as much as we can before
# scanning for the delimiter.
target_bytes = None
else:
target_bytes = 0
next_find_pos = 0
# Pretend to have a pending callback so that an EOF in
# _read_to_buffer doesn't trigger an immediate close
# callback. At the end of this method we'll either
# estabilsh a real pending callback via
# _read_from_buffer or run the close callback.
#
# We need two try statements here so that
# pending_callbacks is decremented before the `except`
# clause below (which calls `close` and does need to
# trigger the callback)
self._pending_callbacks += 1
while not self.closed():
# Read from the socket until we get EWOULDBLOCK or equivalent.
# SSL sockets do some internal buffering, and if the data is
# sitting in the SSL object's buffer select() and friends
# can't see it; the only way to find out if it's there is to
# try to read it.
if self._read_to_buffer() == 0:
break
self._run_streaming_callback()
# If we've read all the bytes we can use, break out of
# this loop. We can't just call read_from_buffer here
# because of subtle interactions with the
# pending_callback and error_listener mechanisms.
#
# If we've reached target_bytes, we know we're done.
if (target_bytes is not None and
self._read_buffer_size >= target_bytes):
break
# Otherwise, we need to call the more expensive find_read_pos.
# It's inefficient to do this on every read, so instead
# do it on the first read and whenever the read buffer
# size has doubled.
if self._read_buffer_size >= next_find_pos:
pos = self._find_read_pos()
if pos is not None:
return pos
next_find_pos = self._read_buffer_size * 2
return self._find_read_pos()
finally:
self._pending_callbacks -= 1
def _handle_read(self):
try:
try:
# Pretend to have a pending callback so that an EOF in
# _read_to_buffer doesn't trigger an immediate close
# callback. At the end of this method we'll either
# estabilsh a real pending callback via
# _read_from_buffer or run the close callback.
#
# We need two try statements here so that
# pending_callbacks is decremented before the `except`
# clause below (which calls `close` and does need to
# trigger the callback)
self._pending_callbacks += 1
while not self.closed():
# Read from the socket until we get EWOULDBLOCK or equivalent.
# SSL sockets do some internal buffering, and if the data is
# sitting in the SSL object's buffer select() and friends
# can't see it; the only way to find out if it's there is to
# try to read it.
if self._read_to_buffer() == 0:
break
finally:
self._pending_callbacks -= 1
pos = self._read_to_buffer_loop()
except UnsatisfiableReadError:
raise
except Exception:
gen_log.warning("error on read", exc_info=True)
self.close(exc_info=True)
return
if self._read_from_buffer():
if pos is not None:
self._read_from_buffer(pos)
return
else:
self._maybe_run_close_callback()
def _set_read_callback(self, callback):
assert not self._read_callback, "Already reading"
self._read_callback = stack_context.wrap(callback)
assert self._read_callback is None, "Already reading"
assert self._read_future is None, "Already reading"
if callback is not None:
self._read_callback = stack_context.wrap(callback)
else:
self._read_future = TracebackFuture()
return self._read_future
def _run_read_callback(self, size, streaming):
if streaming:
callback = self._streaming_callback
else:
callback = self._read_callback
self._read_callback = self._streaming_callback = None
if self._read_future is not None:
assert callback is None
future = self._read_future
self._read_future = None
future.set_result(self._consume(size))
if callback is not None:
assert self._read_future is None
self._run_callback(callback, self._consume(size))
else:
# If we scheduled a callback, we will add the error listener
# afterwards. If we didn't, we have to do it now.
self._maybe_add_error_listener()
def _try_inline_read(self):
"""Attempt to complete the current read operation from buffered data.
@ -426,18 +634,14 @@ class BaseIOStream(object):
listening for reads on the socket.
"""
# See if we've already got the data from a previous read
if self._read_from_buffer():
self._run_streaming_callback()
pos = self._find_read_pos()
if pos is not None:
self._read_from_buffer(pos)
return
self._check_closed()
try:
try:
# See comments in _handle_read about incrementing _pending_callbacks
self._pending_callbacks += 1
while not self.closed():
if self._read_to_buffer() == 0:
break
finally:
self._pending_callbacks -= 1
pos = self._read_to_buffer_loop()
except Exception:
# If there was an in _read_to_buffer, we called close() already,
# but couldn't run the close callback because of _pending_callbacks.
@ -445,9 +649,15 @@ class BaseIOStream(object):
# applicable.
self._maybe_run_close_callback()
raise
if self._read_from_buffer():
if pos is not None:
self._read_from_buffer(pos)
return
self._maybe_add_error_listener()
# We couldn't satisfy the read inline, so either close the stream
# or listen for new data.
if self.closed():
self._maybe_run_close_callback()
else:
self._add_io_state(ioloop.IOLoop.READ)
def _read_to_buffer(self):
"""Reads from the socket and appends the result to the read buffer.
@ -472,32 +682,42 @@ class BaseIOStream(object):
return 0
self._read_buffer.append(chunk)
self._read_buffer_size += len(chunk)
if self._read_buffer_size >= self.max_buffer_size:
if self._read_buffer_size > self.max_buffer_size:
gen_log.error("Reached maximum read buffer size")
self.close()
raise IOError("Reached maximum read buffer size")
raise StreamBufferFullError("Reached maximum read buffer size")
return len(chunk)
def _read_from_buffer(self):
"""Attempts to complete the currently-pending read from the buffer.
Returns True if the read was completed.
"""
def _run_streaming_callback(self):
if self._streaming_callback is not None and self._read_buffer_size:
bytes_to_consume = self._read_buffer_size
if self._read_bytes is not None:
bytes_to_consume = min(self._read_bytes, bytes_to_consume)
self._read_bytes -= bytes_to_consume
self._run_callback(self._streaming_callback,
self._consume(bytes_to_consume))
if self._read_bytes is not None and self._read_buffer_size >= self._read_bytes:
num_bytes = self._read_bytes
callback = self._read_callback
self._read_callback = None
self._streaming_callback = None
self._read_bytes = None
self._run_callback(callback, self._consume(num_bytes))
return True
self._run_read_callback(bytes_to_consume, True)
def _read_from_buffer(self, pos):
"""Attempts to complete the currently-pending read from the buffer.
The argument is either a position in the read buffer or None,
as returned by _find_read_pos.
"""
self._read_bytes = self._read_delimiter = self._read_regex = None
self._read_partial = False
self._run_read_callback(pos, False)
def _find_read_pos(self):
"""Attempts to find a position in the read buffer that satisfies
the currently-pending read.
Returns a position in the buffer if the current read can be satisfied,
or None if it cannot.
"""
if (self._read_bytes is not None and
(self._read_buffer_size >= self._read_bytes or
(self._read_partial and self._read_buffer_size > 0))):
num_bytes = min(self._read_bytes, self._read_buffer_size)
return num_bytes
elif self._read_delimiter is not None:
# Multi-byte delimiters (e.g. '\r\n') may straddle two
# chunks in the read buffer, so we can't easily find them
@ -506,37 +726,40 @@ class BaseIOStream(object):
# length) tend to be "line" oriented, the delimiter is likely
# to be in the first few chunks. Merge the buffer gradually
# since large merges are relatively expensive and get undone in
# consume().
# _consume().
if self._read_buffer:
while True:
loc = self._read_buffer[0].find(self._read_delimiter)
if loc != -1:
callback = self._read_callback
delimiter_len = len(self._read_delimiter)
self._read_callback = None
self._streaming_callback = None
self._read_delimiter = None
self._run_callback(callback,
self._consume(loc + delimiter_len))
return True
self._check_max_bytes(self._read_delimiter,
loc + delimiter_len)
return loc + delimiter_len
if len(self._read_buffer) == 1:
break
_double_prefix(self._read_buffer)
self._check_max_bytes(self._read_delimiter,
len(self._read_buffer[0]))
elif self._read_regex is not None:
if self._read_buffer:
while True:
m = self._read_regex.search(self._read_buffer[0])
if m is not None:
callback = self._read_callback
self._read_callback = None
self._streaming_callback = None
self._read_regex = None
self._run_callback(callback, self._consume(m.end()))
return True
self._check_max_bytes(self._read_regex, m.end())
return m.end()
if len(self._read_buffer) == 1:
break
_double_prefix(self._read_buffer)
return False
self._check_max_bytes(self._read_regex,
len(self._read_buffer[0]))
return None
def _check_max_bytes(self, delimiter, size):
if (self._read_max_bytes is not None and
size > self._read_max_bytes):
raise UnsatisfiableReadError(
"delimiter %r not found within %d bytes" % (
delimiter, self._read_max_bytes))
def _handle_write(self):
while self._write_buffer:
@ -563,6 +786,7 @@ class BaseIOStream(object):
self._write_buffer_frozen = False
_merge_prefix(self._write_buffer, num_bytes)
self._write_buffer.popleft()
self._write_buffer_size -= num_bytes
except (socket.error, IOError, OSError) as e:
if e.args[0] in _ERRNO_WOULDBLOCK:
self._write_buffer_frozen = True
@ -572,14 +796,19 @@ class BaseIOStream(object):
# Broken pipe errors are usually caused by connection
# reset, and its better to not log EPIPE errors to
# minimize log spam
gen_log.warning("Write error on %d: %s",
gen_log.warning("Write error on %s: %s",
self.fileno(), e)
self.close(exc_info=True)
return
if not self._write_buffer and self._write_callback:
callback = self._write_callback
self._write_callback = None
self._run_callback(callback)
if not self._write_buffer:
if self._write_callback:
callback = self._write_callback
self._write_callback = None
self._run_callback(callback)
if self._write_future:
future = self._write_future
self._write_future = None
future.set_result(None)
def _consume(self, loc):
if loc == 0:
@ -593,10 +822,19 @@ class BaseIOStream(object):
raise StreamClosedError("Stream is closed")
def _maybe_add_error_listener(self):
if self._state is None and self._pending_callbacks == 0:
# This method is part of an optimization: to detect a connection that
# is closed when we're not actively reading or writing, we must listen
# for read events. However, it is inefficient to do this when the
# connection is first established because we are going to read or write
# immediately anyway. Instead, we insert checks at various times to
# see if the connection is idle and add the read listener then.
if self._pending_callbacks != 0:
return
if self._state is None or self._state == ioloop.IOLoop.ERROR:
if self.closed():
self._maybe_run_close_callback()
else:
elif (self._read_buffer_size == 0 and
self._close_callback is not None):
self._add_io_state(ioloop.IOLoop.READ)
def _add_io_state(self, state):
@ -680,7 +918,7 @@ class IOStream(BaseIOStream):
super(IOStream, self).__init__(*args, **kwargs)
def fileno(self):
return self.socket.fileno()
return self.socket
def close_fd(self):
self.socket.close()
@ -712,9 +950,19 @@ class IOStream(BaseIOStream):
May only be called if the socket passed to the constructor was
not previously connected. The address parameter is in the
same format as for `socket.connect <socket.socket.connect>`,
i.e. a ``(host, port)`` tuple. If ``callback`` is specified,
it will be called when the connection is completed.
same format as for `socket.connect <socket.socket.connect>` for
the type of socket passed to the IOStream constructor,
e.g. an ``(ip, port)`` tuple. Hostnames are accepted here,
but will be resolved synchronously and block the IOLoop.
If you have a hostname instead of an IP address, the `.TCPClient`
class is recommended instead of calling this method directly.
`.TCPClient` will do asynchronous DNS resolution and handle
both IPv4 and IPv6.
If ``callback`` is specified, it will be called with no
arguments when the connection is completed; if not this method
returns a `.Future` (whose result after a successful
connection will be the stream itself).
If specified, the ``server_hostname`` parameter will be used
in SSL connections for certificate validation (if requested in
@ -726,6 +974,10 @@ class IOStream(BaseIOStream):
which case the data will be written as soon as the connection
is ready. Calling `IOStream` read methods before the socket is
connected works on some platforms but is non-portable.
.. versionchanged:: 4.0
If no callback is given, returns a `.Future`.
"""
self._connecting = True
try:
@ -738,14 +990,83 @@ class IOStream(BaseIOStream):
# returned immediately when attempting to connect to
# localhost, so handle them the same way as an error
# reported later in _handle_connect.
if (e.args[0] != errno.EINPROGRESS and
e.args[0] not in _ERRNO_WOULDBLOCK):
gen_log.warning("Connect error on fd %d: %s",
if (errno_from_exception(e) != errno.EINPROGRESS and
errno_from_exception(e) not in _ERRNO_WOULDBLOCK):
gen_log.warning("Connect error on fd %s: %s",
self.socket.fileno(), e)
self.close(exc_info=True)
return
self._connect_callback = stack_context.wrap(callback)
if callback is not None:
self._connect_callback = stack_context.wrap(callback)
future = None
else:
future = self._connect_future = TracebackFuture()
self._add_io_state(self.io_loop.WRITE)
return future
def start_tls(self, server_side, ssl_options=None, server_hostname=None):
"""Convert this `IOStream` to an `SSLIOStream`.
This enables protocols that begin in clear-text mode and
switch to SSL after some initial negotiation (such as the
``STARTTLS`` extension to SMTP and IMAP).
This method cannot be used if there are outstanding reads
or writes on the stream, or if there is any data in the
IOStream's buffer (data in the operating system's socket
buffer is allowed). This means it must generally be used
immediately after reading or writing the last clear-text
data. It can also be used immediately after connecting,
before any reads or writes.
The ``ssl_options`` argument may be either a dictionary
of options or an `ssl.SSLContext`. If a ``server_hostname``
is given, it will be used for certificate verification
(as configured in the ``ssl_options``).
This method returns a `.Future` whose result is the new
`SSLIOStream`. After this method has been called,
any other operation on the original stream is undefined.
If a close callback is defined on this stream, it will be
transferred to the new stream.
.. versionadded:: 4.0
"""
if (self._read_callback or self._read_future or
self._write_callback or self._write_future or
self._connect_callback or self._connect_future or
self._pending_callbacks or self._closed or
self._read_buffer or self._write_buffer):
raise ValueError("IOStream is not idle; cannot convert to SSL")
if ssl_options is None:
ssl_options = {}
socket = self.socket
self.io_loop.remove_handler(socket)
self.socket = None
socket = ssl_wrap_socket(socket, ssl_options, server_side=server_side,
do_handshake_on_connect=False)
orig_close_callback = self._close_callback
self._close_callback = None
future = TracebackFuture()
ssl_stream = SSLIOStream(socket, ssl_options=ssl_options,
io_loop=self.io_loop)
# Wrap the original close callback so we can fail our Future as well.
# If we had an "unwrap" counterpart to this method we would need
# to restore the original callback after our Future resolves
# so that repeated wrap/unwrap calls don't build up layers.
def close_callback():
if not future.done():
future.set_exception(ssl_stream.error or StreamClosedError())
if orig_close_callback is not None:
orig_close_callback()
ssl_stream.set_close_callback(close_callback)
ssl_stream._ssl_connect_callback = lambda: future.set_result(ssl_stream)
ssl_stream.max_buffer_size = self.max_buffer_size
ssl_stream.read_chunk_size = self.read_chunk_size
return future
def _handle_connect(self):
err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
@ -755,14 +1076,19 @@ class IOStream(BaseIOStream):
# an error state before the socket becomes writable, so
# in that case a connection failure would be handled by the
# error path in _handle_events instead of here.
gen_log.warning("Connect error on fd %d: %s",
self.socket.fileno(), errno.errorcode[err])
if self._connect_future is None:
gen_log.warning("Connect error on fd %s: %s",
self.socket.fileno(), errno.errorcode[err])
self.close()
return
if self._connect_callback is not None:
callback = self._connect_callback
self._connect_callback = None
self._run_callback(callback)
if self._connect_future is not None:
future = self._connect_future
self._connect_future = None
future.set_result(self)
self._connecting = False
def set_nodelay(self, value):
@ -841,7 +1167,7 @@ class SSLIOStream(IOStream):
peer = self.socket.getpeername()
except Exception:
peer = '(not connected)'
gen_log.warning("SSL Error on %d %s: %s",
gen_log.warning("SSL Error on %s %s: %s",
self.socket.fileno(), peer, err)
return self.close(exc_info=True)
raise
@ -907,19 +1233,33 @@ class SSLIOStream(IOStream):
# has completed.
self._ssl_connect_callback = stack_context.wrap(callback)
self._server_hostname = server_hostname
super(SSLIOStream, self).connect(address, callback=None)
# Note: Since we don't pass our callback argument along to
# super.connect(), this will always return a Future.
# This is harmless, but a bit less efficient than it could be.
return super(SSLIOStream, self).connect(address, callback=None)
def _handle_connect(self):
# Call the superclass method to check for errors.
super(SSLIOStream, self)._handle_connect()
if self.closed():
return
# When the connection is complete, wrap the socket for SSL
# traffic. Note that we do this by overriding _handle_connect
# instead of by passing a callback to super().connect because
# user callbacks are enqueued asynchronously on the IOLoop,
# but since _handle_events calls _handle_connect immediately
# followed by _handle_write we need this to be synchronous.
#
# The IOLoop will get confused if we swap out self.socket while the
# fd is registered, so remove it now and re-register after
# wrap_socket().
self.io_loop.remove_handler(self.socket)
old_state = self._state
self._state = None
self.socket = ssl_wrap_socket(self.socket, self._ssl_options,
server_hostname=self._server_hostname,
do_handshake_on_connect=False)
super(SSLIOStream, self)._handle_connect()
self._add_io_state(old_state)
def read_from_fd(self):
if self._ssl_accepting:
@ -978,9 +1318,9 @@ class PipeIOStream(BaseIOStream):
try:
chunk = os.read(self.fd, self.read_chunk_size)
except (IOError, OSError) as e:
if e.args[0] in _ERRNO_WOULDBLOCK:
if errno_from_exception(e) in _ERRNO_WOULDBLOCK:
return None
elif e.args[0] == errno.EBADF:
elif errno_from_exception(e) == errno.EBADF:
# If the writing half of a pipe is closed, select will
# report it as readable but reads will fail with EBADF.
self.close(exc_info=True)

View file

@ -83,10 +83,10 @@ class LogFormatter(logging.Formatter):
DEFAULT_FORMAT = '%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s'
DEFAULT_DATE_FORMAT = '%y%m%d %H:%M:%S'
DEFAULT_COLORS = {
logging.DEBUG: 4, # Blue
logging.INFO: 2, # Green
logging.WARNING: 3, # Yellow
logging.ERROR: 1, # Red
logging.DEBUG: 4, # Blue
logging.INFO: 2, # Green
logging.WARNING: 3, # Yellow
logging.ERROR: 1, # Red
}
def __init__(self, color=True, fmt=DEFAULT_FORMAT,
@ -184,7 +184,7 @@ def enable_pretty_logging(options=None, logger=None):
"""
if options is None:
from tornado.options import options
if options.logging == 'none':
if options.logging is None or options.logging.lower() == 'none':
return
if logger is None:
logger = logging.getLogger()

View file

@ -20,18 +20,26 @@ from __future__ import absolute_import, division, print_function, with_statement
import errno
import os
import platform
import socket
import ssl
import stat
from tornado.concurrent import dummy_executor, run_on_executor
from tornado.ioloop import IOLoop
from tornado.platform.auto import set_close_exec
from tornado.util import u, Configurable
from tornado.util import u, Configurable, errno_from_exception
try:
import ssl
except ImportError:
# ssl is not available on Google App Engine
ssl = None
if hasattr(ssl, 'match_hostname') and hasattr(ssl, 'CertificateError'): # python 3.2+
ssl_match_hostname = ssl.match_hostname
SSLCertificateError = ssl.CertificateError
elif ssl is None:
ssl_match_hostname = SSLCertificateError = None
else:
import backports.ssl_match_hostname
ssl_match_hostname = backports.ssl_match_hostname.match_hostname
@ -44,6 +52,11 @@ else:
# thread now.
u('foo').encode('idna')
# These errnos indicate that a non-blocking operation must be retried
# at a later time. On most platforms they're the same value, but on
# some they differ.
_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags=None):
"""Creates listening sockets bound to the given port and address.
@ -77,13 +90,23 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags
family = socket.AF_INET
if flags is None:
flags = socket.AI_PASSIVE
bound_port = None
for res in set(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM,
0, flags)):
af, socktype, proto, canonname, sockaddr = res
if (platform.system() == 'Darwin' and address == 'localhost' and
af == socket.AF_INET6 and sockaddr[3] != 0):
# Mac OS X includes a link-local address fe80::1%lo0 in the
# getaddrinfo results for 'localhost'. However, the firewall
# doesn't understand that this is a local address and will
# prompt for access (often repeatedly, due to an apparent
# bug in its ability to remember granting access to an
# application). Skip these addresses.
continue
try:
sock = socket.socket(af, socktype, proto)
except socket.error as e:
if e.args[0] == errno.EAFNOSUPPORT:
if errno_from_exception(e) == errno.EAFNOSUPPORT:
continue
raise
set_close_exec(sock.fileno())
@ -100,8 +123,16 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags
# Python 2.x on windows doesn't have IPPROTO_IPV6.
if hasattr(socket, "IPPROTO_IPV6"):
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
# automatic port allocation with port=None
# should bind on the same port on IPv4 and IPv6
host, requested_port = sockaddr[:2]
if requested_port == 0 and bound_port is not None:
sockaddr = tuple([host, bound_port] + list(sockaddr[2:]))
sock.setblocking(0)
sock.bind(sockaddr)
bound_port = sock.getsockname()[1]
sock.listen(backlog)
sockets.append(sock)
return sockets
@ -124,7 +155,7 @@ if hasattr(socket, 'AF_UNIX'):
try:
st = os.stat(file)
except OSError as err:
if err.errno != errno.ENOENT:
if errno_from_exception(err) != errno.ENOENT:
raise
else:
if stat.S_ISSOCK(st.st_mode):
@ -154,18 +185,18 @@ def add_accept_handler(sock, callback, io_loop=None):
try:
connection, address = sock.accept()
except socket.error as e:
# EWOULDBLOCK and EAGAIN indicate we have accepted every
# _ERRNO_WOULDBLOCK indicate we have accepted every
# connection that is available.
if e.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
if errno_from_exception(e) in _ERRNO_WOULDBLOCK:
return
# ECONNABORTED indicates that there was a connection
# but it was closed while still in the accept queue.
# (observed on FreeBSD).
if e.args[0] == errno.ECONNABORTED:
if errno_from_exception(e) == errno.ECONNABORTED:
continue
raise
callback(connection, address)
io_loop.add_handler(sock.fileno(), accept_handler, IOLoop.READ)
io_loop.add_handler(sock, accept_handler, IOLoop.READ)
def is_valid_ip(ip):
@ -381,6 +412,10 @@ def ssl_options_to_context(ssl_options):
context.load_verify_locations(ssl_options['ca_certs'])
if 'ciphers' in ssl_options:
context.set_ciphers(ssl_options['ciphers'])
if hasattr(ssl, 'OP_NO_COMPRESSION'):
# Disable TLS compression to avoid CRIME and related attacks.
# This constant wasn't added until python 3.3.
context.options |= ssl.OP_NO_COMPRESSION
return context

View file

@ -56,6 +56,18 @@ We support `datetimes <datetime.datetime>`, `timedeltas
the top-level functions in this module (`define`, `parse_command_line`, etc)
simply call methods on it. You may create additional `OptionParser`
instances to define isolated sets of options, such as for subcommands.
.. note::
By default, several options are defined that will configure the
standard `logging` module when `parse_command_line` or `parse_config_file`
are called. If you want Tornado to leave the logging configuration
alone so you can manage it yourself, either pass ``--logging=none``
on the command line or do the following to disable it in code::
from tornado.options import options, parse_command_line
options.logging = None
parse_command_line()
"""
from __future__ import absolute_import, division, print_function, with_statement
@ -360,6 +372,8 @@ class _Mockable(object):
class _Option(object):
UNSET = object()
def __init__(self, name, default=None, type=basestring_type, help=None,
metavar=None, multiple=False, file_name=None, group_name=None,
callback=None):
@ -374,10 +388,10 @@ class _Option(object):
self.group_name = group_name
self.callback = callback
self.default = default
self._value = None
self._value = _Option.UNSET
def value(self):
return self.default if self._value is None else self._value
return self.default if self._value is _Option.UNSET else self._value
def parse(self, value):
_parse = {

View file

@ -12,9 +12,9 @@ unfinished callbacks on the event loop that fail when it resumes)
from __future__ import absolute_import, division, print_function, with_statement
import datetime
import functools
import os
from tornado.ioloop import IOLoop
# _Timeout is used for its timedelta_to_seconds method for py26 compatibility.
from tornado.ioloop import IOLoop, _Timeout
from tornado import stack_context
try:
@ -34,7 +34,7 @@ class BaseAsyncIOLoop(IOLoop):
self.asyncio_loop = asyncio_loop
self.close_loop = close_loop
self.asyncio_loop.call_soon(self.make_current)
# Maps fd to handler function (as in IOLoop.add_handler)
# Maps fd to (fileobj, handler function) pair (as in IOLoop.add_handler)
self.handlers = {}
# Set of fds listening for reads/writes
self.readers = set()
@ -44,19 +44,18 @@ class BaseAsyncIOLoop(IOLoop):
def close(self, all_fds=False):
self.closing = True
for fd in list(self.handlers):
fileobj, handler_func = self.handlers[fd]
self.remove_handler(fd)
if all_fds:
try:
os.close(fd)
except OSError:
pass
self.close_fd(fileobj)
if self.close_loop:
self.asyncio_loop.close()
def add_handler(self, fd, handler, events):
fd, fileobj = self.split_fd(fd)
if fd in self.handlers:
raise ValueError("fd %d added twice" % fd)
self.handlers[fd] = stack_context.wrap(handler)
raise ValueError("fd %s added twice" % fd)
self.handlers[fd] = (fileobj, stack_context.wrap(handler))
if events & IOLoop.READ:
self.asyncio_loop.add_reader(
fd, self._handle_events, fd, IOLoop.READ)
@ -67,6 +66,7 @@ class BaseAsyncIOLoop(IOLoop):
self.writers.add(fd)
def update_handler(self, fd, events):
fd, fileobj = self.split_fd(fd)
if events & IOLoop.READ:
if fd not in self.readers:
self.asyncio_loop.add_reader(
@ -87,6 +87,7 @@ class BaseAsyncIOLoop(IOLoop):
self.writers.remove(fd)
def remove_handler(self, fd):
fd, fileobj = self.split_fd(fd)
if fd not in self.handlers:
return
if fd in self.readers:
@ -98,7 +99,8 @@ class BaseAsyncIOLoop(IOLoop):
del self.handlers[fd]
def _handle_events(self, fd, events):
self.handlers[fd](fd, events)
fileobj, handler_func = self.handlers[fd]
handler_func(fileobj, events)
def start(self):
self._setup_logging()
@ -107,17 +109,11 @@ class BaseAsyncIOLoop(IOLoop):
def stop(self):
self.asyncio_loop.stop()
def _run_callback(self, callback, *args, **kwargs):
try:
callback(*args, **kwargs)
except Exception:
self.handle_callback_exception(callback)
def add_timeout(self, deadline, callback):
if isinstance(deadline, (int, float)):
delay = max(deadline - self.time(), 0)
elif isinstance(deadline, datetime.timedelta):
delay = deadline.total_seconds()
delay = _Timeout.timedelta_to_seconds(deadline)
else:
raise TypeError("Unsupported deadline %r", deadline)
return self.asyncio_loop.call_later(delay, self._run_callback,
@ -129,13 +125,9 @@ class BaseAsyncIOLoop(IOLoop):
def add_callback(self, callback, *args, **kwargs):
if self.closing:
raise RuntimeError("IOLoop is closing")
if kwargs:
self.asyncio_loop.call_soon_threadsafe(functools.partial(
self._run_callback, stack_context.wrap(callback),
*args, **kwargs))
else:
self.asyncio_loop.call_soon_threadsafe(
self._run_callback, stack_context.wrap(callback), *args)
self.asyncio_loop.call_soon_threadsafe(
self._run_callback,
functools.partial(stack_context.wrap(callback), *args, **kwargs))
add_callback_from_signal = add_callback

View file

@ -30,6 +30,10 @@ import os
if os.name == 'nt':
from tornado.platform.common import Waker
from tornado.platform.windows import set_close_exec
elif 'APPENGINE_RUNTIME' in os.environ:
from tornado.platform.common import Waker
def set_close_exec(fd):
pass
else:
from tornado.platform.posix import set_close_exec, Waker

View file

@ -15,7 +15,8 @@ class Waker(interface.Waker):
and Jython.
"""
def __init__(self):
# Based on Zope async.py: http://svn.zope.org/zc.ngi/trunk/src/zc/ngi/async.py
# Based on Zope select_trigger.py:
# https://github.com/zopefoundation/Zope/blob/master/src/ZServer/medusa/thread/select_trigger.py
self.writer = socket.socket()
# Disable buffering -- pulling the trigger sends 1 byte,

View file

@ -37,7 +37,7 @@ class _KQueue(object):
def register(self, fd, events):
if fd in self._active:
raise IOError("fd %d already registered" % fd)
raise IOError("fd %s already registered" % fd)
self._control(fd, events, select.KQ_EV_ADD)
self._active[fd] = events

View file

@ -37,7 +37,7 @@ class _Select(object):
def register(self, fd, events):
if fd in self.read_fds or fd in self.write_fds or fd in self.error_fds:
raise IOError("fd %d already registered" % fd)
raise IOError("fd %s already registered" % fd)
if events & IOLoop.READ:
self.read_fds.add(fd)
if events & IOLoop.WRITE:

View file

@ -91,6 +91,11 @@ from tornado.netutil import Resolver
from tornado.stack_context import NullContext, wrap
from tornado.ioloop import IOLoop
try:
long # py2
except NameError:
long = int # py3
@implementer(IDelayedCall)
class TornadoDelayedCall(object):
@ -365,8 +370,9 @@ def install(io_loop=None):
@implementer(IReadDescriptor, IWriteDescriptor)
class _FD(object):
def __init__(self, fd, handler):
def __init__(self, fd, fileobj, handler):
self.fd = fd
self.fileobj = fileobj
self.handler = handler
self.reading = False
self.writing = False
@ -377,15 +383,15 @@ class _FD(object):
def doRead(self):
if not self.lost:
self.handler(self.fd, tornado.ioloop.IOLoop.READ)
self.handler(self.fileobj, tornado.ioloop.IOLoop.READ)
def doWrite(self):
if not self.lost:
self.handler(self.fd, tornado.ioloop.IOLoop.WRITE)
self.handler(self.fileobj, tornado.ioloop.IOLoop.WRITE)
def connectionLost(self, reason):
if not self.lost:
self.handler(self.fd, tornado.ioloop.IOLoop.ERROR)
self.handler(self.fileobj, tornado.ioloop.IOLoop.ERROR)
self.lost = True
def logPrefix(self):
@ -412,14 +418,19 @@ class TwistedIOLoop(tornado.ioloop.IOLoop):
self.reactor.callWhenRunning(self.make_current)
def close(self, all_fds=False):
fds = self.fds
self.reactor.removeAll()
for c in self.reactor.getDelayedCalls():
c.cancel()
if all_fds:
for fd in fds.values():
self.close_fd(fd.fileobj)
def add_handler(self, fd, handler, events):
if fd in self.fds:
raise ValueError('fd %d added twice' % fd)
self.fds[fd] = _FD(fd, wrap(handler))
raise ValueError('fd %s added twice' % fd)
fd, fileobj = self.split_fd(fd)
self.fds[fd] = _FD(fd, fileobj, wrap(handler))
if events & tornado.ioloop.IOLoop.READ:
self.fds[fd].reading = True
self.reactor.addReader(self.fds[fd])
@ -428,6 +439,7 @@ class TwistedIOLoop(tornado.ioloop.IOLoop):
self.reactor.addWriter(self.fds[fd])
def update_handler(self, fd, events):
fd, fileobj = self.split_fd(fd)
if events & tornado.ioloop.IOLoop.READ:
if not self.fds[fd].reading:
self.fds[fd].reading = True
@ -446,6 +458,7 @@ class TwistedIOLoop(tornado.ioloop.IOLoop):
self.reactor.removeWriter(self.fds[fd])
def remove_handler(self, fd):
fd, fileobj = self.split_fd(fd)
if fd not in self.fds:
return
self.fds[fd].lost = True
@ -462,12 +475,6 @@ class TwistedIOLoop(tornado.ioloop.IOLoop):
def stop(self):
self.reactor.crash()
def _run_callback(self, callback, *args, **kwargs):
try:
callback(*args, **kwargs)
except Exception:
self.handle_callback_exception(callback)
def add_timeout(self, deadline, callback):
if isinstance(deadline, (int, long, float)):
delay = max(deadline - self.time(), 0)
@ -482,8 +489,9 @@ class TwistedIOLoop(tornado.ioloop.IOLoop):
timeout.cancel()
def add_callback(self, callback, *args, **kwargs):
self.reactor.callFromThread(self._run_callback,
wrap(callback), *args, **kwargs)
self.reactor.callFromThread(
self._run_callback,
functools.partial(wrap(callback), *args, **kwargs))
def add_callback_from_signal(self, callback, *args, **kwargs):
self.add_callback(callback, *args, **kwargs)

View file

@ -21,7 +21,6 @@ the server into multiple processes and managing subprocesses.
from __future__ import absolute_import, division, print_function, with_statement
import errno
import multiprocessing
import os
import signal
import subprocess
@ -35,6 +34,13 @@ from tornado.iostream import PipeIOStream
from tornado.log import gen_log
from tornado.platform.auto import set_close_exec
from tornado import stack_context
from tornado.util import errno_from_exception
try:
import multiprocessing
except ImportError:
# Multiprocessing is not availble on Google App Engine.
multiprocessing = None
try:
long # py2
@ -44,6 +50,8 @@ except NameError:
def cpu_count():
"""Returns the number of processors on this machine."""
if multiprocessing is None:
return 1
try:
return multiprocessing.cpu_count()
except NotImplementedError:
@ -136,7 +144,7 @@ def fork_processes(num_processes, max_restarts=100):
try:
pid, status = os.wait()
except OSError as e:
if e.errno == errno.EINTR:
if errno_from_exception(e) == errno.EINTR:
continue
raise
if pid not in children:
@ -283,7 +291,7 @@ class Subprocess(object):
try:
ret_pid, status = os.waitpid(pid, os.WNOHANG)
except OSError as e:
if e.args[0] == errno.ECHILD:
if errno_from_exception(e) == errno.ECHILD:
return
if ret_pid == 0:
return

View file

@ -1,23 +1,23 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
from tornado.escape import utf8, _unicode, native_str
from tornado.concurrent import is_future
from tornado.escape import utf8, _unicode
from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy
from tornado.httputil import HTTPHeaders
from tornado.iostream import IOStream, SSLIOStream
from tornado import httputil
from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters
from tornado.iostream import StreamClosedError
from tornado.netutil import Resolver, OverrideResolver
from tornado.log import gen_log
from tornado import stack_context
from tornado.util import GzipDecompressor
from tornado.tcpclient import TCPClient
import base64
import collections
import copy
import functools
import os.path
import re
import socket
import ssl
import sys
try:
@ -30,7 +30,23 @@ try:
except ImportError:
import urllib.parse as urlparse # py3
_DEFAULT_CA_CERTS = os.path.dirname(__file__) + '/ca-certificates.crt'
try:
import ssl
except ImportError:
# ssl is not available on Google App Engine.
ssl = None
try:
import certifi
except ImportError:
certifi = None
def _default_ca_certs():
if certifi is None:
raise Exception("The 'certifi' package is required to use https "
"in simple_httpclient")
return certifi.where()
class SimpleAsyncHTTPClient(AsyncHTTPClient):
@ -47,7 +63,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
"""
def initialize(self, io_loop, max_clients=10,
hostname_mapping=None, max_buffer_size=104857600,
resolver=None, defaults=None):
resolver=None, defaults=None, max_header_size=None):
"""Creates a AsyncHTTPClient.
Only a single AsyncHTTPClient instance exists per IOLoop
@ -74,6 +90,9 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
self.active = {}
self.waiting = {}
self.max_buffer_size = max_buffer_size
self.max_header_size = max_header_size
# TCPClient could create a Resolver for us, but we have to do it
# ourselves to support hostname_mapping.
if resolver:
self.resolver = resolver
self.own_resolver = False
@ -83,11 +102,13 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
if hostname_mapping is not None:
self.resolver = OverrideResolver(resolver=self.resolver,
mapping=hostname_mapping)
self.tcp_client = TCPClient(resolver=self.resolver, io_loop=io_loop)
def close(self):
super(SimpleAsyncHTTPClient, self).close()
if self.own_resolver:
self.resolver.close()
self.tcp_client.close()
def fetch_impl(self, request, callback):
key = object()
@ -119,7 +140,8 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
def _handle_request(self, request, release_callback, final_callback):
_HTTPConnection(self.io_loop, self, request, release_callback,
final_callback, self.max_buffer_size, self.resolver)
final_callback, self.max_buffer_size, self.tcp_client,
self.max_header_size)
def _release_fetch(self, key):
del self.active[key]
@ -142,11 +164,12 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
del self.waiting[key]
class _HTTPConnection(object):
class _HTTPConnection(httputil.HTTPMessageDelegate):
_SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
def __init__(self, io_loop, client, request, release_callback,
final_callback, max_buffer_size, resolver):
final_callback, max_buffer_size, tcp_client,
max_header_size):
self.start_time = io_loop.time()
self.io_loop = io_loop
self.client = client
@ -154,13 +177,15 @@ class _HTTPConnection(object):
self.release_callback = release_callback
self.final_callback = final_callback
self.max_buffer_size = max_buffer_size
self.resolver = resolver
self.tcp_client = tcp_client
self.max_header_size = max_header_size
self.code = None
self.headers = None
self.chunks = None
self.chunks = []
self._decompressor = None
# Timeout handle returned by IOLoop.add_timeout
self._timeout = None
self._sockaddr = None
with stack_context.ExceptionStackContext(self._handle_exception):
self.parsed = urlparse.urlsplit(_unicode(self.request.url))
if self.parsed.scheme not in ("http", "https"):
@ -183,42 +208,31 @@ class _HTTPConnection(object):
host = host[1:-1]
self.parsed_hostname = host # save final host for _on_connect
if request.allow_ipv6:
af = socket.AF_UNSPEC
else:
# We only try the first IP we get from getaddrinfo,
# so restrict to ipv4 by default.
if request.allow_ipv6 is False:
af = socket.AF_INET
else:
af = socket.AF_UNSPEC
ssl_options = self._get_ssl_options(self.parsed.scheme)
timeout = min(self.request.connect_timeout, self.request.request_timeout)
if timeout:
self._timeout = self.io_loop.add_timeout(
self.start_time + timeout,
stack_context.wrap(self._on_timeout))
self.resolver.resolve(host, port, af, callback=self._on_resolve)
self.tcp_client.connect(host, port, af=af,
ssl_options=ssl_options,
callback=self._on_connect)
def _on_resolve(self, addrinfo):
if self.final_callback is None:
# final_callback is cleared if we've hit our timeout
return
self.stream = self._create_stream(addrinfo)
self.stream.set_close_callback(self._on_close)
# ipv6 addresses are broken (in self.parsed.hostname) until
# 2.7, here is correctly parsed value calculated in __init__
sockaddr = addrinfo[0][1]
self.stream.connect(sockaddr, self._on_connect,
server_hostname=self.parsed_hostname)
def _create_stream(self, addrinfo):
af = addrinfo[0][0]
if self.parsed.scheme == "https":
def _get_ssl_options(self, scheme):
if scheme == "https":
ssl_options = {}
if self.request.validate_cert:
ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
if self.request.ca_certs is not None:
ssl_options["ca_certs"] = self.request.ca_certs
else:
ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
ssl_options["ca_certs"] = _default_ca_certs()
if self.request.client_key is not None:
ssl_options["keyfile"] = self.request.client_key
if self.request.client_cert is not None:
@ -236,21 +250,16 @@ class _HTTPConnection(object):
# but nearly all servers support both SSLv3 and TLSv1:
# http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html
if sys.version_info >= (2, 7):
ssl_options["ciphers"] = "DEFAULT:!SSLv2"
# In addition to disabling SSLv2, we also exclude certain
# classes of insecure ciphers.
ssl_options["ciphers"] = "DEFAULT:!SSLv2:!EXPORT:!DES"
else:
# This is really only necessary for pre-1.0 versions
# of openssl, but python 2.6 doesn't expose version
# information.
ssl_options["ssl_version"] = ssl.PROTOCOL_TLSv1
return SSLIOStream(socket.socket(af),
io_loop=self.io_loop,
ssl_options=ssl_options,
max_buffer_size=self.max_buffer_size)
else:
return IOStream(socket.socket(af),
io_loop=self.io_loop,
max_buffer_size=self.max_buffer_size)
return ssl_options
return None
def _on_timeout(self):
self._timeout = None
@ -262,7 +271,13 @@ class _HTTPConnection(object):
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
def _on_connect(self):
def _on_connect(self, stream):
if self.final_callback is None:
# final_callback is cleared if we've hit our timeout.
stream.close()
return
self.stream = stream
self.stream.set_close_callback(self._on_close)
self._remove_timeout()
if self.final_callback is None:
return
@ -302,16 +317,22 @@ class _HTTPConnection(object):
self.request.headers["User-Agent"] = self.request.user_agent
if not self.request.allow_nonstandard_methods:
if self.request.method in ("POST", "PATCH", "PUT"):
if self.request.body is None:
if (self.request.body is None and
self.request.body_producer is None):
raise AssertionError(
'Body must not be empty for "%s" request'
% self.request.method)
else:
if self.request.body is not None:
if (self.request.body is not None or
self.request.body_producer is not None):
raise AssertionError(
'Body must be empty for "%s" request'
% self.request.method)
if self.request.expect_100_continue:
self.request.headers["Expect"] = "100-continue"
if self.request.body is not None:
# When body_producer is used the caller is responsible for
# setting Content-Length (or else chunked encoding will be used).
self.request.headers["Content-Length"] = str(len(
self.request.body))
if (self.request.method == "POST" and
@ -320,20 +341,47 @@ class _HTTPConnection(object):
if self.request.use_gzip:
self.request.headers["Accept-Encoding"] = "gzip"
req_path = ((self.parsed.path or '/') +
(('?' + self.parsed.query) if self.parsed.query else ''))
request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method,
req_path))]
for k, v in self.request.headers.get_all():
line = utf8(k) + b": " + utf8(v)
if b'\n' in line:
raise ValueError('Newline in header: ' + repr(line))
request_lines.append(line)
request_str = b"\r\n".join(request_lines) + b"\r\n\r\n"
if self.request.body is not None:
request_str += self.request.body
(('?' + self.parsed.query) if self.parsed.query else ''))
self.stream.set_nodelay(True)
self.stream.write(request_str)
self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)
self.connection = HTTP1Connection(
self.stream, True,
HTTP1ConnectionParameters(
no_keep_alive=True,
max_header_size=self.max_header_size,
use_gzip=self.request.use_gzip),
self._sockaddr)
start_line = httputil.RequestStartLine(self.request.method,
req_path, 'HTTP/1.1')
self.connection.write_headers(start_line, self.request.headers)
if self.request.expect_100_continue:
self._read_response()
else:
self._write_body(True)
def _write_body(self, start_read):
if self.request.body is not None:
self.connection.write(self.request.body)
self.connection.finish()
elif self.request.body_producer is not None:
fut = self.request.body_producer(self.connection.write)
if is_future(fut):
def on_body_written(fut):
fut.result()
self.connection.finish()
if start_read:
self._read_response()
self.io_loop.add_future(fut, on_body_written)
return
self.connection.finish()
if start_read:
self._read_response()
def _read_response(self):
# Ensure that any exception raised in read_response ends up in our
# stack context.
self.io_loop.add_future(
self.connection.read_response(self),
lambda f: f.result())
def _release(self):
if self.release_callback is not None:
@ -351,43 +399,39 @@ class _HTTPConnection(object):
def _handle_exception(self, typ, value, tb):
if self.final_callback:
self._remove_timeout()
if isinstance(value, StreamClosedError):
value = HTTPError(599, "Stream closed")
self._run_callback(HTTPResponse(self.request, 599, error=value,
request_time=self.io_loop.time() - self.start_time,
))
if hasattr(self, "stream"):
# TODO: this may cause a StreamClosedError to be raised
# by the connection's Future. Should we cancel the
# connection more gracefully?
self.stream.close()
return True
else:
# If our callback has already been called, we are probably
# catching an exception that is not caused by us but rather
# some child of our callback. Rather than drop it on the floor,
# pass it along.
return False
# pass it along, unless it's just the stream being closed.
return isinstance(value, StreamClosedError)
def _on_close(self):
if self.final_callback is not None:
message = "Connection closed"
if self.stream.error:
message = str(self.stream.error)
raise self.stream.error
raise HTTPError(599, message)
def _handle_1xx(self, code):
self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)
def _on_headers(self, data):
data = native_str(data.decode("latin1"))
first_line, _, header_data = data.partition("\n")
match = re.match("HTTP/1.[01] ([0-9]+) ([^\r]*)", first_line)
assert match
code = int(match.group(1))
self.headers = HTTPHeaders.parse(header_data)
if 100 <= code < 200:
self._handle_1xx(code)
def headers_received(self, first_line, headers):
if self.request.expect_100_continue and first_line.code == 100:
self._write_body(False)
return
else:
self.code = code
self.reason = match.group(2)
self.headers = headers
self.code = first_line.code
self.reason = first_line.reason
if "Content-Length" in self.headers:
if "," in self.headers["Content-Length"]:
@ -404,17 +448,12 @@ class _HTTPConnection(object):
content_length = None
if self.request.header_callback is not None:
# re-attach the newline we split on earlier
self.request.header_callback(first_line + _)
# Reassemble the start line.
self.request.header_callback('%s %s %s\r\n' % first_line)
for k, v in self.headers.get_all():
self.request.header_callback("%s: %s\r\n" % (k, v))
self.request.header_callback('\r\n')
if self.request.method == "HEAD" or self.code == 304:
# HEAD requests and 304 responses never have content, even
# though they may have content-length headers
self._on_body(b"")
return
if 100 <= self.code < 200 or self.code == 204:
# These response codes never have bodies
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
@ -422,21 +461,9 @@ class _HTTPConnection(object):
content_length not in (None, 0)):
raise ValueError("Response with code %d should not have body" %
self.code)
self._on_body(b"")
return
if (self.request.use_gzip and
self.headers.get("Content-Encoding") == "gzip"):
self._decompressor = GzipDecompressor()
if self.headers.get("Transfer-Encoding") == "chunked":
self.chunks = []
self.stream.read_until(b"\r\n", self._on_chunk_length)
elif content_length is not None:
self.stream.read_bytes(content_length, self._on_body)
else:
self.stream.read_until_close(self._on_body)
def _on_body(self, data):
def finish(self):
data = b''.join(self.chunks)
self._remove_timeout()
original_request = getattr(self.request, "original_request",
self.request)
@ -472,19 +499,12 @@ class _HTTPConnection(object):
self.client.fetch(new_request, final_callback)
self._on_end_request()
return
if self._decompressor:
data = (self._decompressor.decompress(data) +
self._decompressor.flush())
if self.request.streaming_callback:
if self.chunks is None:
# if chunks is not None, we already called streaming_callback
# in _on_chunk_data
self.request.streaming_callback(data)
buffer = BytesIO()
else:
buffer = BytesIO(data) # TODO: don't require one big string?
response = HTTPResponse(original_request,
self.code, reason=self.reason,
self.code, reason=getattr(self, 'reason', None),
headers=self.headers,
request_time=self.io_loop.time() - self.start_time,
buffer=buffer,
@ -495,40 +515,11 @@ class _HTTPConnection(object):
def _on_end_request(self):
self.stream.close()
def _on_chunk_length(self, data):
# TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
length = int(data.strip(), 16)
if length == 0:
if self._decompressor is not None:
tail = self._decompressor.flush()
if tail:
# I believe the tail will always be empty (i.e.
# decompress will return all it can). The purpose
# of the flush call is to detect errors such
# as truncated input. But in case it ever returns
# anything, treat it as an extra chunk
if self.request.streaming_callback is not None:
self.request.streaming_callback(tail)
else:
self.chunks.append(tail)
# all the data has been decompressed, so we don't need to
# decompress again in _on_body
self._decompressor = None
self._on_body(b''.join(self.chunks))
else:
self.stream.read_bytes(length + 2, # chunk ends with \r\n
self._on_chunk_data)
def _on_chunk_data(self, data):
assert data[-2:] == b"\r\n"
chunk = data[:-2]
if self._decompressor:
chunk = self._decompressor.decompress(chunk)
def data_received(self, chunk):
if self.request.streaming_callback is not None:
self.request.streaming_callback(chunk)
else:
self.chunks.append(chunk)
self.stream.read_until(b"\r\n", self._on_chunk_length)
if __name__ == "__main__":

View file

@ -266,6 +266,18 @@ def wrap(fn):
# TODO: Any other better way to store contexts and update them in wrapped function?
cap_contexts = [_state.contexts]
if not cap_contexts[0][0] and not cap_contexts[0][1]:
# Fast path when there are no active contexts.
def null_wrapper(*args, **kwargs):
try:
current_state = _state.contexts
_state.contexts = cap_contexts[0]
return fn(*args, **kwargs)
finally:
_state.contexts = current_state
null_wrapper._wrapped = True
return null_wrapper
def wrapped(*args, **kwargs):
ret = None
try:

179
tornado/tcpclient.py Normal file
View file

@ -0,0 +1,179 @@
#!/usr/bin/env python
#
# Copyright 2014 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A non-blocking TCP connection factory.
"""
from __future__ import absolute_import, division, print_function, with_statement
import functools
import socket
from tornado.concurrent import Future
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
from tornado import gen
from tornado.netutil import Resolver
_INITIAL_CONNECT_TIMEOUT = 0.3
class _Connector(object):
"""A stateless implementation of the "Happy Eyeballs" algorithm.
"Happy Eyeballs" is documented in RFC6555 as the recommended practice
for when both IPv4 and IPv6 addresses are available.
In this implementation, we partition the addresses by family, and
make the first connection attempt to whichever address was
returned first by ``getaddrinfo``. If that connection fails or
times out, we begin a connection in parallel to the first address
of the other family. If there are additional failures we retry
with other addresses, keeping one connection attempt per family
in flight at a time.
http://tools.ietf.org/html/rfc6555
"""
def __init__(self, addrinfo, io_loop, connect):
self.io_loop = io_loop
self.connect = connect
self.future = Future()
self.timeout = None
self.last_error = None
self.remaining = len(addrinfo)
self.primary_addrs, self.secondary_addrs = self.split(addrinfo)
@staticmethod
def split(addrinfo):
"""Partition the ``addrinfo`` list by address family.
Returns two lists. The first list contains the first entry from
``addrinfo`` and all others with the same family, and the
second list contains all other addresses (normally one list will
be AF_INET and the other AF_INET6, although non-standard resolvers
may return additional families).
"""
primary = []
secondary = []
primary_af = addrinfo[0][0]
for af, addr in addrinfo:
if af == primary_af:
primary.append((af, addr))
else:
secondary.append((af, addr))
return primary, secondary
def start(self, timeout=_INITIAL_CONNECT_TIMEOUT):
self.try_connect(iter(self.primary_addrs))
self.set_timout(timeout)
return self.future
def try_connect(self, addrs):
try:
af, addr = next(addrs)
except StopIteration:
# We've reached the end of our queue, but the other queue
# might still be working. Send a final error on the future
# only when both queues are finished.
if self.remaining == 0 and not self.future.done():
self.future.set_exception(self.last_error or
IOError("connection failed"))
return
future = self.connect(af, addr)
future.add_done_callback(functools.partial(self.on_connect_done,
addrs, af, addr))
def on_connect_done(self, addrs, af, addr, future):
self.remaining -= 1
try:
stream = future.result()
except Exception as e:
if self.future.done():
return
# Error: try again (but remember what happened so we have an
# error to raise in the end)
self.last_error = e
self.try_connect(addrs)
if self.timeout is not None:
# If the first attempt failed, don't wait for the
# timeout to try an address from the secondary queue.
self.on_timeout()
return
self.clear_timeout()
if self.future.done():
# This is a late arrival; just drop it.
stream.close()
else:
self.future.set_result((af, addr, stream))
def set_timout(self, timeout):
self.timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout,
self.on_timeout)
def on_timeout(self):
self.timeout = None
self.try_connect(iter(self.secondary_addrs))
def clear_timeout(self):
if self.timeout is not None:
self.io_loop.remove_timeout(self.timeout)
class TCPClient(object):
"""A non-blocking TCP connection factory.
"""
def __init__(self, resolver=None, io_loop=None):
self.io_loop = io_loop or IOLoop.current()
if resolver is not None:
self.resolver = resolver
self._own_resolver = False
else:
self.resolver = Resolver(io_loop=io_loop)
self._own_resolver = True
def close(self):
if self._own_resolver:
self.resolver.close()
@gen.coroutine
def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None,
max_buffer_size=None):
"""Connect to the given host and port.
Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
``ssl_options`` is not None).
"""
addrinfo = yield self.resolver.resolve(host, port, af)
connector = _Connector(
addrinfo, self.io_loop,
functools.partial(self._create_stream, max_buffer_size))
af, addr, stream = yield connector.start()
# TODO: For better performance we could cache the (af, addr)
# information here and re-use it on sbusequent connections to
# the same host. (http://tools.ietf.org/html/rfc6555#section-4.2)
if ssl_options is not None:
stream = yield stream.start_tls(False, ssl_options=ssl_options,
server_hostname=host)
raise gen.Return(stream)
def _create_stream(self, max_buffer_size, af, addr):
# Always connect in plaintext; we'll convert to ssl if necessary
# after one connection has completed.
stream = IOStream(socket.socket(af),
io_loop=self.io_loop,
max_buffer_size=max_buffer_size)
return stream.connect(addr)

View file

@ -20,13 +20,19 @@ from __future__ import absolute_import, division, print_function, with_statement
import errno
import os
import socket
import ssl
from tornado.log import app_log
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream, SSLIOStream
from tornado.netutil import bind_sockets, add_accept_handler, ssl_wrap_socket
from tornado import process
from tornado.util import errno_from_exception
try:
import ssl
except ImportError:
# ssl is not available on Google App Engine.
ssl = None
class TCPServer(object):
@ -81,13 +87,15 @@ class TCPServer(object):
.. versionadded:: 3.1
The ``max_buffer_size`` argument.
"""
def __init__(self, io_loop=None, ssl_options=None, max_buffer_size=None):
def __init__(self, io_loop=None, ssl_options=None, max_buffer_size=None,
read_chunk_size=None):
self.io_loop = io_loop
self.ssl_options = ssl_options
self._sockets = {} # fd -> socket object
self._pending_sockets = []
self._started = False
self.max_buffer_size = max_buffer_size
self.read_chunk_size = None
# Verify the SSL options. Otherwise we don't get errors until clients
# connect. This doesn't verify that the keys are legitimate, but
@ -231,15 +239,19 @@ class TCPServer(object):
# SSLIOStream._do_ssl_handshake).
# To test this behavior, try nmap with the -sT flag.
# https://github.com/tornadoweb/tornado/pull/750
if err.args[0] in (errno.ECONNABORTED, errno.EINVAL):
if errno_from_exception(err) in (errno.ECONNABORTED, errno.EINVAL):
return connection.close()
else:
raise
try:
if self.ssl_options is not None:
stream = SSLIOStream(connection, io_loop=self.io_loop, max_buffer_size=self.max_buffer_size)
stream = SSLIOStream(connection, io_loop=self.io_loop,
max_buffer_size=self.max_buffer_size,
read_chunk_size=self.read_chunk_size)
else:
stream = IOStream(connection, io_loop=self.io_loop, max_buffer_size=self.max_buffer_size)
stream = IOStream(connection, io_loop=self.io_loop,
max_buffer_size=self.max_buffer_size,
read_chunk_size=self.read_chunk_size)
self.handle_stream(stream, address)
except Exception:
app_log.error("Error in connection callback", exc_info=True)

View file

@ -180,7 +180,7 @@ with ``{# ... #}``.
``{% set *x* = *y* %}``
Sets a local variable.
``{% try %}...{% except %}...{% finally %}...{% else %}...{% end %}``
``{% try %}...{% except %}...{% else %}...{% finally %}...{% end %}``
Same as the python ``try`` statement.
``{% while *condition* %}... {% end %}``
@ -367,10 +367,9 @@ class Loader(BaseLoader):
def _create_template(self, name):
path = os.path.join(self.root, name)
f = open(path, "rb")
template = Template(f.read(), name=name, loader=self)
f.close()
return template
with open(path, "rb") as f:
template = Template(f.read(), name=name, loader=self)
return template
class DictLoader(BaseLoader):
@ -785,7 +784,7 @@ def _parse(reader, template, in_block=None, in_loop=None):
if allowed_parents is not None:
if not in_block:
raise ParseError("%s outside %s block" %
(operator, allowed_parents))
(operator, allowed_parents))
if in_block not in allowed_parents:
raise ParseError("%s block cannot be attached to %s block" % (operator, in_block))
body.chunks.append(_IntermediateControlBlock(contents, line))

14
tornado/test/__main__.py Normal file
View file

@ -0,0 +1,14 @@
"""Shim to allow python -m tornado.test.
This only works in python 2.7+.
"""
from __future__ import absolute_import, division, print_function, with_statement
from tornado.test.runtests import all, main
# tornado.testing.main autodiscovery relies on 'all' being present in
# the main module, so import it here even though it is not used directly.
# The following line prevents a pyflakes warning.
all = all
main()

View file

@ -67,11 +67,29 @@ class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
self.finish(user)
def _oauth_get_user(self, access_token, callback):
if self.get_argument('fail_in_get_user', None):
raise Exception("failing in get_user")
if access_token != dict(key='uiop', secret='5678'):
raise Exception("incorrect access token %r" % access_token)
callback(dict(email='foo@example.com'))
class OAuth1ClientLoginCoroutineHandler(OAuth1ClientLoginHandler):
"""Replaces OAuth1ClientLoginCoroutineHandler's get() with a coroutine."""
@gen.coroutine
def get(self):
if self.get_argument('oauth_token', None):
# Ensure that any exceptions are set on the returned Future,
# not simply thrown into the surrounding StackContext.
try:
yield self.get_authenticated_user()
except Exception as e:
self.set_status(503)
self.write("got exception: %s" % e)
else:
yield self.authorize_redirect()
class OAuth1ClientRequestParametersHandler(RequestHandler, OAuthMixin):
def initialize(self, version):
self._OAUTH_VERSION = version
@ -255,6 +273,9 @@ class AuthTest(AsyncHTTPTestCase):
dict(version='1.0')),
('/oauth10a/client/login', OAuth1ClientLoginHandler,
dict(test=self, version='1.0a')),
('/oauth10a/client/login_coroutine',
OAuth1ClientLoginCoroutineHandler,
dict(test=self, version='1.0a')),
('/oauth10a/client/request_params',
OAuth1ClientRequestParametersHandler,
dict(version='1.0a')),
@ -348,6 +369,12 @@ class AuthTest(AsyncHTTPTestCase):
self.assertTrue('oauth_nonce' in parsed)
self.assertTrue('oauth_signature' in parsed)
def test_oauth10a_get_user_coroutine_exception(self):
response = self.fetch(
'/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true',
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
self.assertEqual(response.code, 503)
def test_oauth2_redirect(self):
response = self.fetch('/oauth2/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)

View file

@ -28,7 +28,6 @@ from tornado.iostream import IOStream
from tornado import stack_context
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, LogTrapTestCase, bind_unused_port, gen_test
from tornado.test.util import unittest
try:
@ -113,13 +112,6 @@ class ReturnFutureTest(AsyncTestCase):
self.assertIs(future, future2)
self.assertEqual(future.result(), 42)
@unittest.skipIf(futures is None, "futures module not present")
def test_timeout_future(self):
with self.assertRaises(futures.TimeoutError):
future = self.async_future()
# Do not call self.wait()
future.result(timeout=.1)
@gen_test
def test_async_future_gen(self):
result = yield self.async_future()

View file

@ -68,6 +68,16 @@ class DigestAuthHandler(RequestHandler):
(realm, nonce, opaque))
class CustomReasonHandler(RequestHandler):
def get(self):
self.set_status(200, "Custom reason")
class CustomFailReasonHandler(RequestHandler):
def get(self):
self.set_status(400, "Custom reason")
@unittest.skipIf(pycurl is None, "pycurl module not present")
class CurlHTTPClientTestCase(AsyncHTTPTestCase):
def setUp(self):
@ -78,6 +88,8 @@ class CurlHTTPClientTestCase(AsyncHTTPTestCase):
def get_app(self):
return Application([
('/digest', DigestAuthHandler),
('/custom_reason', CustomReasonHandler),
('/custom_fail_reason', CustomFailReasonHandler),
])
def test_prepare_curl_callback_stack_context(self):
@ -100,3 +112,11 @@ class CurlHTTPClientTestCase(AsyncHTTPTestCase):
response = self.fetch('/digest', auth_mode='digest',
auth_username='foo', auth_password='bar')
self.assertEqual(response.body, b'ok')
def test_custom_reason(self):
response = self.fetch('/custom_reason')
self.assertEqual(response.reason, "Custom reason")
def test_fail_custom_reason(self):
response = self.fetch('/custom_fail_reason')
self.assertEqual(str(response.error), "HTTP 400: Custom reason")

View file

@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function, with_statement
import contextlib
import datetime
import functools
import sys
import textwrap
@ -8,7 +9,7 @@ import time
import platform
import weakref
from tornado.concurrent import return_future
from tornado.concurrent import return_future, Future
from tornado.escape import url_escape
from tornado.httpclient import AsyncHTTPClient
from tornado.ioloop import IOLoop
@ -20,6 +21,10 @@ from tornado.web import Application, RequestHandler, asynchronous, HTTPError
from tornado import gen
try:
from concurrent import futures
except ImportError:
futures = None
skipBefore33 = unittest.skipIf(sys.version_info < (3, 3), 'PEP 380 not available')
skipNotCPython = unittest.skipIf(platform.python_implementation() != 'CPython',
@ -291,26 +296,53 @@ class GenEngineTest(AsyncTestCase):
self.stop()
self.run_gen(f)
def test_multi_delayed(self):
# The following tests explicitly run with both gen.Multi
# and gen.multi_future (Task returns a Future, so it can be used
# with either).
def test_multi_yieldpoint_delayed(self):
@gen.engine
def f():
# callbacks run at different times
responses = yield [
responses = yield gen.Multi([
gen.Task(self.delay_callback, 3, arg="v1"),
gen.Task(self.delay_callback, 1, arg="v2"),
]
])
self.assertEqual(responses, ["v1", "v2"])
self.stop()
self.run_gen(f)
def test_multi_dict_delayed(self):
def test_multi_yieldpoint_dict_delayed(self):
@gen.engine
def f():
# callbacks run at different times
responses = yield dict(
responses = yield gen.Multi(dict(
foo=gen.Task(self.delay_callback, 3, arg="v1"),
bar=gen.Task(self.delay_callback, 1, arg="v2"),
)
))
self.assertEqual(responses, dict(foo="v1", bar="v2"))
self.stop()
self.run_gen(f)
def test_multi_future_delayed(self):
@gen.engine
def f():
# callbacks run at different times
responses = yield gen.multi_future([
gen.Task(self.delay_callback, 3, arg="v1"),
gen.Task(self.delay_callback, 1, arg="v2"),
])
self.assertEqual(responses, ["v1", "v2"])
self.stop()
self.run_gen(f)
def test_multi_future_dict_delayed(self):
@gen.engine
def f():
# callbacks run at different times
responses = yield gen.multi_future(dict(
foo=gen.Task(self.delay_callback, 3, arg="v1"),
bar=gen.Task(self.delay_callback, 1, arg="v2"),
))
self.assertEqual(responses, dict(foo="v1", bar="v2"))
self.stop()
self.run_gen(f)
@ -334,6 +366,15 @@ class GenEngineTest(AsyncTestCase):
y = yield {}
self.assertTrue(isinstance(y, dict))
@gen_test
def test_multi_mixed_types(self):
# A YieldPoint (Wait) and Future (Task) can be combined
# (and use the YieldPoint codepath)
(yield gen.Callback("k1"))("v1")
responses = yield [gen.Wait("k1"),
gen.Task(self.delay_callback, 3, arg="v2")]
self.assertEqual(responses, ["v1", "v2"])
@gen_test
def test_future(self):
result = yield self.async_future(1)
@ -733,8 +774,14 @@ class GenCoroutineTest(AsyncTestCase):
def test_replace_context_exception(self):
# Test exception handling: exceptions thrown into the stack context
# can be caught and replaced.
# Note that this test and the following are for behavior that is
# not really supported any more: coroutines no longer create a
# stack context automatically; but one is created after the first
# YieldPoint (i.e. not a Future).
@gen.coroutine
def f2():
(yield gen.Callback(1))()
yield gen.Wait(1)
self.io_loop.add_callback(lambda: 1 / 0)
try:
yield gen.Task(self.io_loop.add_timeout,
@ -753,6 +800,8 @@ class GenCoroutineTest(AsyncTestCase):
# can be caught and ignored.
@gen.coroutine
def f2():
(yield gen.Callback(1))()
yield gen.Wait(1)
self.io_loop.add_callback(lambda: 1 / 0)
try:
yield gen.Task(self.io_loop.add_timeout,
@ -764,6 +813,31 @@ class GenCoroutineTest(AsyncTestCase):
self.assertEqual(result, 42)
self.finished = True
@gen_test
def test_moment(self):
calls = []
@gen.coroutine
def f(name, yieldable):
for i in range(5):
calls.append(name)
yield yieldable
# First, confirm the behavior without moment: each coroutine
# monopolizes the event loop until it finishes.
immediate = Future()
immediate.set_result(None)
yield [f('a', immediate), f('b', immediate)]
self.assertEqual(''.join(calls), 'aaaaabbbbb')
# With moment, they take turns.
calls = []
yield [f('a', gen.moment), f('b', gen.moment)]
self.assertEqual(''.join(calls), 'ababababab')
self.finished = True
calls = []
yield [f('a', gen.moment), f('b', immediate)]
self.assertEqual(''.join(calls), 'abbbbbaaaa')
class GenSequenceHandler(RequestHandler):
@asynchronous
@ -943,5 +1017,55 @@ class GenWebTest(AsyncHTTPTestCase):
response = self.fetch('/async_prepare_error')
self.assertEqual(response.code, 403)
class WithTimeoutTest(AsyncTestCase):
@gen_test
def test_timeout(self):
with self.assertRaises(gen.TimeoutError):
yield gen.with_timeout(datetime.timedelta(seconds=0.1),
Future())
@gen_test
def test_completes_before_timeout(self):
future = Future()
self.io_loop.add_timeout(datetime.timedelta(seconds=0.1),
lambda: future.set_result('asdf'))
result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
future)
self.assertEqual(result, 'asdf')
@gen_test
def test_fails_before_timeout(self):
future = Future()
self.io_loop.add_timeout(
datetime.timedelta(seconds=0.1),
lambda: future.set_exception(ZeroDivisionError))
with self.assertRaises(ZeroDivisionError):
yield gen.with_timeout(datetime.timedelta(seconds=3600), future)
@gen_test
def test_already_resolved(self):
future = Future()
future.set_result('asdf')
result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
future)
self.assertEqual(result, 'asdf')
@unittest.skipIf(futures is None, 'futures module not present')
@gen_test
def test_timeout_concurrent_future(self):
with futures.ThreadPoolExecutor(1) as executor:
with self.assertRaises(gen.TimeoutError):
yield gen.with_timeout(self.io_loop.time(),
executor.submit(time.sleep, 0.1))
@unittest.skipIf(futures is None, 'futures module not present')
@gen_test
def test_completed_concurrent_future(self):
with futures.ThreadPoolExecutor(1) as executor:
yield gen.with_timeout(datetime.timedelta(seconds=3600),
executor.submit(lambda: None))
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1,11 @@
# Dummy source file to allow creation of the initial .po file in the
# same way as a real project. I'm not entirely sure about the real
# workflow here, but this seems to work.
#
# 1) xgettext --language=Python --keyword=_:1,2 -d tornado_test extract_me.py -o tornado_test.po
# 2) Edit tornado_test.po, setting CHARSET and setting msgstr
# 3) msgfmt tornado_test.po -o tornado_test.mo
# 4) Put the file in the proper location: $LANG/LC_MESSAGES
from __future__ import absolute_import, division, print_function, with_statement
_("school")

View file

@ -8,7 +8,6 @@ from contextlib import closing
import functools
import sys
import threading
import time
from tornado.escape import utf8
from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient
@ -19,7 +18,7 @@ from tornado.log import gen_log
from tornado import netutil
from tornado.stack_context import ExceptionStackContext, NullContext
from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog
from tornado.test.util import unittest
from tornado.test.util import unittest, skipOnTravis
from tornado.util import u, bytes_type
from tornado.web import Application, RequestHandler, url
@ -111,6 +110,7 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase):
url("/all_methods", AllMethodsHandler),
], gzip=True)
@skipOnTravis
def test_hello_world(self):
response = self.fetch("/hello")
self.assertEqual(response.code, 200)
@ -356,11 +356,10 @@ Transfer-Encoding: chunked
@gen_test
def test_future_http_error(self):
try:
with self.assertRaises(HTTPError) as context:
yield self.http_client.fetch(self.get_url('/notfound'))
except HTTPError as e:
self.assertEqual(e.code, 404)
self.assertEqual(e.response.code, 404)
self.assertEqual(context.exception.code, 404)
self.assertEqual(context.exception.response.code, 404)
@gen_test
def test_reuse_request_from_response(self):

View file

@ -2,20 +2,23 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado import httpclient, simple_httpclient, netutil
from tornado.escape import json_decode, utf8, _unicode, recursive_unicode, native_str
from tornado import netutil
from tornado.escape import json_decode, json_encode, utf8, _unicode, recursive_unicode, native_str
from tornado import gen
from tornado.http1connection import HTTP1Connection
from tornado.httpserver import HTTPServer
from tornado.httputil import HTTPHeaders
from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine
from tornado.iostream import IOStream
from tornado.log import gen_log
from tornado.netutil import ssl_options_to_context, Resolver
from tornado.log import gen_log, app_log
from tornado.netutil import ssl_options_to_context
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog
from tornado.test.util import unittest
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test
from tornado.test.util import unittest, skipOnTravis
from tornado.util import u, bytes_type
from tornado.web import Application, RequestHandler, asynchronous
from tornado.web import Application, RequestHandler, asynchronous, stream_request_body
from contextlib import closing
import datetime
import gzip
import os
import shutil
import socket
@ -23,6 +26,28 @@ import ssl
import sys
import tempfile
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
def read_stream_body(stream, callback):
"""Reads an HTTP response from `stream` and runs callback with its
headers and body."""
chunks = []
class Delegate(HTTPMessageDelegate):
def headers_received(self, start_line, headers):
self.headers = headers
def data_received(self, chunk):
chunks.append(chunk)
def finish(self):
callback((self.headers, b''.join(chunks)))
conn = HTTP1Connection(stream, True)
conn.read_response(Delegate())
class HandlerBaseTestCase(AsyncHTTPTestCase):
def get_app(self):
@ -86,11 +111,13 @@ class SSLTestMixin(object):
# connection, rather than waiting for a timeout or otherwise
# misbehaving.
with ExpectLog(gen_log, '(SSL Error|uncaught exception)'):
self.http_client.fetch(self.get_url("/").replace('https:', 'http:'),
self.stop,
request_timeout=3600,
connect_timeout=3600)
response = self.wait()
with ExpectLog(gen_log, 'Uncaught exception', required=False):
self.http_client.fetch(
self.get_url("/").replace('https:', 'http:'),
self.stop,
request_timeout=3600,
connect_timeout=3600)
response = self.wait()
self.assertEqual(response.code, 599)
# Python's SSL implementation differs significantly between versions.
@ -163,18 +190,7 @@ class MultipartTestHandler(RequestHandler):
})
class RawRequestHTTPConnection(simple_httpclient._HTTPConnection):
def set_request(self, request):
self.__next_request = request
def _on_connect(self):
self.stream.write(self.__next_request)
self.__next_request = None
self.stream.read_until(b"\r\n\r\n", self._on_headers)
# This test is also called from wsgi_test
class HTTPConnectionTest(AsyncHTTPTestCase):
def get_handlers(self):
return [("/multipart", MultipartTestHandler),
@ -184,23 +200,16 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
return Application(self.get_handlers())
def raw_fetch(self, headers, body):
with closing(Resolver(io_loop=self.io_loop)) as resolver:
with closing(SimpleAsyncHTTPClient(self.io_loop,
resolver=resolver)) as client:
conn = RawRequestHTTPConnection(
self.io_loop, client,
httpclient._RequestProxy(
httpclient.HTTPRequest(self.get_url("/")),
dict(httpclient.HTTPRequest._DEFAULTS)),
None, self.stop,
1024 * 1024, resolver)
conn.set_request(
b"\r\n".join(headers +
[utf8("Content-Length: %d\r\n" % len(body))]) +
b"\r\n" + body)
response = self.wait()
response.rethrow()
return response
with closing(IOStream(socket.socket())) as stream:
stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
self.wait()
stream.write(
b"\r\n".join(headers +
[utf8("Content-Length: %d\r\n" % len(body))]) +
b"\r\n" + body)
read_stream_body(stream, self.stop)
headers, body = self.wait()
return body
def test_multipart_form(self):
# Encodings here are tricky: Headers are latin1, bodies can be
@ -221,7 +230,7 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
b"--1234567890--",
b"",
]))
data = json_decode(response.body)
data = json_decode(response)
self.assertEqual(u("\u00e9"), data["header"])
self.assertEqual(u("\u00e1"), data["argument"])
self.assertEqual(u("\u00f3"), data["filename"])
@ -397,6 +406,25 @@ class HTTPServerRawTest(AsyncHTTPTestCase):
self.stop)
self.wait()
def test_chunked_request_body(self):
# Chunked requests are not widely supported and we don't have a way
# to generate them in AsyncHTTPClient, but HTTPServer will read them.
self.stream.write(b"""\
POST /echo HTTP/1.1
Transfer-Encoding: chunked
Content-Type: application/x-www-form-urlencoded
4
foo=
3
bar
0
""".replace(b"\n", b"\r\n"))
read_stream_body(self.stream, self.stop)
headers, response = self.wait()
self.assertEqual(json_decode(response), {u('foo'): [u('bar')]})
class XHeaderTest(HandlerBaseTestCase):
class Handler(RequestHandler):
@ -541,7 +569,7 @@ class UnixSocketTest(AsyncTestCase):
def test_unix_socket_bad_request(self):
# Unix sockets don't have remote addresses so they just return an
# empty string.
with ExpectLog(gen_log, "Malformed HTTP request from"):
with ExpectLog(gen_log, "Malformed HTTP message from"):
self.stream.write(b"garbage\r\n\r\n")
self.stream.read_until_close(self.stop)
response = self.wait()
@ -610,8 +638,8 @@ class KeepAliveTest(AsyncHTTPTestCase):
return headers
def read_response(self):
headers = self.read_headers()
self.stream.read_bytes(int(headers['Content-Length']), self.stop)
self.headers = self.read_headers()
self.stream.read_bytes(int(self.headers['Content-Length']), self.stop)
body = self.wait()
self.assertEqual(b'Hello world', body)
@ -645,6 +673,7 @@ class KeepAliveTest(AsyncHTTPTestCase):
self.stream.read_until_close(callback=self.stop)
data = self.wait()
self.assertTrue(not data)
self.assertTrue('Connection' not in self.headers)
self.close()
def test_http10_keepalive(self):
@ -652,8 +681,10 @@ class KeepAliveTest(AsyncHTTPTestCase):
self.connect()
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.close()
def test_pipelined_requests(self):
@ -683,3 +714,322 @@ class KeepAliveTest(AsyncHTTPTestCase):
self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n')
self.read_headers()
self.close()
class GzipBaseTest(object):
def get_app(self):
return Application([('/', EchoHandler)])
def post_gzip(self, body):
bytesio = BytesIO()
gzip_file = gzip.GzipFile(mode='w', fileobj=bytesio)
gzip_file.write(utf8(body))
gzip_file.close()
compressed_body = bytesio.getvalue()
return self.fetch('/', method='POST', body=compressed_body,
headers={'Content-Encoding': 'gzip'})
def test_uncompressed(self):
response = self.fetch('/', method='POST', body='foo=bar')
self.assertEquals(json_decode(response.body), {u('foo'): [u('bar')]})
class GzipTest(GzipBaseTest, AsyncHTTPTestCase):
def get_httpserver_options(self):
return dict(gzip=True)
def test_gzip(self):
response = self.post_gzip('foo=bar')
self.assertEquals(json_decode(response.body), {u('foo'): [u('bar')]})
class GzipUnsupportedTest(GzipBaseTest, AsyncHTTPTestCase):
def test_gzip_unsupported(self):
# Gzip support is opt-in; without it the server fails to parse
# the body (but parsing form bodies is currently just a log message,
# not a fatal error).
with ExpectLog(gen_log, "Unsupported Content-Encoding"):
response = self.post_gzip('foo=bar')
self.assertEquals(json_decode(response.body), {})
class StreamingChunkSizeTest(AsyncHTTPTestCase):
# 50 characters long, and repetitive so it can be compressed.
BODY = b'01234567890123456789012345678901234567890123456789'
CHUNK_SIZE = 16
def get_http_client(self):
# body_producer doesn't work on curl_httpclient, so override the
# configured AsyncHTTPClient implementation.
return SimpleAsyncHTTPClient(io_loop=self.io_loop)
def get_httpserver_options(self):
return dict(chunk_size=self.CHUNK_SIZE, gzip=True)
class MessageDelegate(HTTPMessageDelegate):
def __init__(self, connection):
self.connection = connection
def headers_received(self, start_line, headers):
self.chunk_lengths = []
def data_received(self, chunk):
self.chunk_lengths.append(len(chunk))
def finish(self):
response_body = utf8(json_encode(self.chunk_lengths))
self.connection.write_headers(
ResponseStartLine('HTTP/1.1', 200, 'OK'),
HTTPHeaders({'Content-Length': str(len(response_body))}))
self.connection.write(response_body)
self.connection.finish()
def get_app(self):
class App(HTTPServerConnectionDelegate):
def start_request(self, connection):
return StreamingChunkSizeTest.MessageDelegate(connection)
return App()
def fetch_chunk_sizes(self, **kwargs):
response = self.fetch('/', method='POST', **kwargs)
response.rethrow()
chunks = json_decode(response.body)
self.assertEqual(len(self.BODY), sum(chunks))
for chunk_size in chunks:
self.assertLessEqual(chunk_size, self.CHUNK_SIZE,
'oversized chunk: ' + str(chunks))
self.assertGreater(chunk_size, 0,
'empty chunk: ' + str(chunks))
return chunks
def compress(self, body):
bytesio = BytesIO()
gzfile = gzip.GzipFile(mode='w', fileobj=bytesio)
gzfile.write(body)
gzfile.close()
compressed = bytesio.getvalue()
if len(compressed) >= len(body):
raise Exception("body did not shrink when compressed")
return compressed
def test_regular_body(self):
chunks = self.fetch_chunk_sizes(body=self.BODY)
# Without compression we know exactly what to expect.
self.assertEqual([16, 16, 16, 2], chunks)
def test_compressed_body(self):
self.fetch_chunk_sizes(body=self.compress(self.BODY),
headers={'Content-Encoding': 'gzip'})
# Compression creates irregular boundaries so the assertions
# in fetch_chunk_sizes are as specific as we can get.
def test_chunked_body(self):
def body_producer(write):
write(self.BODY[:20])
write(self.BODY[20:])
chunks = self.fetch_chunk_sizes(body_producer=body_producer)
# HTTP chunk boundaries translate to application-visible breaks
self.assertEqual([16, 4, 16, 14], chunks)
def test_chunked_compressed(self):
compressed = self.compress(self.BODY)
self.assertGreater(len(compressed), 20)
def body_producer(write):
write(compressed[:20])
write(compressed[20:])
self.fetch_chunk_sizes(body_producer=body_producer,
headers={'Content-Encoding': 'gzip'})
class MaxHeaderSizeTest(AsyncHTTPTestCase):
def get_app(self):
return Application([('/', HelloWorldRequestHandler)])
def get_httpserver_options(self):
return dict(max_header_size=1024)
def test_small_headers(self):
response = self.fetch("/", headers={'X-Filler': 'a' * 100})
response.rethrow()
self.assertEqual(response.body, b"Hello world")
def test_large_headers(self):
with ExpectLog(gen_log, "Unsatisfiable read"):
response = self.fetch("/", headers={'X-Filler': 'a' * 1000})
self.assertEqual(response.code, 599)
@skipOnTravis
class IdleTimeoutTest(AsyncHTTPTestCase):
def get_app(self):
return Application([('/', HelloWorldRequestHandler)])
def get_httpserver_options(self):
return dict(idle_connection_timeout=0.1)
def setUp(self):
super(IdleTimeoutTest, self).setUp()
self.streams = []
def tearDown(self):
super(IdleTimeoutTest, self).tearDown()
for stream in self.streams:
stream.close()
def connect(self):
stream = IOStream(socket.socket())
stream.connect(('localhost', self.get_http_port()), self.stop)
self.wait()
self.streams.append(stream)
return stream
def test_unused_connection(self):
stream = self.connect()
stream.set_close_callback(self.stop)
self.wait()
def test_idle_after_use(self):
stream = self.connect()
stream.set_close_callback(lambda: self.stop("closed"))
# Use the connection twice to make sure keep-alives are working
for i in range(2):
stream.write(b"GET / HTTP/1.1\r\n\r\n")
stream.read_until(b"\r\n\r\n", self.stop)
self.wait()
stream.read_bytes(11, self.stop)
data = self.wait()
self.assertEqual(data, b"Hello world")
# Now let the timeout trigger and close the connection.
data = self.wait()
self.assertEqual(data, "closed")
class BodyLimitsTest(AsyncHTTPTestCase):
def get_app(self):
class BufferedHandler(RequestHandler):
def put(self):
self.write(str(len(self.request.body)))
@stream_request_body
class StreamingHandler(RequestHandler):
def initialize(self):
self.bytes_read = 0
def prepare(self):
if 'expected_size' in self.request.arguments:
self.request.connection.set_max_body_size(
int(self.get_argument('expected_size')))
if 'body_timeout' in self.request.arguments:
self.request.connection.set_body_timeout(
float(self.get_argument('body_timeout')))
def data_received(self, data):
self.bytes_read += len(data)
def put(self):
self.write(str(self.bytes_read))
return Application([('/buffered', BufferedHandler),
('/streaming', StreamingHandler)])
def get_httpserver_options(self):
return dict(body_timeout=3600, max_body_size=4096)
def get_http_client(self):
# body_producer doesn't work on curl_httpclient, so override the
# configured AsyncHTTPClient implementation.
return SimpleAsyncHTTPClient(io_loop=self.io_loop)
def test_small_body(self):
response = self.fetch('/buffered', method='PUT', body=b'a' * 4096)
self.assertEqual(response.body, b'4096')
response = self.fetch('/streaming', method='PUT', body=b'a' * 4096)
self.assertEqual(response.body, b'4096')
def test_large_body_buffered(self):
with ExpectLog(gen_log, '.*Content-Length too long'):
response = self.fetch('/buffered', method='PUT', body=b'a' * 10240)
self.assertEqual(response.code, 599)
def test_large_body_buffered_chunked(self):
with ExpectLog(gen_log, '.*chunked body too large'):
response = self.fetch('/buffered', method='PUT',
body_producer=lambda write: write(b'a' * 10240))
self.assertEqual(response.code, 599)
def test_large_body_streaming(self):
with ExpectLog(gen_log, '.*Content-Length too long'):
response = self.fetch('/streaming', method='PUT', body=b'a' * 10240)
self.assertEqual(response.code, 599)
def test_large_body_streaming_chunked(self):
with ExpectLog(gen_log, '.*chunked body too large'):
response = self.fetch('/streaming', method='PUT',
body_producer=lambda write: write(b'a' * 10240))
self.assertEqual(response.code, 599)
def test_large_body_streaming_override(self):
response = self.fetch('/streaming?expected_size=10240', method='PUT',
body=b'a' * 10240)
self.assertEqual(response.body, b'10240')
def test_large_body_streaming_chunked_override(self):
response = self.fetch('/streaming?expected_size=10240', method='PUT',
body_producer=lambda write: write(b'a' * 10240))
self.assertEqual(response.body, b'10240')
@gen_test
def test_timeout(self):
stream = IOStream(socket.socket())
try:
yield stream.connect(('127.0.0.1', self.get_http_port()))
# Use a raw stream because AsyncHTTPClient won't let us read a
# response without finishing a body.
stream.write(b'PUT /streaming?body_timeout=0.1 HTTP/1.0\r\n'
b'Content-Length: 42\r\n\r\n')
with ExpectLog(gen_log, 'Timeout reading body'):
response = yield stream.read_until_close()
self.assertEqual(response, b'')
finally:
stream.close()
@gen_test
def test_body_size_override_reset(self):
# The max_body_size override is reset between requests.
stream = IOStream(socket.socket())
try:
yield stream.connect(('127.0.0.1', self.get_http_port()))
# Use a raw stream so we can make sure it's all on one connection.
stream.write(b'PUT /streaming?expected_size=10240 HTTP/1.1\r\n'
b'Content-Length: 10240\r\n\r\n')
stream.write(b'a' * 10240)
headers, response = yield gen.Task(read_stream_body, stream)
self.assertEqual(response, b'10240')
# Without the ?expected_size parameter, we get the old default value
stream.write(b'PUT /streaming HTTP/1.1\r\n'
b'Content-Length: 10240\r\n\r\n')
with ExpectLog(gen_log, '.*Content-Length too long'):
data = yield stream.read_until_close()
self.assertEqual(data, b'')
finally:
stream.close()
class LegacyInterfaceTest(AsyncHTTPTestCase):
def get_app(self):
# The old request_callback interface does not implement the
# delegate interface, and writes its response via request.write
# instead of request.connection.write_headers.
def handle_request(request):
message = b"Hello world"
request.write(utf8("HTTP/1.1 200 OK\r\n"
"Content-Length: %d\r\n\r\n" % len(message)))
request.write(message)
request.finish()
return handle_request
def test_legacy_interface(self):
response = self.fetch('/')
self.assertEqual(response.body, b"Hello world")

View file

@ -13,6 +13,7 @@ class ImportTest(unittest.TestCase):
# import tornado.curl_httpclient # depends on pycurl
import tornado.escape
import tornado.gen
import tornado.http1connection
import tornado.httpclient
import tornado.httpserver
import tornado.httputil

View file

@ -5,16 +5,16 @@ from __future__ import absolute_import, division, print_function, with_statement
import contextlib
import datetime
import functools
import logging
import socket
import sys
import threading
import time
from tornado import gen
from tornado.ioloop import IOLoop, PollIOLoop, TimeoutError
from tornado.ioloop import IOLoop, TimeoutError
from tornado.log import app_log
from tornado.stack_context import ExceptionStackContext, StackContext, wrap, NullContext
from tornado.testing import AsyncTestCase, bind_unused_port
from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog
from tornado.test.util import unittest, skipIfNonUnix, skipOnTravis
try:
@ -52,7 +52,8 @@ class TestIOLoop(AsyncTestCase):
thread = threading.Thread(target=target)
self.io_loop.add_callback(thread.start)
self.wait()
self.assertAlmostEqual(time.time(), self.stop_time, places=2)
delta = time.time() - self.stop_time
self.assertLess(delta, 0.1)
thread.join()
def test_add_timeout_timedelta(self):
@ -172,6 +173,119 @@ class TestIOLoop(AsyncTestCase):
self.io_loop.add_callback(lambda: self.io_loop.add_callback(self.stop))
self.wait()
def test_close_file_object(self):
"""When a file object is used instead of a numeric file descriptor,
the object should be closed (by IOLoop.close(all_fds=True),
not just the fd.
"""
# Use a socket since they are supported by IOLoop on all platforms.
# Unfortunately, sockets don't support the .closed attribute for
# inspecting their close status, so we must use a wrapper.
class SocketWrapper(object):
def __init__(self, sockobj):
self.sockobj = sockobj
self.closed = False
def fileno(self):
return self.sockobj.fileno()
def close(self):
self.closed = True
self.sockobj.close()
sockobj, port = bind_unused_port()
socket_wrapper = SocketWrapper(sockobj)
io_loop = IOLoop()
io_loop.add_handler(socket_wrapper, lambda fd, events: None,
IOLoop.READ)
io_loop.close(all_fds=True)
self.assertTrue(socket_wrapper.closed)
def test_handler_callback_file_object(self):
"""The handler callback receives the same fd object it passed in."""
server_sock, port = bind_unused_port()
fds = []
def handle_connection(fd, events):
fds.append(fd)
conn, addr = server_sock.accept()
conn.close()
self.stop()
self.io_loop.add_handler(server_sock, handle_connection, IOLoop.READ)
with contextlib.closing(socket.socket()) as client_sock:
client_sock.connect(('127.0.0.1', port))
self.wait()
self.io_loop.remove_handler(server_sock)
self.io_loop.add_handler(server_sock.fileno(), handle_connection,
IOLoop.READ)
with contextlib.closing(socket.socket()) as client_sock:
client_sock.connect(('127.0.0.1', port))
self.wait()
self.assertIs(fds[0], server_sock)
self.assertEqual(fds[1], server_sock.fileno())
self.io_loop.remove_handler(server_sock.fileno())
server_sock.close()
def test_mixed_fd_fileobj(self):
server_sock, port = bind_unused_port()
def f(fd, events):
pass
self.io_loop.add_handler(server_sock, f, IOLoop.READ)
with self.assertRaises(Exception):
# The exact error is unspecified - some implementations use
# IOError, others use ValueError.
self.io_loop.add_handler(server_sock.fileno(), f, IOLoop.READ)
self.io_loop.remove_handler(server_sock.fileno())
server_sock.close()
def test_reentrant(self):
"""Calling start() twice should raise an error, not deadlock."""
returned_from_start = [False]
got_exception = [False]
def callback():
try:
self.io_loop.start()
returned_from_start[0] = True
except Exception:
got_exception[0] = True
self.stop()
self.io_loop.add_callback(callback)
self.wait()
self.assertTrue(got_exception[0])
self.assertFalse(returned_from_start[0])
def test_exception_logging(self):
"""Uncaught exceptions get logged by the IOLoop."""
# Use a NullContext to keep the exception from being caught by
# AsyncTestCase.
with NullContext():
self.io_loop.add_callback(lambda: 1/0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
def test_exception_logging_future(self):
"""The IOLoop examines exceptions from Futures and logs them."""
with NullContext():
@gen.coroutine
def callback():
self.io_loop.add_callback(self.stop)
1/0
self.io_loop.add_callback(callback)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
def test_spawn_callback(self):
# An added callback runs in the test's stack_context, so will be
# re-arised in wait().
self.io_loop.add_callback(lambda: 1/0)
with self.assertRaises(ZeroDivisionError):
self.wait()
# A spawned callback is run directly on the IOLoop, so it will be
# logged without stopping the test.
self.io_loop.spawn_callback(lambda: 1/0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
# Deliberately not a subclass of AsyncTestCase so the IOLoop isn't
# automatically set as current.

View file

@ -1,13 +1,16 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado.concurrent import Future
from tornado import gen
from tornado import netutil
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream, SSLIOStream, PipeIOStream
from tornado.iostream import IOStream, SSLIOStream, PipeIOStream, StreamClosedError
from tornado.httputil import HTTPHeaders
from tornado.log import gen_log, app_log
from tornado.netutil import ssl_wrap_socket
from tornado.stack_context import NullContext
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test
from tornado.test.util import unittest, skipIfNonUnix
from tornado.web import RequestHandler, Application
import certifi
import errno
import logging
import os
@ -17,6 +20,13 @@ import ssl
import sys
def _server_ssl_options():
return dict(
certfile=os.path.join(os.path.dirname(__file__), 'test.crt'),
keyfile=os.path.join(os.path.dirname(__file__), 'test.key'),
)
class HelloHandler(RequestHandler):
def get(self):
self.write("Hello")
@ -106,6 +116,48 @@ class TestIOStreamWebMixin(object):
stream.close()
@gen_test
def test_future_interface(self):
"""Basic test of IOStream's ability to return Futures."""
stream = self._make_client_iostream()
connect_result = yield stream.connect(
("localhost", self.get_http_port()))
self.assertIs(connect_result, stream)
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
first_line = yield stream.read_until(b"\r\n")
self.assertEqual(first_line, b"HTTP/1.0 200 OK\r\n")
# callback=None is equivalent to no callback.
header_data = yield stream.read_until(b"\r\n\r\n", callback=None)
headers = HTTPHeaders.parse(header_data.decode('latin1'))
content_length = int(headers['Content-Length'])
body = yield stream.read_bytes(content_length)
self.assertEqual(body, b'Hello')
stream.close()
@gen_test
def test_future_close_while_reading(self):
stream = self._make_client_iostream()
yield stream.connect(("localhost", self.get_http_port()))
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
with self.assertRaises(StreamClosedError):
yield stream.read_bytes(1024 * 1024)
stream.close()
@gen_test
def test_future_read_until_close(self):
# Ensure that the data comes through before the StreamClosedError.
stream = self._make_client_iostream()
yield stream.connect(("localhost", self.get_http_port()))
yield stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n")
yield stream.read_until(b"\r\n\r\n")
body = yield stream.read_until_close()
self.assertEqual(body, b"Hello")
# Nothing else to read; the error comes immediately without waiting
# for yield.
with self.assertRaises(StreamClosedError):
stream.read_bytes(1)
class TestIOStreamMixin(object):
def _make_server_iostream(self, connection, **kwargs):
@ -158,9 +210,6 @@ class TestIOStreamMixin(object):
server, client = self.make_iostream_pair()
server.write(b'', callback=self.stop)
self.wait()
# As a side effect, the stream is now listening for connection
# close (if it wasn't already), but is not listening for writes
self.assertEqual(server._state, IOLoop.READ | IOLoop.ERROR)
server.close()
client.close()
@ -298,6 +347,25 @@ class TestIOStreamMixin(object):
server.close()
client.close()
def test_future_delayed_close_callback(self):
# Same as test_delayed_close_callback, but with the future interface.
server, client = self.make_iostream_pair()
# We can't call make_iostream_pair inside a gen_test function
# because the ioloop is not reentrant.
@gen_test
def f(self):
server.write(b"12")
chunks = []
chunks.append((yield client.read_bytes(1)))
server.close()
chunks.append((yield client.read_bytes(1)))
self.assertEqual(chunks, [b"1", b"2"])
try:
f(self)
finally:
server.close()
client.close()
def test_close_buffered_data(self):
# Similar to the previous test, but with data stored in the OS's
# socket buffers instead of the IOStream's read buffer. Out-of-band
@ -330,14 +398,18 @@ class TestIOStreamMixin(object):
# Similar to test_delayed_close_callback, but read_until_close takes
# a separate code path so test it separately.
server, client = self.make_iostream_pair()
client.set_close_callback(self.stop)
try:
server.write(b"1234")
server.close()
self.wait()
# Read one byte to make sure the client has received the data.
# It won't run the close callback as long as there is more buffered
# data that could satisfy a later read.
client.read_bytes(1, self.stop)
data = self.wait()
self.assertEqual(data, b"1")
client.read_until_close(self.stop)
data = self.wait()
self.assertEqual(data, b"1234")
self.assertEqual(data, b"234")
finally:
server.close()
client.close()
@ -347,17 +419,18 @@ class TestIOStreamMixin(object):
# All data should go through the streaming callback,
# and the final read callback just gets an empty string.
server, client = self.make_iostream_pair()
client.set_close_callback(self.stop)
try:
server.write(b"1234")
server.close()
self.wait()
client.read_bytes(1, self.stop)
data = self.wait()
self.assertEqual(data, b"1")
streaming_data = []
client.read_until_close(self.stop,
streaming_callback=streaming_data.append)
data = self.wait()
self.assertEqual(b'', data)
self.assertEqual(b''.join(streaming_data), b"1234")
self.assertEqual(b''.join(streaming_data), b"234")
finally:
server.close()
client.close()
@ -451,6 +524,203 @@ class TestIOStreamMixin(object):
server.close()
client.close()
def test_future_close_callback(self):
# Regression test for interaction between the Future read interfaces
# and IOStream._maybe_add_error_listener.
server, client = self.make_iostream_pair()
closed = [False]
def close_callback():
closed[0] = True
self.stop()
server.set_close_callback(close_callback)
try:
client.write(b'a')
future = server.read_bytes(1)
self.io_loop.add_future(future, self.stop)
self.assertEqual(self.wait().result(), b'a')
self.assertFalse(closed[0])
client.close()
self.wait()
self.assertTrue(closed[0])
finally:
server.close()
client.close()
def test_read_bytes_partial(self):
server, client = self.make_iostream_pair()
try:
# Ask for more than is available with partial=True
client.read_bytes(50, self.stop, partial=True)
server.write(b"hello")
data = self.wait()
self.assertEqual(data, b"hello")
# Ask for less than what is available; num_bytes is still
# respected.
client.read_bytes(3, self.stop, partial=True)
server.write(b"world")
data = self.wait()
self.assertEqual(data, b"wor")
# Partial reads won't return an empty string, but read_bytes(0)
# will.
client.read_bytes(0, self.stop, partial=True)
data = self.wait()
self.assertEqual(data, b'')
finally:
server.close()
client.close()
def test_read_until_max_bytes(self):
server, client = self.make_iostream_pair()
client.set_close_callback(lambda: self.stop("closed"))
try:
# Extra room under the limit
client.read_until(b"def", self.stop, max_bytes=50)
server.write(b"abcdef")
data = self.wait()
self.assertEqual(data, b"abcdef")
# Just enough space
client.read_until(b"def", self.stop, max_bytes=6)
server.write(b"abcdef")
data = self.wait()
self.assertEqual(data, b"abcdef")
# Not enough space, but we don't know it until all we can do is
# log a warning and close the connection.
with ExpectLog(gen_log, "Unsatisfiable read"):
client.read_until(b"def", self.stop, max_bytes=5)
server.write(b"123456")
data = self.wait()
self.assertEqual(data, "closed")
finally:
server.close()
client.close()
def test_read_until_max_bytes_inline(self):
server, client = self.make_iostream_pair()
client.set_close_callback(lambda: self.stop("closed"))
try:
# Similar to the error case in the previous test, but the
# server writes first so client reads are satisfied
# inline. For consistency with the out-of-line case, we
# do not raise the error synchronously.
server.write(b"123456")
with ExpectLog(gen_log, "Unsatisfiable read"):
client.read_until(b"def", self.stop, max_bytes=5)
data = self.wait()
self.assertEqual(data, "closed")
finally:
server.close()
client.close()
def test_read_until_max_bytes_ignores_extra(self):
server, client = self.make_iostream_pair()
client.set_close_callback(lambda: self.stop("closed"))
try:
# Even though data that matches arrives the same packet that
# puts us over the limit, we fail the request because it was not
# found within the limit.
server.write(b"abcdef")
with ExpectLog(gen_log, "Unsatisfiable read"):
client.read_until(b"def", self.stop, max_bytes=5)
data = self.wait()
self.assertEqual(data, "closed")
finally:
server.close()
client.close()
def test_read_until_regex_max_bytes(self):
server, client = self.make_iostream_pair()
client.set_close_callback(lambda: self.stop("closed"))
try:
# Extra room under the limit
client.read_until_regex(b"def", self.stop, max_bytes=50)
server.write(b"abcdef")
data = self.wait()
self.assertEqual(data, b"abcdef")
# Just enough space
client.read_until_regex(b"def", self.stop, max_bytes=6)
server.write(b"abcdef")
data = self.wait()
self.assertEqual(data, b"abcdef")
# Not enough space, but we don't know it until all we can do is
# log a warning and close the connection.
with ExpectLog(gen_log, "Unsatisfiable read"):
client.read_until_regex(b"def", self.stop, max_bytes=5)
server.write(b"123456")
data = self.wait()
self.assertEqual(data, "closed")
finally:
server.close()
client.close()
def test_read_until_regex_max_bytes_inline(self):
server, client = self.make_iostream_pair()
client.set_close_callback(lambda: self.stop("closed"))
try:
# Similar to the error case in the previous test, but the
# server writes first so client reads are satisfied
# inline. For consistency with the out-of-line case, we
# do not raise the error synchronously.
server.write(b"123456")
with ExpectLog(gen_log, "Unsatisfiable read"):
client.read_until_regex(b"def", self.stop, max_bytes=5)
data = self.wait()
self.assertEqual(data, "closed")
finally:
server.close()
client.close()
def test_read_until_regex_max_bytes_ignores_extra(self):
server, client = self.make_iostream_pair()
client.set_close_callback(lambda: self.stop("closed"))
try:
# Even though data that matches arrives the same packet that
# puts us over the limit, we fail the request because it was not
# found within the limit.
server.write(b"abcdef")
with ExpectLog(gen_log, "Unsatisfiable read"):
client.read_until_regex(b"def", self.stop, max_bytes=5)
data = self.wait()
self.assertEqual(data, "closed")
finally:
server.close()
client.close()
def test_small_reads_from_large_buffer(self):
# 10KB buffer size, 100KB available to read.
# Read 1KB at a time and make sure that the buffer is not eagerly
# filled.
server, client = self.make_iostream_pair(max_buffer_size=10 * 1024)
try:
server.write(b"a" * 1024 * 100)
for i in range(100):
client.read_bytes(1024, self.stop)
data = self.wait()
self.assertEqual(data, b"a" * 1024)
finally:
server.close()
client.close()
def test_small_read_untils_from_large_buffer(self):
# 10KB buffer size, 100KB available to read.
# Read 1KB at a time and make sure that the buffer is not eagerly
# filled.
server, client = self.make_iostream_pair(max_buffer_size=10 * 1024)
try:
server.write((b"a" * 1023 + b"\n") * 100)
for i in range(100):
client.read_until(b"\n", self.stop, max_bytes=4096)
data = self.wait()
self.assertEqual(data, b"a" * 1023 + b"\n")
finally:
server.close()
client.close()
class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase):
def _make_client_iostream(self):
@ -472,14 +742,10 @@ class TestIOStream(TestIOStreamMixin, AsyncTestCase):
class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase):
def _make_server_iostream(self, connection, **kwargs):
ssl_options = dict(
certfile=os.path.join(os.path.dirname(__file__), 'test.crt'),
keyfile=os.path.join(os.path.dirname(__file__), 'test.key'),
)
connection = ssl.wrap_socket(connection,
server_side=True,
do_handshake_on_connect=False,
**ssl_options)
**_server_ssl_options())
return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
def _make_client_iostream(self, connection, **kwargs):
@ -507,6 +773,91 @@ class TestIOStreamSSLContext(TestIOStreamMixin, AsyncTestCase):
ssl_options=context, **kwargs)
class TestIOStreamStartTLS(AsyncTestCase):
def setUp(self):
try:
super(TestIOStreamStartTLS, self).setUp()
self.listener, self.port = bind_unused_port()
self.server_stream = None
self.server_accepted = Future()
netutil.add_accept_handler(self.listener, self.accept)
self.client_stream = IOStream(socket.socket())
self.io_loop.add_future(self.client_stream.connect(
('127.0.0.1', self.port)), self.stop)
self.wait()
self.io_loop.add_future(self.server_accepted, self.stop)
self.wait()
except Exception as e:
print(e)
raise
def tearDown(self):
if self.server_stream is not None:
self.server_stream.close()
if self.client_stream is not None:
self.client_stream.close()
self.listener.close()
super(TestIOStreamStartTLS, self).tearDown()
def accept(self, connection, address):
if self.server_stream is not None:
self.fail("should only get one connection")
self.server_stream = IOStream(connection)
self.server_accepted.set_result(None)
@gen.coroutine
def client_send_line(self, line):
self.client_stream.write(line)
recv_line = yield self.server_stream.read_until(b"\r\n")
self.assertEqual(line, recv_line)
@gen.coroutine
def server_send_line(self, line):
self.server_stream.write(line)
recv_line = yield self.client_stream.read_until(b"\r\n")
self.assertEqual(line, recv_line)
def client_start_tls(self, ssl_options=None):
client_stream = self.client_stream
self.client_stream = None
return client_stream.start_tls(False, ssl_options)
def server_start_tls(self, ssl_options=None):
server_stream = self.server_stream
self.server_stream = None
return server_stream.start_tls(True, ssl_options)
@gen_test
def test_start_tls_smtp(self):
# This flow is simplified from RFC 3207 section 5.
# We don't really need all of this, but it helps to make sure
# that after realistic back-and-forth traffic the buffers end up
# in a sane state.
yield self.server_send_line(b"220 mail.example.com ready\r\n")
yield self.client_send_line(b"EHLO mail.example.com\r\n")
yield self.server_send_line(b"250-mail.example.com welcome\r\n")
yield self.server_send_line(b"250 STARTTLS\r\n")
yield self.client_send_line(b"STARTTLS\r\n")
yield self.server_send_line(b"220 Go ahead\r\n")
client_future = self.client_start_tls()
server_future = self.server_start_tls(_server_ssl_options())
self.client_stream = yield client_future
self.server_stream = yield server_future
self.assertTrue(isinstance(self.client_stream, SSLIOStream))
self.assertTrue(isinstance(self.server_stream, SSLIOStream))
yield self.client_send_line(b"EHLO mail.example.com\r\n")
yield self.server_send_line(b"250 mail.example.com welcome\r\n")
@gen_test
def test_handshake_fail(self):
self.server_start_tls(_server_ssl_options())
client_future = self.client_start_tls(
dict(cert_reqs=ssl.CERT_REQUIRED, ca_certs=certifi.where()))
with ExpectLog(gen_log, "SSL Error"):
with self.assertRaises(ssl.SSLError):
yield client_future
@skipIfNonUnix
class TestPipeIOStream(AsyncTestCase):
def test_pipe_iostream(self):

View file

@ -20,6 +20,8 @@ import glob
import logging
import os
import re
import subprocess
import sys
import tempfile
import warnings
@ -156,3 +158,50 @@ class EnablePrettyLoggingTest(unittest.TestCase):
for filename in glob.glob(tmpdir + '/test_log*'):
os.unlink(filename)
os.rmdir(tmpdir)
class LoggingOptionTest(unittest.TestCase):
"""Test the ability to enable and disable Tornado's logging hooks."""
def logs_present(self, statement, args=None):
# Each test may manipulate and/or parse the options and then logs
# a line at the 'info' level. This level is ignored in the
# logging module by default, but Tornado turns it on by default
# so it is the easiest way to tell whether tornado's logging hooks
# ran.
IMPORT = 'from tornado.options import options, parse_command_line'
LOG_INFO = 'import logging; logging.info("hello")'
program = ';'.join([IMPORT, statement, LOG_INFO])
proc = subprocess.Popen(
[sys.executable, '-c', program] + (args or []),
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
stdout, stderr = proc.communicate()
self.assertEqual(proc.returncode, 0, 'process failed: %r' % stdout)
return b'hello' in stdout
def test_default(self):
self.assertFalse(self.logs_present('pass'))
def test_tornado_default(self):
self.assertTrue(self.logs_present('parse_command_line()'))
def test_disable_command_line(self):
self.assertFalse(self.logs_present('parse_command_line()',
['--logging=none']))
def test_disable_command_line_case_insensitive(self):
self.assertFalse(self.logs_present('parse_command_line()',
['--logging=None']))
def test_disable_code_string(self):
self.assertFalse(self.logs_present(
'options.logging = "none"; parse_command_line()'))
def test_disable_code_none(self):
self.assertFalse(self.logs_present(
'options.logging = None; parse_command_line()'))
def test_disable_override(self):
# command line trumps code defaults
self.assertTrue(self.logs_present(
'options.logging = None; parse_command_line()',
['--logging=info']))

View file

@ -1,15 +1,16 @@
from __future__ import absolute_import, division, print_function, with_statement
import os
import signal
import socket
from subprocess import Popen
import sys
import time
from tornado.netutil import BlockingResolver, ThreadedResolver, is_valid_ip
from tornado.netutil import BlockingResolver, ThreadedResolver, is_valid_ip, bind_sockets
from tornado.stack_context import ExceptionStackContext
from tornado.testing import AsyncTestCase, gen_test
from tornado.test.util import unittest
from tornado.test.util import unittest, skipIfNoNetwork
try:
from concurrent import futures
@ -25,6 +26,7 @@ else:
try:
import twisted
import twisted.names
except ImportError:
twisted = None
else:
@ -73,12 +75,14 @@ class _ResolverTestMixin(object):
socket.AF_UNSPEC)
@skipIfNoNetwork
class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(BlockingResolverTest, self).setUp()
self.resolver = BlockingResolver(io_loop=self.io_loop)
@skipIfNoNetwork
@unittest.skipIf(futures is None, "futures module not present")
class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
@ -90,7 +94,9 @@ class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
super(ThreadedResolverTest, self).tearDown()
@skipIfNoNetwork
@unittest.skipIf(futures is None, "futures module not present")
@unittest.skipIf(sys.platform == 'win32', "preexec_fn not available on win32")
class ThreadedResolverImportTest(unittest.TestCase):
def test_import(self):
TIMEOUT = 5
@ -115,6 +121,7 @@ class ThreadedResolverImportTest(unittest.TestCase):
self.fail("import timed out")
@skipIfNoNetwork
@unittest.skipIf(pycares is None, "pycares module not present")
class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
@ -122,6 +129,7 @@ class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
self.resolver = CaresResolver(io_loop=self.io_loop)
@skipIfNoNetwork
@unittest.skipIf(twisted is None, "twisted module not present")
@unittest.skipIf(getattr(twisted, '__version__', '0.0') < "12.1", "old version of twisted")
class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin):
@ -144,3 +152,17 @@ class IsValidIPTest(unittest.TestCase):
self.assertTrue(not is_valid_ip(' '))
self.assertTrue(not is_valid_ip('\n'))
self.assertTrue(not is_valid_ip('\x00'))
class TestPortAllocation(unittest.TestCase):
def test_same_port_allocation(self):
if 'TRAVIS' in os.environ:
self.skipTest("dual-stack servers often have port conflicts on travis")
sockets = bind_sockets(None, 'localhost')
try:
port = sockets[0].getsockname()[1]
self.assertTrue(all(s.getsockname()[1] == port
for s in sockets[1:]))
finally:
for sock in sockets:
sock.close()

View file

@ -40,6 +40,7 @@ TEST_MODULES = [
'tornado.test.process_test',
'tornado.test.simple_httpclient_test',
'tornado.test.stack_context_test',
'tornado.test.tcpclient_test',
'tornado.test.template_test',
'tornado.test.testing_test',
'tornado.test.twisted_test',
@ -65,7 +66,8 @@ class TornadoTextTestRunner(unittest.TextTestRunner):
self.stream.write("\n")
return result
if __name__ == '__main__':
def main():
# The -W command-line option does not work in a virtualenv with
# python 3 (as of virtualenv 1.7), so configure warnings
# programmatically instead.
@ -82,6 +84,9 @@ if __name__ == '__main__':
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("error", category=DeprecationWarning,
module=r"tornado\..*")
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
warnings.filterwarnings("error", category=PendingDeprecationWarning,
module=r"tornado\..*")
# The unittest module is aggressive about deprecating redundant methods,
# leaving some without non-deprecated spellings that work on both
# 2.7 and 3.2
@ -127,3 +132,6 @@ if __name__ == '__main__':
kwargs['warnings'] = False
kwargs['testRunner'] = TornadoTextTestRunner
tornado.testing.main(**kwargs)
if __name__ == '__main__':
main()

View file

@ -10,17 +10,18 @@ import re
import socket
import sys
from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httputil import HTTPHeaders
from tornado.ioloop import IOLoop
from tornado.log import gen_log
from tornado.netutil import Resolver
from tornado.simple_httpclient import SimpleAsyncHTTPClient, _DEFAULT_CA_CERTS
from tornado.log import gen_log, app_log
from tornado.netutil import Resolver, bind_sockets
from tornado.simple_httpclient import SimpleAsyncHTTPClient, _default_ca_certs
from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler
from tornado.test import httpclient_test
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
from tornado.test.util import unittest, skipOnTravis
from tornado.web import RequestHandler, Application, asynchronous, url
from tornado.test.util import skipOnTravis, skipIfNoIPv6
from tornado.web import RequestHandler, Application, asynchronous, url, stream_request_body
class SimpleHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
@ -70,7 +71,8 @@ class OptionsHandler(RequestHandler):
class NoContentHandler(RequestHandler):
def get(self):
if self.get_argument("error", None):
self.set_header("Content-Length", "7")
self.set_header("Content-Length", "5")
self.write("hello")
self.set_status(204)
@ -94,6 +96,30 @@ class HostEchoHandler(RequestHandler):
self.write(self.request.headers["Host"])
class NoContentLengthHandler(RequestHandler):
@gen.coroutine
def get(self):
# Emulate the old HTTP/1.0 behavior of returning a body with no
# content-length. Tornado handles content-length at the framework
# level so we have to go around it.
stream = self.request.connection.stream
yield stream.write(b"HTTP/1.0 200 OK\r\n\r\n"
b"hello")
stream.close()
class EchoPostHandler(RequestHandler):
def post(self):
self.write(self.request.body)
@stream_request_body
class RespondInPrepareHandler(RequestHandler):
def prepare(self):
self.set_status(403)
self.finish("forbidden")
class SimpleHTTPClientTestMixin(object):
def get_app(self):
# callable objects to finish pending /trigger requests
@ -112,6 +138,9 @@ class SimpleHTTPClientTestMixin(object):
url("/see_other_post", SeeOtherPostHandler),
url("/see_other_get", SeeOtherGetHandler),
url("/host_echo", HostEchoHandler),
url("/no_content_length", NoContentLengthHandler),
url("/echo_post", EchoPostHandler),
url("/respond_in_prepare", RespondInPrepareHandler),
], gzip=True)
def test_singleton(self):
@ -163,7 +192,7 @@ class SimpleHTTPClientTestMixin(object):
response.rethrow()
def test_default_certificates_exist(self):
open(_DEFAULT_CA_CERTS).close()
open(_default_ca_certs()).close()
def test_gzip(self):
# All the tests in this file should be using gzip, but this test
@ -213,28 +242,30 @@ class SimpleHTTPClientTestMixin(object):
# trigger the hanging request to let it clean up after itself
self.triggers.popleft()()
@unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present')
@skipIfNoIPv6
def test_ipv6(self):
try:
self.http_server.listen(self.get_http_port(), address='::1')
[sock] = bind_sockets(None, '::1', family=socket.AF_INET6)
port = sock.getsockname()[1]
self.http_server.add_socket(sock)
except socket.gaierror as e:
if e.args[0] == socket.EAI_ADDRFAMILY:
# python supports ipv6, but it's not configured on the network
# interface, so skip this test.
return
raise
url = self.get_url("/hello").replace("localhost", "[::1]")
url = '%s://[::1]:%d/hello' % (self.get_protocol(), port)
# ipv6 is currently disabled by default and must be explicitly requested
self.http_client.fetch(url, self.stop)
# ipv6 is currently enabled by default but can be disabled
self.http_client.fetch(url, self.stop, allow_ipv6=False)
response = self.wait()
self.assertEqual(response.code, 599)
self.http_client.fetch(url, self.stop, allow_ipv6=True)
self.http_client.fetch(url, self.stop)
response = self.wait()
self.assertEqual(response.body, b"Hello world!")
def test_multiple_content_length_accepted(self):
def xtest_multiple_content_length_accepted(self):
response = self.fetch("/content_length?value=2,2")
self.assertEqual(response.body, b"ok")
response = self.fetch("/content_length?value=2,%202,2")
@ -266,7 +297,8 @@ class SimpleHTTPClientTestMixin(object):
self.assertEqual(response.headers["Content-length"], "0")
# 204 status with non-zero content length is malformed
response = self.fetch("/no_content?error=1")
with ExpectLog(app_log, "Uncaught exception"):
response = self.fetch("/no_content?error=1")
self.assertEqual(response.code, 599)
def test_host_header(self):
@ -313,6 +345,60 @@ class SimpleHTTPClientTestMixin(object):
self.triggers.popleft()()
self.wait()
def test_no_content_length(self):
response = self.fetch("/no_content_length")
self.assertEquals(b"hello", response.body)
def sync_body_producer(self, write):
write(b'1234')
write(b'5678')
@gen.coroutine
def async_body_producer(self, write):
yield write(b'1234')
yield gen.Task(IOLoop.current().add_callback)
yield write(b'5678')
def test_sync_body_producer_chunked(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.sync_body_producer)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_sync_body_producer_content_length(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.sync_body_producer,
headers={'Content-Length': '8'})
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_async_body_producer_chunked(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.async_body_producer)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_async_body_producer_content_length(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.async_body_producer,
headers={'Content-Length': '8'})
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_100_continue(self):
response = self.fetch("/echo_post", method="POST",
body=b"1234",
expect_100_continue=True)
self.assertEqual(response.body, b"1234")
def test_100_continue_early_response(self):
def body_producer(write):
raise Exception("should not be called")
response = self.fetch("/respond_in_prepare", method="POST",
body_producer=body_producer,
expect_100_continue=True)
self.assertEqual(response.code, 403)
class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase):
def setUp(self):
@ -433,3 +519,32 @@ class ResolveTimeoutTestCase(AsyncHTTPTestCase):
def test_resolve_timeout(self):
response = self.fetch('/hello', connect_timeout=0.1)
self.assertEqual(response.code, 599)
class MaxHeaderSizeTest(AsyncHTTPTestCase):
def get_app(self):
class SmallHeaders(RequestHandler):
def get(self):
self.set_header("X-Filler", "a" * 100)
self.write("ok")
class LargeHeaders(RequestHandler):
def get(self):
self.set_header("X-Filler", "a" * 1000)
self.write("ok")
return Application([('/small', SmallHeaders),
('/large', LargeHeaders)])
def get_http_client(self):
return SimpleAsyncHTTPClient(io_loop=self.io_loop, max_header_size=1024)
def test_small_headers(self):
response = self.fetch('/small')
response.rethrow()
self.assertEqual(response.body, b'ok')
def test_large_headers(self):
with ExpectLog(gen_log, "Unsatisfiable read"):
response = self.fetch('/large')
self.assertEqual(response.code, 599)

View file

@ -219,22 +219,13 @@ class StackContextTest(AsyncTestCase):
def test_yield_in_with(self):
@gen.engine
def f():
try:
self.callback = yield gen.Callback('a')
with StackContext(functools.partial(self.context, 'c1')):
# This yield is a problem: the generator will be suspended
# and the StackContext's __exit__ is not called yet, so
# the context will be left on _state.contexts for anything
# that runs before the yield resolves.
yield gen.Wait('a')
except StackContextInconsistentError:
# In python <= 3.3, this suspended generator is never garbage
# collected, so it remains suspended in the 'yield' forever.
# Starting in 3.4, it is made collectable by raising
# a GeneratorExit exception from the yield, which gets
# converted into a StackContextInconsistentError by the
# exit of the 'with' block.
pass
self.callback = yield gen.Callback('a')
with StackContext(functools.partial(self.context, 'c1')):
# This yield is a problem: the generator will be suspended
# and the StackContext's __exit__ is not called yet, so
# the context will be left on _state.contexts for anything
# that runs before the yield resolves.
yield gen.Wait('a')
with self.assertRaises(StackContextInconsistentError):
f()
@ -257,11 +248,8 @@ class StackContextTest(AsyncTestCase):
# As above, but with ExceptionStackContext instead of StackContext.
@gen.engine
def f():
try:
with ExceptionStackContext(lambda t, v, tb: False):
yield gen.Task(self.io_loop.add_callback)
except StackContextInconsistentError:
pass
with ExceptionStackContext(lambda t, v, tb: False):
yield gen.Task(self.io_loop.add_callback)
with self.assertRaises(StackContextInconsistentError):
f()

View file

@ -0,0 +1,278 @@
#!/usr/bin/env python
#
# Copyright 2014 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
from contextlib import closing
import os
import socket
from tornado.concurrent import Future
from tornado.netutil import bind_sockets, Resolver
from tornado.tcpclient import TCPClient, _Connector
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, bind_unused_port, gen_test
from tornado.test.util import skipIfNoIPv6, unittest
# Fake address families for testing. Used in place of AF_INET
# and AF_INET6 because some installations do not have AF_INET6.
AF1, AF2 = 1, 2
class TestTCPServer(TCPServer):
def __init__(self, family):
super(TestTCPServer, self).__init__()
self.streams = []
sockets = bind_sockets(None, 'localhost', family)
self.add_sockets(sockets)
self.port = sockets[0].getsockname()[1]
def handle_stream(self, stream, address):
self.streams.append(stream)
def stop(self):
super(TestTCPServer, self).stop()
for stream in self.streams:
stream.close()
class TCPClientTest(AsyncTestCase):
def setUp(self):
super(TCPClientTest, self).setUp()
self.server = None
self.client = TCPClient()
def start_server(self, family):
if family == socket.AF_UNSPEC and 'TRAVIS' in os.environ:
self.skipTest("dual-stack servers often have port conflicts on travis")
self.server = TestTCPServer(family)
return self.server.port
def stop_server(self):
if self.server is not None:
self.server.stop()
self.server = None
def tearDown(self):
self.client.close()
self.stop_server()
super(TCPClientTest, self).tearDown()
def skipIfLocalhostV4(self):
Resolver().resolve('localhost', 0, callback=self.stop)
addrinfo = self.wait()
families = set(addr[0] for addr in addrinfo)
if socket.AF_INET6 not in families:
self.skipTest("localhost does not resolve to ipv6")
@gen_test
def do_test_connect(self, family, host):
port = self.start_server(family)
stream = yield self.client.connect(host, port)
with closing(stream):
stream.write(b"hello")
data = yield self.server.streams[0].read_bytes(5)
self.assertEqual(data, b"hello")
def test_connect_ipv4_ipv4(self):
self.do_test_connect(socket.AF_INET, '127.0.0.1')
def test_connect_ipv4_dual(self):
self.do_test_connect(socket.AF_INET, 'localhost')
@skipIfNoIPv6
def test_connect_ipv6_ipv6(self):
self.skipIfLocalhostV4()
self.do_test_connect(socket.AF_INET6, '::1')
@skipIfNoIPv6
def test_connect_ipv6_dual(self):
self.skipIfLocalhostV4()
if Resolver.configured_class().__name__.endswith('TwistedResolver'):
self.skipTest('TwistedResolver does not support multiple addresses')
self.do_test_connect(socket.AF_INET6, 'localhost')
def test_connect_unspec_ipv4(self):
self.do_test_connect(socket.AF_UNSPEC, '127.0.0.1')
@skipIfNoIPv6
def test_connect_unspec_ipv6(self):
self.skipIfLocalhostV4()
self.do_test_connect(socket.AF_UNSPEC, '::1')
def test_connect_unspec_dual(self):
self.do_test_connect(socket.AF_UNSPEC, 'localhost')
@gen_test
def test_refused_ipv4(self):
sock, port = bind_unused_port()
sock.close()
with self.assertRaises(IOError):
yield self.client.connect('127.0.0.1', port)
class TestConnectorSplit(unittest.TestCase):
def test_one_family(self):
# These addresses aren't in the right format, but split doesn't care.
primary, secondary = _Connector.split(
[(AF1, 'a'),
(AF1, 'b')])
self.assertEqual(primary, [(AF1, 'a'),
(AF1, 'b')])
self.assertEqual(secondary, [])
def test_mixed(self):
primary, secondary = _Connector.split(
[(AF1, 'a'),
(AF2, 'b'),
(AF1, 'c'),
(AF2, 'd')])
self.assertEqual(primary, [(AF1, 'a'), (AF1, 'c')])
self.assertEqual(secondary, [(AF2, 'b'), (AF2, 'd')])
class ConnectorTest(AsyncTestCase):
class FakeStream(object):
def __init__(self):
self.closed = False
def close(self):
self.closed = True
def setUp(self):
super(ConnectorTest, self).setUp()
self.connect_futures = {}
self.streams = {}
self.addrinfo = [(AF1, 'a'), (AF1, 'b'),
(AF2, 'c'), (AF2, 'd')]
def tearDown(self):
# Unless explicitly checked (and popped) in the test, we shouldn't
# be closing any streams
for stream in self.streams.values():
self.assertFalse(stream.closed)
super(ConnectorTest, self).tearDown()
def create_stream(self, af, addr):
future = Future()
self.connect_futures[(af, addr)] = future
return future
def assert_pending(self, *keys):
self.assertEqual(sorted(self.connect_futures.keys()), sorted(keys))
def resolve_connect(self, af, addr, success):
future = self.connect_futures.pop((af, addr))
if success:
self.streams[addr] = ConnectorTest.FakeStream()
future.set_result(self.streams[addr])
else:
future.set_exception(IOError())
def start_connect(self, addrinfo):
conn = _Connector(addrinfo, self.io_loop, self.create_stream)
# Give it a huge timeout; we'll trigger timeouts manually.
future = conn.start(3600)
return conn, future
def test_immediate_success(self):
conn, future = self.start_connect(self.addrinfo)
self.assertEqual(list(self.connect_futures.keys()),
[(AF1, 'a')])
self.resolve_connect(AF1, 'a', True)
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
def test_immediate_failure(self):
# Fail with just one address.
conn, future = self.start_connect([(AF1, 'a')])
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assertRaises(IOError, future.result)
def test_one_family_second_try(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
self.resolve_connect(AF1, 'b', True)
self.assertEqual(future.result(), (AF1, 'b', self.streams['b']))
def test_one_family_second_try_failure(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
self.resolve_connect(AF1, 'b', False)
self.assertRaises(IOError, future.result)
def test_one_family_second_try_timeout(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
# trigger the timeout while the first lookup is pending;
# nothing happens.
conn.on_timeout()
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
self.resolve_connect(AF1, 'b', True)
self.assertEqual(future.result(), (AF1, 'b', self.streams['b']))
def test_two_families_immediate_failure(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'), (AF2, 'c'))
self.resolve_connect(AF1, 'b', False)
self.resolve_connect(AF2, 'c', True)
self.assertEqual(future.result(), (AF2, 'c', self.streams['c']))
def test_two_families_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
conn.on_timeout()
self.assert_pending((AF1, 'a'), (AF2, 'c'))
self.resolve_connect(AF2, 'c', True)
self.assertEqual(future.result(), (AF2, 'c', self.streams['c']))
# resolving 'a' after the connection has completed doesn't start 'b'
self.resolve_connect(AF1, 'a', False)
self.assert_pending()
def test_success_after_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
conn.on_timeout()
self.assert_pending((AF1, 'a'), (AF2, 'c'))
self.resolve_connect(AF1, 'a', True)
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
# resolving 'c' after completion closes the connection.
self.resolve_connect(AF2, 'c', True)
self.assertTrue(self.streams.pop('c').closed)
def test_all_fail(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
conn.on_timeout()
self.assert_pending((AF1, 'a'), (AF2, 'c'))
self.resolve_connect(AF2, 'c', False)
self.assert_pending((AF1, 'a'), (AF2, 'd'))
self.resolve_connect(AF2, 'd', False)
# one queue is now empty
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
self.assertFalse(future.done())
self.resolve_connect(AF1, 'b', False)
self.assertRaises(IOError, future.result)

View file

@ -182,6 +182,7 @@ three
"""})
try:
loader.load("test.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# test.html:2" in traceback.format_exc())
@ -192,6 +193,7 @@ three{%end%}
"""})
try:
loader.load("test.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# test.html:2" in traceback.format_exc())
@ -202,6 +204,7 @@ three{%end%}
}, namespace={"_tt_modules": ObjectDict({"Template": lambda path, **kwargs: loader.load(path).generate(**kwargs)})})
try:
loader.load("base.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
exc_stack = traceback.format_exc()
self.assertTrue('# base.html:1' in exc_stack)
@ -214,6 +217,7 @@ three{%end%}
})
try:
loader.load("base.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# sub.html:1 (via base.html:1)" in
traceback.format_exc())
@ -225,6 +229,7 @@ three{%end%}
})
try:
loader.load("sub.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
exc_stack = traceback.format_exc()
self.assertTrue("# base.html:1" in exc_stack)
@ -240,6 +245,7 @@ three{%end%}
"""})
try:
loader.load("sub.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# sub.html:4 (via base.html:1)" in
traceback.format_exc())
@ -252,6 +258,7 @@ three{%end%}
})
try:
loader.load("a.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# c.html:1 (via b.html:1, a.html:1)" in
traceback.format_exc())
@ -380,6 +387,20 @@ raw: {% raw name %}""",
self.assertEqual(render("foo.py", ["not a string"]),
b"""s = "['not a string']"\n""")
def test_minimize_whitespace(self):
# Whitespace including newlines is allowed within template tags
# and directives, and this is one way to avoid long lines while
# keeping extra whitespace out of the rendered output.
loader = DictLoader({'foo.txt': """\
{% for i in items
%}{% if i > 0 %}, {% end %}{#
#}{{i
}}{% end
%}""",
})
self.assertEqual(loader.load("foo.txt").generate(items=range(5)),
b"0, 1, 2, 3, 4")
class TemplateLoaderTest(unittest.TestCase):
def setUp(self):

View file

@ -8,6 +8,7 @@ from tornado.test.util import unittest
import contextlib
import os
import traceback
@contextlib.contextmanager
@ -62,6 +63,39 @@ class AsyncTestCaseTest(AsyncTestCase):
self.wait(timeout=0.15)
class AsyncTestCaseWrapperTest(unittest.TestCase):
def test_undecorated_generator(self):
class Test(AsyncTestCase):
def test_gen(self):
yield
test = Test('test_gen')
result = unittest.TestResult()
test.run(result)
self.assertEqual(len(result.errors), 1)
self.assertIn("should be decorated", result.errors[0][1])
def test_undecorated_generator_with_skip(self):
class Test(AsyncTestCase):
@unittest.skip("don't run this")
def test_gen(self):
yield
test = Test('test_gen')
result = unittest.TestResult()
test.run(result)
self.assertEqual(len(result.errors), 0)
self.assertEqual(len(result.skipped), 1)
def test_other_return(self):
class Test(AsyncTestCase):
def test_other_return(self):
return 42
test = Test('test_other_return')
result = unittest.TestResult()
test.run(result)
self.assertEqual(len(result.errors), 1)
self.assertIn("Return value from test method ignored", result.errors[0][1])
class SetUpTearDownTest(unittest.TestCase):
def test_set_up_tear_down(self):
"""
@ -115,8 +149,17 @@ class GenTest(AsyncTestCase):
def test(self):
yield gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)
with self.assertRaises(ioloop.TimeoutError):
# This can't use assertRaises because we need to inspect the
# exc_info triple (and not just the exception object)
try:
test(self)
self.fail("did not get expected exception")
except ioloop.TimeoutError:
# The stack trace should blame the add_timeout line, not just
# unrelated IOLoop/testing internals.
self.assertIn(
"gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)",
traceback.format_exc())
self.finished = True
@ -155,5 +198,23 @@ class GenTest(AsyncTestCase):
self.finished = True
def test_with_method_args(self):
@gen_test
def test_with_args(self, *args):
self.assertEqual(args, ('test',))
yield gen.Task(self.io_loop.add_callback)
test_with_args(self, 'test')
self.finished = True
def test_with_method_kwargs(self):
@gen_test
def test_with_kwargs(self, **kwargs):
self.assertDictEqual(kwargs, {'test': 'test'})
yield gen.Task(self.io_loop.add_callback)
test_with_kwargs(self, test='test')
self.finished = True
if __name__ == '__main__':
unittest.main()

View file

@ -1,14 +1,18 @@
from __future__ import absolute_import, division, print_function, with_statement
import os
import socket
import sys
# Encapsulate the choice of unittest or unittest2 here.
# To be used as 'from tornado.test.util import unittest'.
if sys.version_info >= (2, 7):
import unittest
else:
if sys.version_info < (2, 7):
# In py26, we must always use unittest2.
import unittest2 as unittest
else:
# Otherwise, use whichever version of unittest was imported in
# tornado.testing.
from tornado.testing import unittest
skipIfNonUnix = unittest.skipIf(os.name != 'posix' or sys.platform == 'cygwin',
"non-unix platform")
@ -17,3 +21,10 @@ skipIfNonUnix = unittest.skipIf(os.name != 'posix' or sys.platform == 'cygwin',
# timing-related tests unreliable.
skipOnTravis = unittest.skipIf('TRAVIS' in os.environ,
'timing tests unreliable on travis')
# Set the environment variable NO_NETWORK=1 to disable any tests that
# depend on an external network.
skipIfNoNetwork = unittest.skipIf('NO_NETWORK' in os.environ,
'network access disabled')
skipIfNoIPv6 = unittest.skipIf(not socket.has_ipv6, 'ipv6 support not present')

View file

@ -151,14 +151,22 @@ class ArgReplacerTest(unittest.TestCase):
self.replacer = ArgReplacer(function, 'callback')
def test_omitted(self):
self.assertEqual(self.replacer.replace('new', (1, 2), dict()),
args = (1, 2)
kwargs = dict()
self.assertIs(self.replacer.get_old_value(args, kwargs), None)
self.assertEqual(self.replacer.replace('new', args, kwargs),
(None, (1, 2), dict(callback='new')))
def test_position(self):
self.assertEqual(self.replacer.replace('new', (1, 2, 'old', 3), dict()),
args = (1, 2, 'old', 3)
kwargs = dict()
self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
self.assertEqual(self.replacer.replace('new', args, kwargs),
('old', [1, 2, 'new', 3], dict()))
def test_keyword(self):
self.assertEqual(self.replacer.replace('new', (1,),
dict(y=2, callback='old', z=3)),
args = (1,)
kwargs = dict(y=2, callback='old', z=3)
self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
self.assertEqual(self.replacer.replace('new', args, kwargs),
('old', (1,), dict(y=2, callback='new', z=3)))

View file

@ -1,4 +1,5 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado.concurrent import Future
from tornado import gen
from tornado.escape import json_decode, utf8, to_unicode, recursive_unicode, native_str, to_basestring
from tornado.httputil import format_timestamp
@ -6,14 +7,16 @@ from tornado.iostream import IOStream
from tornado.log import app_log, gen_log
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.template import DictLoader
from tornado.testing import AsyncHTTPTestCase, ExpectLog
from tornado.testing import AsyncHTTPTestCase, ExpectLog, gen_test
from tornado.test.util import unittest
from tornado.util import u, bytes_type, ObjectDict, unicode_type
from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError
from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body
import binascii
import contextlib
import datetime
import email.utils
import itertools
import logging
import os
import re
@ -100,14 +103,14 @@ class SecureCookieV1Test(unittest.TestCase):
sig = match.group(2)
self.assertEqual(
_create_signature_v1(handler.application.settings["cookie_secret"],
'foo', '12345678', timestamp),
'foo', '12345678', timestamp),
sig)
# shifting digits from payload to timestamp doesn't alter signature
# (this is not desirable behavior, just confirming that that's how it
# works)
self.assertEqual(
_create_signature_v1(handler.application.settings["cookie_secret"],
'foo', '1234', b'5678' + timestamp),
'foo', '1234', b'5678' + timestamp),
sig)
# tamper with the cookie
handler._cookies['foo'] = utf8('1234|5678%s|%s' % (
@ -471,12 +474,13 @@ class EmptyFlushCallbackHandler(RequestHandler):
@asynchronous
def get(self):
# Ensure that the flush callback is run whether or not there
# was any output.
# was any output. The gen.Task and direct yield forms are
# equivalent.
yield gen.Task(self.flush) # "empty" flush, but writes headers
yield gen.Task(self.flush) # empty flush
self.write("o")
yield gen.Task(self.flush) # flushes the "o"
yield gen.Task(self.flush) # empty flush
yield self.flush() # flushes the "o"
yield self.flush() # empty flush
self.finish("k")
@ -575,8 +579,8 @@ class WSGISafeWebTest(WebTestCase):
"/decode_arg/%E9?foo=%E9&encoding=latin1",
"/decode_arg_kw/%E9?foo=%E9&encoding=latin1",
]
for url in urls:
response = self.fetch(url)
for req_url in urls:
response = self.fetch(req_url)
response.rethrow()
data = json_decode(response.body)
self.assertEqual(data, {u('path'): [u('unicode'), u('\u00e9')],
@ -602,8 +606,8 @@ class WSGISafeWebTest(WebTestCase):
# These urls are all equivalent.
urls = ["/decode_arg/1%20%2B%201?foo=1%20%2B%201&encoding=utf-8",
"/decode_arg/1%20+%201?foo=1+%2B+1&encoding=utf-8"]
for url in urls:
response = self.fetch(url)
for req_url in urls:
response = self.fetch(req_url)
response.rethrow()
data = json_decode(response.body)
self.assertEqual(data, {u('path'): [u('unicode'), u('1 + 1')],
@ -915,17 +919,37 @@ class StaticFileTest(WebTestCase):
response = self.fetch(path % int(include_host))
self.assertEqual(response.body, utf8(str(True)))
def get_and_head(self, *args, **kwargs):
"""Performs a GET and HEAD request and returns the GET response.
Fails if any ``Content-*`` headers returned by the two requests
differ.
"""
head_response = self.fetch(*args, method="HEAD", **kwargs)
get_response = self.fetch(*args, method="GET", **kwargs)
content_headers = set()
for h in itertools.chain(head_response.headers, get_response.headers):
if h.startswith('Content-'):
content_headers.add(h)
for h in content_headers:
self.assertEqual(head_response.headers.get(h),
get_response.headers.get(h),
"%s differs between GET (%s) and HEAD (%s)" %
(h, head_response.headers.get(h),
get_response.headers.get(h)))
return get_response
def test_static_304_if_modified_since(self):
response1 = self.fetch("/static/robots.txt")
response2 = self.fetch("/static/robots.txt", headers={
response1 = self.get_and_head("/static/robots.txt")
response2 = self.get_and_head("/static/robots.txt", headers={
'If-Modified-Since': response1.headers['Last-Modified']})
self.assertEqual(response2.code, 304)
self.assertTrue('Content-Length' not in response2.headers)
self.assertTrue('Last-Modified' not in response2.headers)
def test_static_304_if_none_match(self):
response1 = self.fetch("/static/robots.txt")
response2 = self.fetch("/static/robots.txt", headers={
response1 = self.get_and_head("/static/robots.txt")
response2 = self.get_and_head("/static/robots.txt", headers={
'If-None-Match': response1.headers['Etag']})
self.assertEqual(response2.code, 304)
@ -933,7 +957,7 @@ class StaticFileTest(WebTestCase):
# On windows, the functions that work with time_t do not accept
# negative values, and at least one client (processing.js) seems
# to use if-modified-since 1/1/1960 as a cache-busting technique.
response = self.fetch("/static/robots.txt", headers={
response = self.get_and_head("/static/robots.txt", headers={
'If-Modified-Since': 'Fri, 01 Jan 1960 00:00:00 GMT'})
self.assertEqual(response.code, 200)
@ -944,20 +968,20 @@ class StaticFileTest(WebTestCase):
# when parsing If-Modified-Since.
stat = os.stat(relpath('static/robots.txt'))
response = self.fetch('/static/robots.txt', headers={
response = self.get_and_head('/static/robots.txt', headers={
'If-Modified-Since': format_timestamp(stat.st_mtime - 1)})
self.assertEqual(response.code, 200)
response = self.fetch('/static/robots.txt', headers={
response = self.get_and_head('/static/robots.txt', headers={
'If-Modified-Since': format_timestamp(stat.st_mtime + 1)})
self.assertEqual(response.code, 304)
def test_static_etag(self):
response = self.fetch('/static/robots.txt')
response = self.get_and_head('/static/robots.txt')
self.assertEqual(utf8(response.headers.get("Etag")),
b'"' + self.robots_txt_hash + b'"')
def test_static_with_range(self):
response = self.fetch('/static/robots.txt', headers={
response = self.get_and_head('/static/robots.txt', headers={
'Range': 'bytes=0-9'})
self.assertEqual(response.code, 206)
self.assertEqual(response.body, b"User-agent")
@ -968,7 +992,7 @@ class StaticFileTest(WebTestCase):
"bytes 0-9/26")
def test_static_with_range_full_file(self):
response = self.fetch('/static/robots.txt', headers={
response = self.get_and_head('/static/robots.txt', headers={
'Range': 'bytes=0-'})
# Note: Chrome refuses to play audio if it gets an HTTP 206 in response
# to ``Range: bytes=0-`` :(
@ -980,7 +1004,7 @@ class StaticFileTest(WebTestCase):
self.assertEqual(response.headers.get("Content-Range"), None)
def test_static_with_range_full_past_end(self):
response = self.fetch('/static/robots.txt', headers={
response = self.get_and_head('/static/robots.txt', headers={
'Range': 'bytes=0-10000000'})
self.assertEqual(response.code, 200)
robots_file_path = os.path.join(self.static_dir, "robots.txt")
@ -990,7 +1014,7 @@ class StaticFileTest(WebTestCase):
self.assertEqual(response.headers.get("Content-Range"), None)
def test_static_with_range_partial_past_end(self):
response = self.fetch('/static/robots.txt', headers={
response = self.get_and_head('/static/robots.txt', headers={
'Range': 'bytes=1-10000000'})
self.assertEqual(response.code, 206)
robots_file_path = os.path.join(self.static_dir, "robots.txt")
@ -1000,7 +1024,7 @@ class StaticFileTest(WebTestCase):
self.assertEqual(response.headers.get("Content-Range"), "bytes 1-25/26")
def test_static_with_range_end_edge(self):
response = self.fetch('/static/robots.txt', headers={
response = self.get_and_head('/static/robots.txt', headers={
'Range': 'bytes=22-'})
self.assertEqual(response.body, b": /\n")
self.assertEqual(response.headers.get("Content-Length"), "4")
@ -1008,7 +1032,7 @@ class StaticFileTest(WebTestCase):
"bytes 22-25/26")
def test_static_with_range_neg_end(self):
response = self.fetch('/static/robots.txt', headers={
response = self.get_and_head('/static/robots.txt', headers={
'Range': 'bytes=-4'})
self.assertEqual(response.body, b": /\n")
self.assertEqual(response.headers.get("Content-Length"), "4")
@ -1016,19 +1040,19 @@ class StaticFileTest(WebTestCase):
"bytes 22-25/26")
def test_static_invalid_range(self):
response = self.fetch('/static/robots.txt', headers={
response = self.get_and_head('/static/robots.txt', headers={
'Range': 'asdf'})
self.assertEqual(response.code, 200)
def test_static_unsatisfiable_range_zero_suffix(self):
response = self.fetch('/static/robots.txt', headers={
response = self.get_and_head('/static/robots.txt', headers={
'Range': 'bytes=-0'})
self.assertEqual(response.headers.get("Content-Range"),
"bytes */26")
self.assertEqual(response.code, 416)
def test_static_unsatisfiable_range_invalid_start(self):
response = self.fetch('/static/robots.txt', headers={
response = self.get_and_head('/static/robots.txt', headers={
'Range': 'bytes=26'})
self.assertEqual(response.code, 416)
self.assertEqual(response.headers.get("Content-Range"),
@ -1053,7 +1077,7 @@ class StaticFileTest(WebTestCase):
b'"' + self.robots_txt_hash + b'"')
def test_static_range_if_none_match(self):
response = self.fetch('/static/robots.txt', headers={
response = self.get_and_head('/static/robots.txt', headers={
'Range': 'bytes=1-4',
'If-None-Match': b'"' + self.robots_txt_hash + b'"'})
self.assertEqual(response.code, 304)
@ -1063,7 +1087,7 @@ class StaticFileTest(WebTestCase):
b'"' + self.robots_txt_hash + b'"')
def test_static_404(self):
response = self.fetch('/static/blarg')
response = self.get_and_head('/static/blarg')
self.assertEqual(response.code, 404)
@ -1136,6 +1160,11 @@ class CustomStaticFileTest(WebTestCase):
return b'bar'
raise Exception("unexpected path %r" % path)
def get_content_size(self):
if self.absolute_path == 'CustomStaticFileTest:foo.txt':
return 3
raise Exception("unexpected path %r" % self.absolute_path)
def get_modified_time(self):
return None
@ -1335,6 +1364,7 @@ class ErrorHandlerXSRFTest(WebTestCase):
self.assertEqual(response.code, 404)
@wsgi_safe
class GzipTestCase(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
@ -1347,7 +1377,13 @@ class GzipTestCase(SimpleHandlerTestCase):
def test_gzip(self):
response = self.fetch('/')
self.assertEqual(response.headers['Content-Encoding'], 'gzip')
# simple_httpclient renames the content-encoding header;
# curl_httpclient doesn't.
self.assertEqual(
response.headers.get(
'Content-Encoding',
response.headers.get('X-Consumed-Content-Encoding')),
'gzip')
self.assertEqual(response.headers['Vary'], 'Accept-Encoding')
def test_gzip_not_requested(self):
@ -1799,6 +1835,227 @@ class HandlerByNameTest(WebTestCase):
self.assertEqual(resp.body, b'hello')
class StreamingRequestBodyTest(WebTestCase):
def get_handlers(self):
@stream_request_body
class StreamingBodyHandler(RequestHandler):
def initialize(self, test):
self.test = test
def prepare(self):
self.test.prepared.set_result(None)
def data_received(self, data):
self.test.data.set_result(data)
def get(self):
self.test.finished.set_result(None)
self.write({})
@stream_request_body
class EarlyReturnHandler(RequestHandler):
def prepare(self):
# If we finish the response in prepare, it won't continue to
# the (non-existent) data_received.
raise HTTPError(401)
@stream_request_body
class CloseDetectionHandler(RequestHandler):
def initialize(self, test):
self.test = test
def on_connection_close(self):
super(CloseDetectionHandler, self).on_connection_close()
self.test.close_future.set_result(None)
return [('/stream_body', StreamingBodyHandler, dict(test=self)),
('/early_return', EarlyReturnHandler),
('/close_detection', CloseDetectionHandler, dict(test=self))]
def connect(self, url, connection_close):
# Use a raw connection so we can control the sending of data.
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
s.connect(("localhost", self.get_http_port()))
stream = IOStream(s, io_loop=self.io_loop)
stream.write(b"GET " + url + b" HTTP/1.1\r\n")
if connection_close:
stream.write(b"Connection: close\r\n")
stream.write(b"Transfer-Encoding: chunked\r\n\r\n")
return stream
@gen_test
def test_streaming_body(self):
self.prepared = Future()
self.data = Future()
self.finished = Future()
stream = self.connect(b"/stream_body", connection_close=True)
yield self.prepared
stream.write(b"4\r\nasdf\r\n")
# Ensure the first chunk is received before we send the second.
data = yield self.data
self.assertEqual(data, b"asdf")
self.data = Future()
stream.write(b"4\r\nqwer\r\n")
data = yield self.data
self.assertEquals(data, b"qwer")
stream.write(b"0\r\n")
yield self.finished
data = yield gen.Task(stream.read_until_close)
# This would ideally use an HTTP1Connection to read the response.
self.assertTrue(data.endswith(b"{}"))
stream.close()
@gen_test
def test_early_return(self):
stream = self.connect(b"/early_return", connection_close=False)
data = yield gen.Task(stream.read_until_close)
self.assertTrue(data.startswith(b"HTTP/1.1 401"))
@gen_test
def test_early_return_with_data(self):
stream = self.connect(b"/early_return", connection_close=False)
stream.write(b"4\r\nasdf\r\n")
data = yield gen.Task(stream.read_until_close)
self.assertTrue(data.startswith(b"HTTP/1.1 401"))
@gen_test
def test_close_during_upload(self):
self.close_future = Future()
stream = self.connect(b"/close_detection", connection_close=False)
stream.close()
yield self.close_future
class StreamingRequestFlowControlTest(WebTestCase):
def get_handlers(self):
from tornado.ioloop import IOLoop
# Each method in this handler returns a Future and yields to the
# IOLoop so the future is not immediately ready. Ensure that the
# Futures are respected and no method is called before the previous
# one has completed.
@stream_request_body
class FlowControlHandler(RequestHandler):
def initialize(self, test):
self.test = test
self.method = None
self.methods = []
@contextlib.contextmanager
def in_method(self, method):
if self.method is not None:
self.test.fail("entered method %s while in %s" %
(method, self.method))
self.method = method
self.methods.append(method)
try:
yield
finally:
self.method = None
@gen.coroutine
def prepare(self):
with self.in_method('prepare'):
yield gen.Task(IOLoop.current().add_callback)
@gen.coroutine
def data_received(self, data):
with self.in_method('data_received'):
yield gen.Task(IOLoop.current().add_callback)
@gen.coroutine
def post(self):
with self.in_method('post'):
yield gen.Task(IOLoop.current().add_callback)
self.write(dict(methods=self.methods))
return [('/', FlowControlHandler, dict(test=self))]
def get_httpserver_options(self):
# Use a small chunk size so flow control is relevant even though
# all the data arrives at once.
return dict(chunk_size=10)
def test_flow_control(self):
response = self.fetch('/', body='abcdefghijklmnopqrstuvwxyz',
method='POST')
response.rethrow()
self.assertEqual(json_decode(response.body),
dict(methods=['prepare', 'data_received',
'data_received', 'data_received',
'post']))
@wsgi_safe
class IncorrectContentLengthTest(SimpleHandlerTestCase):
def get_handlers(self):
test = self
self.server_error = None
# Manually set a content-length that doesn't match the actual content.
class TooHigh(RequestHandler):
def get(self):
self.set_header("Content-Length", "42")
try:
self.finish("ok")
except Exception as e:
test.server_error = e
raise
class TooLow(RequestHandler):
def get(self):
self.set_header("Content-Length", "2")
try:
self.finish("hello")
except Exception as e:
test.server_error = e
raise
return [('/high', TooHigh),
('/low', TooLow)]
def test_content_length_too_high(self):
# When the content-length is too high, the connection is simply
# closed without completing the response. An error is logged on
# the server.
with ExpectLog(app_log, "Uncaught exception"):
with ExpectLog(gen_log,
"Cannot send error response after headers written"):
response = self.fetch("/high")
self.assertEqual(response.code, 599)
self.assertEqual(str(self.server_error),
"Tried to write 40 bytes less than Content-Length")
def test_content_length_too_low(self):
# When the content-length is too low, the connection is closed
# without writing the last chunk, so the client never sees the request
# complete (which would be a framing error).
with ExpectLog(app_log, "Uncaught exception"):
with ExpectLog(gen_log,
"Cannot send error response after headers written"):
response = self.fetch("/low")
self.assertEqual(response.code, 599)
self.assertEqual(str(self.server_error),
"Tried to write more data than Content-Length")
class ClientCloseTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
# Simulate a connection closed by the client during
# request processing. The client will see an error, but the
# server should respond gracefully (without logging errors
# because we were unable to write out as many bytes as
# Content-Length said we would)
self.request.connection.stream.close()
self.write('hello')
def test_client_close(self):
response = self.fetch('/')
self.assertEqual(response.code, 599)
class SignedValueTest(unittest.TestCase):
SECRET = "It's a secret to everybody"

View file

@ -6,7 +6,7 @@ from tornado.concurrent import Future
from tornado.httpclient import HTTPError, HTTPRequest
from tornado.log import gen_log
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
from tornado.test.util import unittest, skipOnTravis
from tornado.test.util import unittest
from tornado.web import Application, RequestHandler
try:
@ -37,7 +37,7 @@ class TestWebSocketHandler(WebSocketHandler):
self.close_future = close_future
def on_close(self):
self.close_future.set_result(None)
self.close_future.set_result((self.close_code, self.close_reason))
class EchoHandler(TestWebSocketHandler):
@ -47,6 +47,13 @@ class EchoHandler(TestWebSocketHandler):
class HeaderHandler(TestWebSocketHandler):
def open(self):
try:
# In a websocket context, many RequestHandler methods
# raise RuntimeErrors.
self.set_status(503)
raise Exception("did not get expected exception")
except RuntimeError:
pass
self.write_message(self.request.headers.get('X-Test', ''))
@ -55,6 +62,11 @@ class NonWebSocketHandler(RequestHandler):
self.write('ok')
class CloseReasonHandler(TestWebSocketHandler):
def open(self):
self.close(1001, "goodbye")
class WebSocketTest(AsyncHTTPTestCase):
def get_app(self):
self.close_future = Future()
@ -62,8 +74,15 @@ class WebSocketTest(AsyncHTTPTestCase):
('/echo', EchoHandler, dict(close_future=self.close_future)),
('/non_ws', NonWebSocketHandler),
('/header', HeaderHandler, dict(close_future=self.close_future)),
('/close_reason', CloseReasonHandler,
dict(close_future=self.close_future)),
])
def test_http_request(self):
# WS server, HTTP client.
response = self.fetch('/echo')
self.assertEqual(response.code, 400)
@gen_test
def test_websocket_gen(self):
ws = yield websocket_connect(
@ -84,8 +103,9 @@ class WebSocketTest(AsyncHTTPTestCase):
ws.read_message(self.stop)
response = self.wait().result()
self.assertEqual(response, 'hello')
self.close_future.add_done_callback(lambda f: self.stop())
ws.close()
yield self.close_future
self.wait()
@gen_test
def test_websocket_http_fail(self):
@ -102,30 +122,16 @@ class WebSocketTest(AsyncHTTPTestCase):
'ws://localhost:%d/non_ws' % self.get_http_port(),
io_loop=self.io_loop)
@skipOnTravis
@gen_test
def test_websocket_network_timeout(self):
sock, port = bind_unused_port()
sock.close()
with self.assertRaises(HTTPError) as cm:
with ExpectLog(gen_log, ".*"):
yield websocket_connect(
'ws://localhost:%d/' % port,
io_loop=self.io_loop,
connect_timeout=0.01)
self.assertEqual(cm.exception.code, 599)
@gen_test
def test_websocket_network_fail(self):
sock, port = bind_unused_port()
sock.close()
with self.assertRaises(HTTPError) as cm:
with self.assertRaises(IOError):
with ExpectLog(gen_log, ".*"):
yield websocket_connect(
'ws://localhost:%d/' % port,
io_loop=self.io_loop,
connect_timeout=3600)
self.assertEqual(cm.exception.code, 599)
@gen_test
def test_websocket_close_buffered_data(self):
@ -147,6 +153,97 @@ class WebSocketTest(AsyncHTTPTestCase):
ws.close()
yield self.close_future
@gen_test
def test_server_close_reason(self):
ws = yield websocket_connect(
'ws://localhost:%d/close_reason' % self.get_http_port())
msg = yield ws.read_message()
# A message of None means the other side closed the connection.
self.assertIs(msg, None)
self.assertEqual(ws.close_code, 1001)
self.assertEqual(ws.close_reason, "goodbye")
@gen_test
def test_client_close_reason(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port())
ws.close(1001, 'goodbye')
code, reason = yield self.close_future
self.assertEqual(code, 1001)
self.assertEqual(reason, 'goodbye')
@gen_test
def test_check_origin_valid_no_path(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
headers = {'Origin': 'http://localhost:%d' % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
@gen_test
def test_check_origin_valid_with_path(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
headers = {'Origin': 'http://localhost:%d/something' % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
@gen_test
def test_check_origin_invalid_partial_url(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
headers = {'Origin': 'localhost:%d' % port}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
self.assertEqual(cm.exception.code, 403)
@gen_test
def test_check_origin_invalid(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
# Host is localhost, which should not be accessible from some other
# domain
headers = {'Origin': 'http://somewhereelse.com'}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
self.assertEqual(cm.exception.code, 403)
@gen_test
def test_check_origin_invalid_subdomains(self):
port = self.get_http_port()
url = 'ws://localhost:%d/echo' % port
# Subdomains should be disallowed by default. If we could pass a
# resolver to websocket_connect we could test sibling domains as well.
headers = {'Origin': 'http://subtenant.localhost'}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers),
io_loop=self.io_loop)
self.assertEqual(cm.exception.code, 403)
class MaskFunctionMixin(object):
# Subclasses should define self.mask(mask, data)

View file

@ -5,8 +5,8 @@ from tornado.escape import json_decode
from tornado.test.httpserver_test import TypeCheckHandler
from tornado.testing import AsyncHTTPTestCase
from tornado.util import u
from tornado.web import RequestHandler
from tornado.wsgi import WSGIApplication, WSGIContainer
from tornado.web import RequestHandler, Application
from tornado.wsgi import WSGIApplication, WSGIContainer, WSGIAdapter
class WSGIContainerTest(AsyncHTTPTestCase):
@ -74,14 +74,27 @@ class WSGIConnectionTest(httpserver_test.HTTPConnectionTest):
return WSGIContainer(validator(WSGIApplication(self.get_handlers())))
def wrap_web_tests():
def wrap_web_tests_application():
result = {}
for cls in web_test.wsgi_safe_tests:
class WSGIWrappedTest(cls):
class WSGIApplicationWrappedTest(cls):
def get_app(self):
self.app = WSGIApplication(self.get_handlers(),
**self.get_app_kwargs())
return WSGIContainer(validator(self.app))
result["WSGIWrapped_" + cls.__name__] = WSGIWrappedTest
result["WSGIApplication_" + cls.__name__] = WSGIApplicationWrappedTest
return result
globals().update(wrap_web_tests())
globals().update(wrap_web_tests_application())
def wrap_web_tests_adapter():
result = {}
for cls in web_test.wsgi_safe_tests:
class WSGIAdapterWrappedTest(cls):
def get_app(self):
self.app = Application(self.get_handlers(),
**self.get_app_kwargs())
return WSGIContainer(validator(WSGIAdapter(self.app)))
result["WSGIAdapter_" + cls.__name__] = WSGIAdapterWrappedTest
return result
globals().update(wrap_web_tests_adapter())

View file

@ -17,7 +17,7 @@ try:
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.ioloop import IOLoop
from tornado.ioloop import IOLoop, TimeoutError
from tornado import netutil
except ImportError:
# These modules are not importable on app engine. Parts of this module
@ -38,6 +38,7 @@ import re
import signal
import socket
import sys
import types
try:
from cStringIO import StringIO # py2
@ -48,10 +49,16 @@ except ImportError:
# (either py27+ or unittest2) so tornado.test.util enforces
# this requirement, but for other users of tornado.testing we want
# to allow the older version if unitest2 is not available.
try:
import unittest2 as unittest
except ImportError:
if sys.version_info >= (3,):
# On python 3, mixing unittest2 and unittest (including doctest)
# doesn't seem to work, so always use unittest.
import unittest
else:
# On python 2, prefer unittest2 when available.
try:
import unittest2 as unittest
except ImportError:
import unittest
_next_port = 10000
@ -95,6 +102,36 @@ def get_async_test_timeout():
return 5
class _TestMethodWrapper(object):
"""Wraps a test method to raise an error if it returns a value.
This is mainly used to detect undecorated generators (if a test
method yields it must use a decorator to consume the generator),
but will also detect other kinds of return values (these are not
necessarily errors, but we alert anyway since there is no good
reason to return a value from a test.
"""
def __init__(self, orig_method):
self.orig_method = orig_method
def __call__(self):
result = self.orig_method()
if isinstance(result, types.GeneratorType):
raise TypeError("Generator test methods should be decorated with "
"tornado.testing.gen_test")
elif result is not None:
raise ValueError("Return value from test method ignored: %r" %
result)
def __getattr__(self, name):
"""Proxy all unknown attributes to the original method.
This is important for some of the decorators in the `unittest`
module, such as `unittest.skipIf`.
"""
return getattr(self.orig_method, name)
class AsyncTestCase(unittest.TestCase):
"""`~unittest.TestCase` subclass for testing `.IOLoop`-based
asynchronous code.
@ -157,14 +194,20 @@ class AsyncTestCase(unittest.TestCase):
self.assertIn("FriendFeed", response.body)
self.stop()
"""
def __init__(self, *args, **kwargs):
super(AsyncTestCase, self).__init__(*args, **kwargs)
def __init__(self, methodName='runTest', **kwargs):
super(AsyncTestCase, self).__init__(methodName, **kwargs)
self.__stopped = False
self.__running = False
self.__failure = None
self.__stop_args = None
self.__timeout = None
# It's easy to forget the @gen_test decorator, but if you do
# the test will silently be ignored because nothing will consume
# the generator. Replace the test method with a wrapper that will
# make sure it's not an undecorated generator.
setattr(self, methodName, _TestMethodWrapper(getattr(self, methodName)))
def setUp(self):
super(AsyncTestCase, self).setUp()
self.io_loop = self.get_new_ioloop()
@ -352,6 +395,7 @@ class AsyncHTTPTestCase(AsyncTestCase):
def tearDown(self):
self.http_server.stop()
self.io_loop.run_sync(self.http_server.close_all_connections)
if (not IOLoop.initialized() or
self.http_client.io_loop is not IOLoop.instance()):
self.http_client.close()
@ -414,18 +458,50 @@ def gen_test(func=None, timeout=None):
.. versionadded:: 3.1
The ``timeout`` argument and ``ASYNC_TEST_TIMEOUT`` environment
variable.
.. versionchanged:: 4.0
The wrapper now passes along ``*args, **kwargs`` so it can be used
on functions with arguments.
"""
if timeout is None:
timeout = get_async_test_timeout()
def wrap(f):
f = gen.coroutine(f)
# Stack up several decorators to allow us to access the generator
# object itself. In the innermost wrapper, we capture the generator
# and save it in an attribute of self. Next, we run the wrapped
# function through @gen.coroutine. Finally, the coroutine is
# wrapped again to make it synchronous with run_sync.
#
# This is a good case study arguing for either some sort of
# extensibility in the gen decorators or cancellation support.
@functools.wraps(f)
def wrapper(self):
return self.io_loop.run_sync(
functools.partial(f, self), timeout=timeout)
return wrapper
def pre_coroutine(self, *args, **kwargs):
result = f(self, *args, **kwargs)
if isinstance(result, types.GeneratorType):
self._test_generator = result
else:
self._test_generator = None
return result
coro = gen.coroutine(pre_coroutine)
@functools.wraps(coro)
def post_coroutine(self, *args, **kwargs):
try:
return self.io_loop.run_sync(
functools.partial(coro, self, *args, **kwargs),
timeout=timeout)
except TimeoutError as e:
# run_sync raises an error with an unhelpful traceback.
# If we throw it back into the generator the stack trace
# will be replaced by the point where the test is stopped.
self._test_generator.throw(e)
# In case the test contains an overly broad except clause,
# we may get back here. In this case re-raise the original
# exception, which is better than nothing.
raise
return post_coroutine
if func is not None:
# Used like:

View file

@ -41,7 +41,7 @@ class ObjectDict(dict):
class GzipDecompressor(object):
"""Streaming gzip decompressor.
The interface is like that of `zlib.decompressobj` (without the
The interface is like that of `zlib.decompressobj` (without some of the
optional arguments, but it understands gzip headers and checksums.
"""
def __init__(self):
@ -50,14 +50,24 @@ class GzipDecompressor(object):
# This works on cpython and pypy, but not jython.
self.decompressobj = zlib.decompressobj(16 + zlib.MAX_WBITS)
def decompress(self, value):
def decompress(self, value, max_length=None):
"""Decompress a chunk, returning newly-available data.
Some data may be buffered for later processing; `flush` must
be called when there is no more input data to ensure that
all data was processed.
If ``max_length`` is given, some input data may be left over
in ``unconsumed_tail``; you must retrieve this value and pass
it back to a future call to `decompress` if it is not empty.
"""
return self.decompressobj.decompress(value)
return self.decompressobj.decompress(value, max_length)
@property
def unconsumed_tail(self):
"""Returns the unconsumed portion left over
"""
return self.decompressobj.unconsumed_tail
def flush(self):
"""Return any remaining buffered data not yet returned by decompress.
@ -90,10 +100,6 @@ def import_object(name):
return __import__(name, None, None)
parts = name.split('.')
imp = 'from ' + '.'.join(parts[:-1]) + ' import ' + parts[-1]
#exec(imp)
obj = __import__('.'.join(parts[:-1]), None, None, [parts[-1]], 0)
try:
return getattr(obj, parts[-1])
@ -144,6 +150,24 @@ def exec_in(code, glob, loc=None):
""")
def errno_from_exception(e):
"""Provides the errno from an Exception object.
There are cases that the errno attribute was not set so we pull
the errno out of the args but if someone instatiates an Exception
without any args you will get a tuple error. So this function
abstracts all that behavior to give you a safe way to get the
errno.
"""
if hasattr(e, 'errno'):
return e.errno
elif e.args:
return e.args[0]
else:
return None
class Configurable(object):
"""Base class for configurable interfaces.
@ -255,6 +279,16 @@ class ArgReplacer(object):
# Not a positional parameter
self.arg_pos = None
def get_old_value(self, args, kwargs, default=None):
"""Returns the old value of the named argument without replacing it.
Returns ``default`` if the argument is not present.
"""
if self.arg_pos is not None and len(args) > self.arg_pos:
return args[self.arg_pos]
else:
return kwargs.get(self.name, default)
def replace(self, new_value, args, kwargs):
"""Replace the named argument in ``args, kwargs`` with ``new_value``.

View file

@ -73,9 +73,11 @@ import tornado
import traceback
import types
from tornado.concurrent import Future
from tornado.concurrent import Future, is_future
from tornado import escape
from tornado import gen
from tornado import httputil
from tornado import iostream
from tornado import locale
from tornado.log import access_log, app_log, gen_log
from tornado import stack_context
@ -160,6 +162,7 @@ class RequestHandler(object):
self._finished = False
self._auto_finish = True
self._transforms = None # will be set in _execute
self._prepared_future = None
self.path_args = None
self.path_kwargs = None
self.ui = ObjectDict((n, self._ui_method(m)) for n, m in
@ -173,10 +176,7 @@ class RequestHandler(object):
application.ui_modules)
self.ui["modules"] = self.ui["_tt_modules"]
self.clear()
# Check since connection is not available in WSGI
if getattr(self.request, "connection", None):
self.request.connection.set_close_callback(
self.on_connection_close)
self.request.connection.set_close_callback(self.on_connection_close)
self.initialize(**kwargs)
def initialize(self):
@ -267,7 +267,9 @@ class RequestHandler(object):
may not be called promptly after the end user closes their
connection.
"""
pass
if _has_stream_request_body(self.__class__):
if not self.request.body.done():
self.request.body.set_exception(iostream.StreamClosedError())
def clear(self):
"""Resets all headers and content for this response."""
@ -277,12 +279,6 @@ class RequestHandler(object):
"Date": httputil.format_timestamp(time.time()),
})
self.set_default_headers()
if (not self.request.supports_http_1_1() and
getattr(self.request, 'connection', None) and
not self.request.connection.no_keep_alive):
conn_header = self.request.headers.get("Connection")
if conn_header and (conn_header.lower() == "keep-alive"):
self._headers["Connection"] = "Keep-Alive"
self._write_buffer = []
self._status_code = 200
self._reason = httputil.responses[200]
@ -487,7 +483,7 @@ class RequestHandler(object):
@property
def cookies(self):
"""An alias for `self.request.cookies <.httpserver.HTTPRequest.cookies>`."""
"""An alias for `self.request.cookies <.httputil.HTTPServerRequest.cookies>`."""
return self.request.cookies
def get_cookie(self, name, default=None):
@ -649,12 +645,15 @@ class RequestHandler(object):
Note that lists are not converted to JSON because of a potential
cross-site security vulnerability. All JSON output should be
wrapped in a dictionary. More details at
http://haacked.com/archive/2008/11/20/anatomy-of-a-subtle-json-vulnerability.aspx
http://haacked.com/archive/2009/06/25/json-hijacking.aspx/ and
https://github.com/facebook/tornado/issues/1009
"""
if self._finished:
raise RuntimeError("Cannot write() after finish(). May be caused "
"by using async operations without the "
"@asynchronous decorator.")
if not isinstance(chunk, (bytes_type, unicode_type, dict)):
raise TypeError("write() only accepts bytes, unicode, and dict objects")
if isinstance(chunk, dict):
chunk = escape.json_encode(chunk)
self.set_header("Content-Type", "application/json; charset=UTF-8")
@ -820,35 +819,44 @@ class RequestHandler(object):
Note that only one flush callback can be outstanding at a time;
if another flush occurs before the previous flush's callback
has been run, the previous callback will be discarded.
"""
if self.application._wsgi:
# WSGI applications cannot usefully support flush, so just make
# it a no-op (and run the callback immediately).
if callback is not None:
callback()
return
.. versionchanged:: 4.0
Now returns a `.Future` if no callback is given.
"""
chunk = b"".join(self._write_buffer)
self._write_buffer = []
if not self._headers_written:
self._headers_written = True
for transform in self._transforms or []:
for transform in self._transforms:
self._status_code, self._headers, chunk = \
transform.transform_first_chunk(
self._status_code, self._headers, chunk, include_footers)
headers = self._generate_headers()
# Ignore the chunk and only write the headers for HEAD requests
if self.request.method == "HEAD":
chunk = None
# Finalize the cookie headers (which have been stored in a side
# object so an outgoing cookie could be overwritten before it
# is sent).
if hasattr(self, "_new_cookie"):
for cookie in self._new_cookie.values():
self.add_header("Set-Cookie", cookie.OutputString(None))
start_line = httputil.ResponseStartLine(self.request.version,
self._status_code,
self._reason)
return self.request.connection.write_headers(
start_line, self._headers, chunk, callback=callback)
else:
for transform in self._transforms:
chunk = transform.transform_chunk(chunk, include_footers)
headers = b""
# Ignore the chunk and only write the headers for HEAD requests
if self.request.method == "HEAD":
if headers:
self.request.write(headers, callback=callback)
return
self.request.write(headers + chunk, callback=callback)
# Ignore the chunk and only write the headers for HEAD requests
if self.request.method != "HEAD":
return self.request.connection.write(chunk, callback=callback)
else:
future = Future()
future.set_result(None)
return future
def finish(self, chunk=None):
"""Finishes this response, ending the HTTP request."""
@ -884,10 +892,9 @@ class RequestHandler(object):
# are keepalive connections)
self.request.connection.set_close_callback(None)
if not self.application._wsgi:
self.flush(include_footers=True)
self.request.finish()
self._log()
self.flush(include_footers=True)
self.request.finish()
self._log()
self._finished = True
self.on_finish()
# Break up a reference cycle between this handler and the
@ -1235,27 +1242,6 @@ class RequestHandler(object):
return base + get_url(self.settings, path, **kwargs)
def async_callback(self, callback, *args, **kwargs):
"""Obsolete - catches exceptions from the wrapped function.
This function is unnecessary since Tornado 1.1.
"""
if callback is None:
return None
if args or kwargs:
callback = functools.partial(callback, *args, **kwargs)
def wrapper(*args, **kwargs):
try:
return callback(*args, **kwargs)
except Exception as e:
if self._headers_written:
app_log.error("Exception after headers written",
exc_info=True)
else:
self._handle_request_exception(e)
return wrapper
def require_setting(self, name, feature="this feature"):
"""Raises an exception if the given app setting is not defined."""
if not self.application.settings.get(name):
@ -1322,6 +1308,7 @@ class RequestHandler(object):
self._handle_request_exception(value)
return True
@gen.coroutine
def _execute(self, transforms, *args, **kwargs):
"""Executes this request with the given output transforms."""
self._transforms = transforms
@ -1336,52 +1323,52 @@ class RequestHandler(object):
if self.request.method not in ("GET", "HEAD", "OPTIONS") and \
self.application.settings.get("xsrf_cookies"):
self.check_xsrf_cookie()
self._when_complete(self.prepare(), self._execute_method)
except Exception as e:
self._handle_request_exception(e)
def _when_complete(self, result, callback):
try:
if result is None:
callback()
elif isinstance(result, Future):
if result.done():
if result.result() is not None:
raise ValueError('Expected None, got %r' % result.result())
callback()
else:
# Delayed import of IOLoop because it's not available
# on app engine
from tornado.ioloop import IOLoop
IOLoop.current().add_future(
result, functools.partial(self._when_complete,
callback=callback))
else:
raise ValueError("Expected Future or None, got %r" % result)
except Exception as e:
self._handle_request_exception(e)
result = self.prepare()
if is_future(result):
result = yield result
if result is not None:
raise TypeError("Expected None, got %r" % result)
if self._prepared_future is not None:
# Tell the Application we've finished with prepare()
# and are ready for the body to arrive.
self._prepared_future.set_result(None)
if self._finished:
return
if _has_stream_request_body(self.__class__):
# In streaming mode request.body is a Future that signals
# the body has been completely received. The Future has no
# result; the data has been passed to self.data_received
# instead.
try:
yield self.request.body
except iostream.StreamClosedError:
return
def _execute_method(self):
if not self._finished:
method = getattr(self, self.request.method.lower())
self._when_complete(method(*self.path_args, **self.path_kwargs),
self._execute_finish)
result = method(*self.path_args, **self.path_kwargs)
if is_future(result):
result = yield result
if result is not None:
raise TypeError("Expected None, got %r" % result)
if self._auto_finish and not self._finished:
self.finish()
except Exception as e:
self._handle_request_exception(e)
if (self._prepared_future is not None and
not self._prepared_future.done()):
# In case we failed before setting _prepared_future, do it
# now (to unblock the HTTP server). Note that this is not
# in a finally block to avoid GC issues prior to Python 3.4.
self._prepared_future.set_result(None)
def _execute_finish(self):
if self._auto_finish and not self._finished:
self.finish()
def data_received(self, chunk):
"""Implement this method to handle streamed request data.
def _generate_headers(self):
reason = self._reason
lines = [utf8(self.request.version + " " +
str(self._status_code) +
" " + reason)]
lines.extend([utf8(n) + b": " + utf8(v) for n, v in self._headers.get_all()])
if hasattr(self, "_new_cookie"):
for cookie in self._new_cookie.values():
lines.append(utf8("Set-Cookie: " + cookie.OutputString(None)))
return b"\r\n".join(lines) + b"\r\n\r\n"
Requires the `.stream_request_body` decorator.
"""
raise NotImplementedError()
def _log(self):
"""Logs the current request.
@ -1495,8 +1482,6 @@ def asynchronous(method):
from tornado.ioloop import IOLoop
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
if self.application._wsgi:
raise Exception("@asynchronous is not supported for WSGI apps")
self._auto_finish = False
with stack_context.ExceptionStackContext(
self._stack_context_handle_exception):
@ -1523,6 +1508,40 @@ def asynchronous(method):
return wrapper
def stream_request_body(cls):
"""Apply to `RequestHandler` subclasses to enable streaming body support.
This decorator implies the following changes:
* `.HTTPServerRequest.body` is undefined, and body arguments will not
be included in `RequestHandler.get_argument`.
* `RequestHandler.prepare` is called when the request headers have been
read instead of after the entire body has been read.
* The subclass must define a method ``data_received(self, data):``, which
will be called zero or more times as data is available. Note that
if the request has an empty body, ``data_received`` may not be called.
* ``prepare`` and ``data_received`` may return Futures (such as via
``@gen.coroutine``, in which case the next method will not be called
until those futures have completed.
* The regular HTTP method (``post``, ``put``, etc) will be called after
the entire body has been read.
There is a subtle interaction between ``data_received`` and asynchronous
``prepare``: The first call to ``data_recieved`` may occur at any point
after the call to ``prepare`` has returned *or yielded*.
"""
if not issubclass(cls, RequestHandler):
raise TypeError("expected subclass of RequestHandler, got %r", cls)
cls._stream_request_body = True
return cls
def _has_stream_request_body(cls):
if not issubclass(cls, RequestHandler):
raise TypeError("expected subclass of RequestHandler, got %r", cls)
return getattr(cls, '_stream_request_body', False)
def removeslash(method):
"""Use this decorator to remove trailing slashes from the request path.
@ -1567,7 +1586,7 @@ def addslash(method):
return wrapper
class Application(object):
class Application(httputil.HTTPServerConnectionDelegate):
"""A collection of request handlers that make up a web application.
Instances of this class are callable and can be passed directly to
@ -1619,12 +1638,11 @@ class Application(object):
"""
def __init__(self, handlers=None, default_host="", transforms=None,
wsgi=False, **settings):
**settings):
if transforms is None:
self.transforms = []
if settings.get("gzip"):
self.transforms.append(GZipContentEncoding)
self.transforms.append(ChunkedTransferEncoding)
else:
self.transforms = transforms
self.handlers = []
@ -1636,7 +1654,6 @@ class Application(object):
'Template': TemplateModule,
}
self.ui_methods = {}
self._wsgi = wsgi
self._load_ui_modules(settings.get("ui_modules", {}))
self._load_ui_methods(settings.get("ui_methods", {}))
if self.settings.get("static_path"):
@ -1662,7 +1679,7 @@ class Application(object):
self.settings.setdefault('serve_traceback', True)
# Automatically reload modified modules
if self.settings.get('autoreload') and not wsgi:
if self.settings.get('autoreload'):
from tornado import autoreload
autoreload.start()
@ -1762,64 +1779,15 @@ class Application(object):
except TypeError:
pass
def start_request(self, connection):
# Modern HTTPServer interface
return _RequestDispatcher(self, connection)
def __call__(self, request):
"""Called by HTTPServer to execute the request."""
transforms = [t(request) for t in self.transforms]
handler = None
args = []
kwargs = {}
handlers = self._get_host_handlers(request)
if not handlers:
handler = RedirectHandler(
self, request, url="http://" + self.default_host + "/")
else:
for spec in handlers:
match = spec.regex.match(request.path)
if match:
handler = spec.handler_class(self, request, **spec.kwargs)
if spec.regex.groups:
# None-safe wrapper around url_unescape to handle
# unmatched optional groups correctly
def unquote(s):
if s is None:
return s
return escape.url_unescape(s, encoding=None,
plus=False)
# Pass matched groups to the handler. Since
# match.groups() includes both named and unnamed groups,
# we want to use either groups or groupdict but not both.
# Note that args are passed as bytes so the handler can
# decide what encoding to use.
if spec.regex.groupindex:
kwargs = dict(
(str(k), unquote(v))
for (k, v) in match.groupdict().items())
else:
args = [unquote(s) for s in match.groups()]
break
if not handler:
if self.settings.get('default_handler_class'):
handler_class = self.settings['default_handler_class']
handler_args = self.settings.get(
'default_handler_args', {})
else:
handler_class = ErrorHandler
handler_args = dict(status_code=404)
handler = handler_class(self, request, **handler_args)
# If template cache is disabled (usually in the debug mode),
# re-compile templates and reload static files on every
# request so you don't need to restart to see changes
if not self.settings.get("compiled_template_cache", True):
with RequestHandler._template_loader_lock:
for loader in RequestHandler._template_loaders.values():
loader.reset()
if not self.settings.get('static_hash_cache', True):
StaticFileHandler.reset()
handler._execute(transforms, *args, **kwargs)
return handler
# Legacy HTTPServer interface
dispatcher = _RequestDispatcher(self, None)
dispatcher.set_request(request)
return dispatcher.execute()
def reverse_url(self, name, *args):
"""Returns a URL path for handler named ``name``
@ -1856,6 +1824,113 @@ class Application(object):
handler._request_summary(), request_time)
class _RequestDispatcher(httputil.HTTPMessageDelegate):
def __init__(self, application, connection):
self.application = application
self.connection = connection
self.request = None
self.chunks = []
self.handler_class = None
self.handler_kwargs = None
self.path_args = []
self.path_kwargs = {}
def headers_received(self, start_line, headers):
self.set_request(httputil.HTTPServerRequest(
connection=self.connection, start_line=start_line, headers=headers))
if self.stream_request_body:
self.request.body = Future()
return self.execute()
def set_request(self, request):
self.request = request
self._find_handler()
self.stream_request_body = _has_stream_request_body(self.handler_class)
def _find_handler(self):
# Identify the handler to use as soon as we have the request.
# Save url path arguments for later.
app = self.application
handlers = app._get_host_handlers(self.request)
if not handlers:
self.handler_class = RedirectHandler
self.handler_kwargs = dict(url="http://" + app.default_host + "/")
return
for spec in handlers:
match = spec.regex.match(self.request.path)
if match:
self.handler_class = spec.handler_class
self.handler_kwargs = spec.kwargs
if spec.regex.groups:
# Pass matched groups to the handler. Since
# match.groups() includes both named and
# unnamed groups, we want to use either groups
# or groupdict but not both.
if spec.regex.groupindex:
self.path_kwargs = dict(
(str(k), _unquote_or_none(v))
for (k, v) in match.groupdict().items())
else:
self.path_args = [_unquote_or_none(s)
for s in match.groups()]
return
if app.settings.get('default_handler_class'):
self.handler_class = app.settings['default_handler_class']
self.handler_kwargs = app.settings.get(
'default_handler_args', {})
else:
self.handler_class = ErrorHandler
self.handler_kwargs = dict(status_code=404)
def data_received(self, data):
if self.stream_request_body:
return self.handler.data_received(data)
else:
self.chunks.append(data)
def finish(self):
if self.stream_request_body:
self.request.body.set_result(None)
else:
self.request.body = b''.join(self.chunks)
self.request._parse_body()
self.execute()
def on_connection_close(self):
if self.stream_request_body:
self.handler.on_connection_close()
else:
self.chunks = None
def execute(self):
# If template cache is disabled (usually in the debug mode),
# re-compile templates and reload static files on every
# request so you don't need to restart to see changes
if not self.application.settings.get("compiled_template_cache", True):
with RequestHandler._template_loader_lock:
for loader in RequestHandler._template_loaders.values():
loader.reset()
if not self.application.settings.get('static_hash_cache', True):
StaticFileHandler.reset()
self.handler = self.handler_class(self.application, self.request,
**self.handler_kwargs)
transforms = [t(self.request) for t in self.application.transforms]
if self.stream_request_body:
self.handler._prepared_future = Future()
# Note that if an exception escapes handler._execute it will be
# trapped in the Future it returns (which we are ignoring here).
# However, that shouldn't happen because _execute has a blanket
# except handler, and we cannot easily access the IOLoop here to
# call add_future.
self.handler._execute(transforms, *self.path_args, **self.path_kwargs)
# If we are streaming the request body, then execute() is finished
# when the handler has prepared to receive the body. If not,
# it doesn't matter when execute() finishes (so we return None)
return self.handler._prepared_future
class HTTPError(Exception):
"""An exception that will turn into an HTTP error response.
@ -2014,8 +2089,9 @@ class StaticFileHandler(RequestHandler):
cls._static_hashes = {}
def head(self, path):
self.get(path, include_body=False)
return self.get(path, include_body=False)
@gen.coroutine
def get(self, path, include_body=True):
# Set up our path instance variables.
self.path = self.parse_url_path(path)
@ -2040,9 +2116,9 @@ class StaticFileHandler(RequestHandler):
# the request will be treated as if the header didn't exist.
request_range = httputil._parse_request_range(range_header)
size = self.get_content_size()
if request_range:
start, end = request_range
size = self.get_content_size()
if (start is not None and start >= size) or end == 0:
# As per RFC 2616 14.35.1, a range is not satisfiable only: if
# the first requested byte is equal to or greater than the
@ -2067,18 +2143,26 @@ class StaticFileHandler(RequestHandler):
httputil._get_content_range(start, end, size))
else:
start = end = None
content = self.get_content(self.absolute_path, start, end)
if isinstance(content, bytes_type):
content = [content]
content_length = 0
for chunk in content:
if include_body:
if start is not None and end is not None:
content_length = end - start
elif end is not None:
content_length = end
elif start is not None:
content_length = size - start
else:
content_length = size
self.set_header("Content-Length", content_length)
if include_body:
content = self.get_content(self.absolute_path, start, end)
if isinstance(content, bytes_type):
content = [content]
for chunk in content:
self.write(chunk)
else:
content_length += len(chunk)
if not include_body:
yield self.flush()
else:
assert self.request.method == "HEAD"
self.set_header("Content-Length", content_length)
def compute_etag(self):
"""Sets the ``Etag`` header based on static url version.
@ -2258,10 +2342,13 @@ class StaticFileHandler(RequestHandler):
def get_content_size(self):
"""Retrieve the total size of the resource at the given path.
This method may be overridden by subclasses. It will only
be called if a partial result is requested from `get_content`
This method may be overridden by subclasses.
.. versionadded:: 3.1
.. versionchanged:: 4.0
This method is now always called, instead of only when
partial results are requested.
"""
stat_result = self._stat()
return stat_result[stat.ST_SIZE]
@ -2383,7 +2470,7 @@ class FallbackHandler(RequestHandler):
"""A `RequestHandler` that wraps another HTTP server callback.
The fallback is a callable object that accepts an
`~.httpserver.HTTPRequest`, such as an `Application` or
`~.httputil.HTTPServerRequest`, such as an `Application` or
`tornado.wsgi.WSGIContainer`. This is most useful to use both
Tornado ``RequestHandlers`` and WSGI in the same server. Typical
usage::
@ -2407,7 +2494,7 @@ class OutputTransform(object):
"""A transform modifies the result of an HTTP request (e.g., GZip encoding)
A new transform instance is created for every request. See the
ChunkedTransferEncoding example below if you want to implement a
GZipContentEncoding example below if you want to implement a
new Transform.
"""
def __init__(self, request):
@ -2424,16 +2511,24 @@ class GZipContentEncoding(OutputTransform):
"""Applies the gzip content encoding to the response.
See http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.11
.. versionchanged:: 4.0
Now compresses all mime types beginning with ``text/``, instead
of just a whitelist. (the whitelist is still used for certain
non-text mime types).
"""
CONTENT_TYPES = set([
"text/plain", "text/html", "text/css", "text/xml", "application/javascript",
"application/x-javascript", "application/xml", "application/atom+xml",
"text/javascript", "application/json", "application/xhtml+xml"])
# Whitelist of compressible mime types (in addition to any types
# beginning with "text/").
CONTENT_TYPES = set(["application/javascript", "application/x-javascript",
"application/xml", "application/atom+xml",
"application/json", "application/xhtml+xml"])
MIN_LENGTH = 5
def __init__(self, request):
self._gzipping = request.supports_http_1_1() and \
"gzip" in request.headers.get("Accept-Encoding", "")
self._gzipping = "gzip" in request.headers.get("Accept-Encoding", "")
def _compressible_type(self, ctype):
return ctype.startswith('text/') or ctype in self.CONTENT_TYPES
def transform_first_chunk(self, status_code, headers, chunk, finishing):
if 'Vary' in headers:
@ -2442,7 +2537,7 @@ class GZipContentEncoding(OutputTransform):
headers['Vary'] = b'Accept-Encoding'
if self._gzipping:
ctype = _unicode(headers.get("Content-Type", "")).split(";")[0]
self._gzipping = (ctype in self.CONTENT_TYPES) and \
self._gzipping = self._compressible_type(ctype) and \
(not finishing or len(chunk) >= self.MIN_LENGTH) and \
(finishing or "Content-Length" not in headers) and \
("Content-Encoding" not in headers)
@ -2468,42 +2563,16 @@ class GZipContentEncoding(OutputTransform):
return chunk
class ChunkedTransferEncoding(OutputTransform):
"""Applies the chunked transfer encoding to the response.
See http://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.6.1
"""
def __init__(self, request):
self._chunking = request.supports_http_1_1()
def transform_first_chunk(self, status_code, headers, chunk, finishing):
# 304 responses have no body (not even a zero-length body), and so
# should not have either Content-Length or Transfer-Encoding headers.
if self._chunking and status_code != 304:
# No need to chunk the output if a Content-Length is specified
if "Content-Length" in headers or "Transfer-Encoding" in headers:
self._chunking = False
else:
headers["Transfer-Encoding"] = "chunked"
chunk = self.transform_chunk(chunk, finishing)
return status_code, headers, chunk
def transform_chunk(self, block, finishing):
if self._chunking:
# Don't write out empty chunks because that means END-OF-STREAM
# with chunked encoding
if block:
block = utf8("%x" % len(block)) + b"\r\n" + block + b"\r\n"
if finishing:
block += b"0\r\n\r\n"
return block
def authenticated(method):
"""Decorate methods with this to require that the user be logged in.
If the user is not logged in, they will be redirected to the configured
`login url <RequestHandler.get_login_url>`.
If you configure a login url with a query parameter, Tornado will
assume you know what you're doing and use it as-is. If not, it
will add a `next` parameter so the login page knows where to send
you once you're logged in.
"""
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
@ -2810,7 +2879,8 @@ def create_signed_value(secret, name, value, version=None, clock=None):
# A leading version number in decimal with no leading zeros, followed by a pipe.
_signed_value_version_re = re.compile(br"^([1-9][0-9]*)\|(.*)$")
def decode_signed_value(secret, name, value, max_age_days=31, clock=None,min_version=None):
def decode_signed_value(secret, name, value, max_age_days=31, clock=None, min_version=None):
if clock is None:
clock = time.time
if min_version is None:
@ -2850,6 +2920,7 @@ def decode_signed_value(secret, name, value, max_age_days=31, clock=None,min_ver
else:
return None
def _decode_signed_value_v1(secret, name, value, max_age_days, clock):
parts = utf8(value).split(b"|")
if len(parts) != 3:
@ -2886,9 +2957,9 @@ def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
field_value = rest[:n]
# In python 3, indexing bytes returns small integers; we must
# use a slice to get a byte string as in python 2.
if rest[n:n+1] != b'|':
if rest[n:n + 1] != b'|':
raise ValueError("malformed v2 signed value field")
rest = rest[n+1:]
rest = rest[n + 1:]
return field_value, rest
rest = value[2:] # remove version number
try:
@ -2921,7 +2992,20 @@ def _create_signature_v1(secret, *parts):
hash.update(utf8(part))
return utf8(hash.hexdigest())
def _create_signature_v2(secret, s):
hash = hmac.new(utf8(secret), digestmod=hashlib.sha256)
hash.update(utf8(s))
return utf8(hash.hexdigest())
def _unquote_or_none(s):
"""None-safe wrapper around url_unescape to handle unamteched optional
groups correctly.
Note that args are passed as bytes so the handler can decide what
encoding to use.
"""
if s is None:
return s
return escape.url_unescape(s, encoding=None, plus=False)

View file

@ -31,15 +31,25 @@ import tornado.escape
import tornado.web
from tornado.concurrent import TracebackFuture
from tornado.escape import utf8, native_str
from tornado.escape import utf8, native_str, to_unicode
from tornado import httpclient, httputil
from tornado.ioloop import IOLoop
from tornado.iostream import StreamClosedError
from tornado.log import gen_log, app_log
from tornado.netutil import Resolver
from tornado import simple_httpclient
from tornado.tcpclient import TCPClient
from tornado.util import bytes_type, unicode_type, _websocket_mask
try:
from urllib.parse import urlparse # py2
except ImportError:
from urlparse import urlparse # py3
try:
xrange # py2
except NameError:
xrange = range # py3
class WebSocketError(Exception):
pass
@ -102,28 +112,20 @@ class WebSocketHandler(tornado.web.RequestHandler):
def __init__(self, application, request, **kwargs):
tornado.web.RequestHandler.__init__(self, application, request,
**kwargs)
self.stream = request.connection.stream
self.ws_connection = None
self.close_code = None
self.close_reason = None
self.stream = None
def _execute(self, transforms, *args, **kwargs):
@tornado.web.asynchronous
def get(self, *args, **kwargs):
self.open_args = args
self.open_kwargs = kwargs
# Websocket only supports GET method
if self.request.method != 'GET':
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 405 Method Not Allowed\r\n\r\n"
))
self.stream.close()
return
# Upgrade header should be present and should be equal to WebSocket
if self.request.headers.get("Upgrade", "").lower() != 'websocket':
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 400 Bad Request\r\n\r\n"
"Can \"Upgrade\" only to \"WebSocket\"."
))
self.stream.close()
self.set_status(400)
self.finish("Can \"Upgrade\" only to \"WebSocket\".")
return
# Connection header should be upgrade. Some proxy servers/load balancers
@ -131,16 +133,31 @@ class WebSocketHandler(tornado.web.RequestHandler):
headers = self.request.headers
connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(","))
if 'upgrade' not in connection:
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 400 Bad Request\r\n\r\n"
"\"Connection\" must be \"Upgrade\"."
))
self.stream.close()
self.set_status(400)
self.finish("\"Connection\" must be \"Upgrade\".")
return
# Handle WebSocket Origin naming convention differences
# The difference between version 8 and 13 is that in 8 the
# client sends a "Sec-Websocket-Origin" header and in 13 it's
# simply "Origin".
if "Origin" in self.request.headers:
origin = self.request.headers.get("Origin")
else:
origin = self.request.headers.get("Sec-Websocket-Origin", None)
# If there was an origin header, check to make sure it matches
# according to check_origin. When the origin is None, we assume it
# did not come from a browser and that it can be passed on.
if origin is not None and not self.check_origin(origin):
self.set_status(403)
self.finish("Cross origin websockets not allowed")
return
self.stream = self.request.connection.detach()
self.stream.set_close_callback(self.on_connection_close)
if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
self.ws_connection = WebSocketProtocol13(self)
self.ws_connection.accept_connection()
@ -154,6 +171,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
"Sec-WebSocket-Version: 8\r\n\r\n"))
self.stream.close()
def write_message(self, message, binary=False):
"""Sends the given message to the client of this Web Socket.
@ -214,18 +232,70 @@ class WebSocketHandler(tornado.web.RequestHandler):
pass
def on_close(self):
"""Invoked when the WebSocket is closed."""
"""Invoked when the WebSocket is closed.
If the connection was closed cleanly and a status code or reason
phrase was supplied, these values will be available as the attributes
``self.close_code`` and ``self.close_reason``.
.. versionchanged:: 4.0
Added ``close_code`` and ``close_reason`` attributes.
"""
pass
def close(self):
def close(self, code=None, reason=None):
"""Closes this Web Socket.
Once the close handshake is successful the socket will be closed.
``code`` may be a numeric status code, taken from the values
defined in `RFC 6455 section 7.4.1
<https://tools.ietf.org/html/rfc6455#section-7.4.1>`_.
``reason`` may be a textual message about why the connection is
closing. These values are made available to the client, but are
not otherwise interpreted by the websocket protocol.
The ``code`` and ``reason`` arguments are ignored in the "draft76"
protocol version.
.. versionchanged:: 4.0
Added the ``code`` and ``reason`` arguments.
"""
if self.ws_connection:
self.ws_connection.close()
self.ws_connection.close(code, reason)
self.ws_connection = None
def check_origin(self, origin):
"""Override to enable support for allowing alternate origins.
The ``origin`` argument is the value of the ``Origin`` HTTP
header, the url responsible for initiating this request. This
method is not called for clients that do not send this header;
such requests are always allowed (because all browsers that
implement WebSockets support this header, and non-browser
clients do not have the same cross-site security concerns).
Should return True to accept the request or False to reject it.
By default, rejects all requests with an origin on a host other
than this one.
This is a security protection against cross site scripting attacks on
browsers, since WebSockets are allowed to bypass the usual same-origin
policies and don't use CORS headers.
.. versionadded:: 4.0
"""
parsed_origin = urlparse(origin)
origin = parsed_origin.netloc
origin = origin.lower()
host = self.request.headers.get("Host")
# Check to see that origin matches host directly, including ports
return origin == host
def allow_draft76(self):
"""Override to enable support for the older "draft76" protocol.
@ -269,17 +339,6 @@ class WebSocketHandler(tornado.web.RequestHandler):
"""
return "wss" if self.request.protocol == "https" else "ws"
def async_callback(self, callback, *args, **kwargs):
"""Obsolete - catches exceptions from the wrapped function.
This function is normally unncecessary thanks to
`tornado.stack_context`.
"""
return self.ws_connection.async_callback(callback, *args, **kwargs)
def _not_supported(self, *args, **kwargs):
raise Exception("Method not supported for Web Sockets")
def on_connection_close(self):
if self.ws_connection:
self.ws_connection.on_connection_close()
@ -287,9 +346,17 @@ class WebSocketHandler(tornado.web.RequestHandler):
self.on_close()
def _wrap_method(method):
def _disallow_for_websocket(self, *args, **kwargs):
if self.stream is None:
method(self, *args, **kwargs)
else:
raise RuntimeError("Method not supported for Web Sockets")
return _disallow_for_websocket
for method in ["write", "redirect", "set_header", "send_error", "set_cookie",
"set_status", "flush", "finish"]:
setattr(WebSocketHandler, method, WebSocketHandler._not_supported)
setattr(WebSocketHandler, method,
_wrap_method(getattr(WebSocketHandler, method)))
class WebSocketProtocol(object):
@ -302,23 +369,17 @@ class WebSocketProtocol(object):
self.client_terminated = False
self.server_terminated = False
def async_callback(self, callback, *args, **kwargs):
"""Wrap callbacks with this if they are used on asynchronous requests.
def _run_callback(self, callback, *args, **kwargs):
"""Runs the given callback with exception handling.
Catches exceptions properly and closes this WebSocket if an exception
is uncaught.
On error, aborts the websocket connection and returns False.
"""
if args or kwargs:
callback = functools.partial(callback, *args, **kwargs)
def wrapper(*args, **kwargs):
try:
return callback(*args, **kwargs)
except Exception:
app_log.error("Uncaught exception in %s",
self.request.path, exc_info=True)
self._abort()
return wrapper
try:
callback(*args, **kwargs)
except Exception:
app_log.error("Uncaught exception in %s",
self.request.path, exc_info=True)
self._abort()
def on_connection_close(self):
self._abort()
@ -409,7 +470,8 @@ class WebSocketProtocol76(WebSocketProtocol):
def _write_response(self, challenge):
self.stream.write(challenge)
self.async_callback(self.handler.open)(*self.handler.open_args, **self.handler.open_kwargs)
self._run_callback(self.handler.open, *self.handler.open_args,
**self.handler.open_kwargs)
self._receive_message()
def _handle_websocket_headers(self):
@ -457,8 +519,8 @@ class WebSocketProtocol76(WebSocketProtocol):
def _on_end_delimiter(self, frame):
if not self.client_terminated:
self.async_callback(self.handler.on_message)(
frame[:-1].decode("utf-8", "replace"))
self._run_callback(self.handler.on_message,
frame[:-1].decode("utf-8", "replace"))
if not self.client_terminated:
self._receive_message()
@ -483,7 +545,7 @@ class WebSocketProtocol76(WebSocketProtocol):
"""Send ping frame."""
raise ValueError("Ping messages not supported by this version of websockets")
def close(self):
def close(self, code=None, reason=None):
"""Closes the WebSocket connection."""
if not self.server_terminated:
if not self.stream.closed():
@ -568,7 +630,8 @@ class WebSocketProtocol13(WebSocketProtocol):
"%s"
"\r\n" % (self._challenge_response(), subprotocol_header)))
self.async_callback(self.handler.open)(*self.handler.open_args, **self.handler.open_kwargs)
self._run_callback(self.handler.open, *self.handler.open_args,
**self.handler.open_kwargs)
self._receive_frame()
def _write_frame(self, fin, opcode, data):
@ -726,28 +789,40 @@ class WebSocketProtocol13(WebSocketProtocol):
except UnicodeDecodeError:
self._abort()
return
self.async_callback(self.handler.on_message)(decoded)
self._run_callback(self.handler.on_message, decoded)
elif opcode == 0x2:
# Binary data
self.async_callback(self.handler.on_message)(data)
self._run_callback(self.handler.on_message, decoded)
elif opcode == 0x8:
# Close
self.client_terminated = True
if len(data) >= 2:
self.handler.close_code = struct.unpack('>H', data[:2])[0]
if len(data) > 2:
self.handler.close_reason = to_unicode(data[2:])
self.close()
elif opcode == 0x9:
# Ping
self._write_frame(True, 0xA, data)
elif opcode == 0xA:
# Pong
self.async_callback(self.handler.on_pong)(data)
self._run_callback(self.handler.on_pong, data)
else:
self._abort()
def close(self):
def close(self, code=None, reason=None):
"""Closes the WebSocket connection."""
if not self.server_terminated:
if not self.stream.closed():
self._write_frame(True, 0x8, b"")
if code is None and reason is not None:
code = 1000 # "normal closure" status code
if code is None:
close_data = b''
else:
close_data = struct.pack('>H', code)
if reason is not None:
close_data += utf8(reason)
self._write_frame(True, 0x8, close_data)
self.server_terminated = True
if self.client_terminated:
if self._waiting is not None:
@ -783,18 +858,25 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
'Sec-WebSocket-Version': '13',
})
self.resolver = Resolver(io_loop=io_loop)
self.tcp_client = TCPClient(io_loop=io_loop)
super(WebSocketClientConnection, self).__init__(
io_loop, None, request, lambda: None, self._on_http_response,
104857600, self.resolver)
104857600, self.tcp_client, 65536)
def close(self):
def close(self, code=None, reason=None):
"""Closes the websocket connection.
``code`` and ``reason`` are documented under
`WebSocketHandler.close`.
.. versionadded:: 3.2
.. versionchanged:: 4.0
Added the ``code`` and ``reason`` arguments.
"""
if self.protocol is not None:
self.protocol.close()
self.protocol.close(code, reason)
self.protocol = None
def _on_close(self):
@ -810,8 +892,12 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
self.connect_future.set_exception(WebSocketError(
"Non-websocket response"))
def _handle_1xx(self, code):
assert code == 101
def headers_received(self, start_line, headers):
if start_line.code != 101:
return super(WebSocketClientConnection, self).headers_received(
start_line, headers)
self.headers = headers
assert self.headers['Upgrade'].lower() == 'websocket'
assert self.headers['Connection'].lower() == 'upgrade'
accept = WebSocketProtocol13.compute_accept_value(self.key)
@ -824,6 +910,9 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
self.stream = self.connection.detach()
self.stream.set_close_callback(self._on_close)
self.connect_future.set_result(self)
def write_message(self, message, binary=False):

View file

@ -20,9 +20,9 @@ WSGI is the Python standard for web servers, and allows for interoperability
between Tornado and other Python web frameworks and servers. This module
provides WSGI support in two ways:
* `WSGIApplication` is a version of `tornado.web.Application` that can run
inside a WSGI server. This is useful for running a Tornado app on another
HTTP server, such as Google App Engine. See the `WSGIApplication` class
* `WSGIAdapter` converts a `tornado.web.Application` to the WSGI application
interface. This is useful for running a Tornado app on another
HTTP server, such as Google App Engine. See the `WSGIAdapter` class
documentation for limitations that apply.
* `WSGIContainer` lets you run other WSGI applications and frameworks on the
Tornado HTTP server. For example, with this class you can mix Django
@ -32,15 +32,14 @@ provides WSGI support in two ways:
from __future__ import absolute_import, division, print_function, with_statement
import sys
import time
import copy
import tornado
from tornado.concurrent import Future
from tornado import escape
from tornado import httputil
from tornado.log import access_log
from tornado import web
from tornado.escape import native_str, parse_qs_bytes
from tornado.escape import native_str
from tornado.util import bytes_type, unicode_type
try:
@ -48,11 +47,6 @@ try:
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
try:
import Cookie # py2
except ImportError:
import http.cookies as Cookie # py3
try:
import urllib.parse as urllib_parse # py3
except ImportError:
@ -83,11 +77,84 @@ else:
class WSGIApplication(web.Application):
"""A WSGI equivalent of `tornado.web.Application`.
`WSGIApplication` is very similar to `tornado.web.Application`,
except no asynchronous methods are supported (since WSGI does not
support non-blocking requests properly). If you call
``self.flush()`` or other asynchronous methods in your request
handlers running in a `WSGIApplication`, we throw an exception.
.. deprecated:: 4.0
Use a regular `.Application` and wrap it in `WSGIAdapter` instead.
"""
def __call__(self, environ, start_response):
return WSGIAdapter(self)(environ, start_response)
# WSGI has no facilities for flow control, so just return an already-done
# Future when the interface requires it.
_dummy_future = Future()
_dummy_future.set_result(None)
class _WSGIConnection(httputil.HTTPConnection):
def __init__(self, method, start_response, context):
self.method = method
self.start_response = start_response
self.context = context
self._write_buffer = []
self._finished = False
self._expected_content_remaining = None
self._error = None
def set_close_callback(self, callback):
# WSGI has no facility for detecting a closed connection mid-request,
# so we can simply ignore the callback.
pass
def write_headers(self, start_line, headers, chunk=None, callback=None):
if self.method == 'HEAD':
self._expected_content_remaining = 0
elif 'Content-Length' in headers:
self._expected_content_remaining = int(headers['Content-Length'])
else:
self._expected_content_remaining = None
self.start_response(
'%s %s' % (start_line.code, start_line.reason),
[(native_str(k), native_str(v)) for (k, v) in headers.get_all()])
if chunk is not None:
self.write(chunk, callback)
elif callback is not None:
callback()
return _dummy_future
def write(self, chunk, callback=None):
if self._expected_content_remaining is not None:
self._expected_content_remaining -= len(chunk)
if self._expected_content_remaining < 0:
self._error = httputil.HTTPOutputError(
"Tried to write more data than Content-Length")
raise self._error
self._write_buffer.append(chunk)
if callback is not None:
callback()
return _dummy_future
def finish(self):
if (self._expected_content_remaining is not None and
self._expected_content_remaining != 0):
self._error = httputil.HTTPOutputError(
"Tried to write %d bytes less than Content-Length" %
self._expected_content_remaining)
raise self._error
self._finished = True
class _WSGIRequestContext(object):
def __init__(self, remote_ip, protocol):
self.remote_ip = remote_ip
self.protocol = protocol
def __str__(self):
return self.remote_ip
class WSGIAdapter(object):
"""Converts a `tornado.web.Application` instance into a WSGI application.
Example usage::
@ -100,121 +167,83 @@ class WSGIApplication(web.Application):
self.write("Hello, world")
if __name__ == "__main__":
application = tornado.wsgi.WSGIApplication([
application = tornado.web.Application([
(r"/", MainHandler),
])
server = wsgiref.simple_server.make_server('', 8888, application)
wsgi_app = tornado.wsgi.WSGIAdapter(application)
server = wsgiref.simple_server.make_server('', 8888, wsgi_app)
server.serve_forever()
See the `appengine demo
<https://github.com/tornadoweb/tornado/tree/master/demos/appengine>`_
<https://github.com/tornadoweb/tornado/tree/stable/demos/appengine>`_
for an example of using this module to run a Tornado app on Google
App Engine.
WSGI applications use the same `.RequestHandler` class, but not
``@asynchronous`` methods or ``flush()``. This means that it is
not possible to use `.AsyncHTTPClient`, or the `tornado.auth` or
`tornado.websocket` modules.
In WSGI mode asynchronous methods are not supported. This means
that it is not possible to use `.AsyncHTTPClient`, or the
`tornado.auth` or `tornado.websocket` modules.
.. versionadded:: 4.0
"""
def __init__(self, handlers=None, default_host="", **settings):
web.Application.__init__(self, handlers, default_host, transforms=[],
wsgi=True, **settings)
def __init__(self, application):
if isinstance(application, WSGIApplication):
self.application = lambda request: web.Application.__call__(
application, request)
else:
self.application = application
def __call__(self, environ, start_response):
handler = web.Application.__call__(self, HTTPRequest(environ))
assert handler._finished
reason = handler._reason
status = str(handler._status_code) + " " + reason
headers = list(handler._headers.get_all())
if hasattr(handler, "_new_cookie"):
for cookie in handler._new_cookie.values():
headers.append(("Set-Cookie", cookie.OutputString(None)))
start_response(status,
[(native_str(k), native_str(v)) for (k, v) in headers])
return handler._write_buffer
class HTTPRequest(object):
"""Mimics `tornado.httpserver.HTTPRequest` for WSGI applications."""
def __init__(self, environ):
"""Parses the given WSGI environment to construct the request."""
self.method = environ["REQUEST_METHOD"]
self.path = urllib_parse.quote(from_wsgi_str(environ.get("SCRIPT_NAME", "")))
self.path += urllib_parse.quote(from_wsgi_str(environ.get("PATH_INFO", "")))
self.uri = self.path
self.arguments = {}
self.query_arguments = {}
self.body_arguments = {}
self.query = environ.get("QUERY_STRING", "")
if self.query:
self.uri += "?" + self.query
self.arguments = parse_qs_bytes(native_str(self.query),
keep_blank_values=True)
self.query_arguments = copy.deepcopy(self.arguments)
self.version = "HTTP/1.1"
self.headers = httputil.HTTPHeaders()
method = environ["REQUEST_METHOD"]
uri = urllib_parse.quote(from_wsgi_str(environ.get("SCRIPT_NAME", "")))
uri += urllib_parse.quote(from_wsgi_str(environ.get("PATH_INFO", "")))
if environ.get("QUERY_STRING"):
uri += "?" + environ["QUERY_STRING"]
headers = httputil.HTTPHeaders()
if environ.get("CONTENT_TYPE"):
self.headers["Content-Type"] = environ["CONTENT_TYPE"]
headers["Content-Type"] = environ["CONTENT_TYPE"]
if environ.get("CONTENT_LENGTH"):
self.headers["Content-Length"] = environ["CONTENT_LENGTH"]
headers["Content-Length"] = environ["CONTENT_LENGTH"]
for key in environ:
if key.startswith("HTTP_"):
self.headers[key[5:].replace("_", "-")] = environ[key]
if self.headers.get("Content-Length"):
self.body = environ["wsgi.input"].read(
int(self.headers["Content-Length"]))
headers[key[5:].replace("_", "-")] = environ[key]
if headers.get("Content-Length"):
body = environ["wsgi.input"].read(
int(headers["Content-Length"]))
else:
self.body = ""
self.protocol = environ["wsgi.url_scheme"]
self.remote_ip = environ.get("REMOTE_ADDR", "")
body = ""
protocol = environ["wsgi.url_scheme"]
remote_ip = environ.get("REMOTE_ADDR", "")
if environ.get("HTTP_HOST"):
self.host = environ["HTTP_HOST"]
host = environ["HTTP_HOST"]
else:
self.host = environ["SERVER_NAME"]
# Parse request body
self.files = {}
httputil.parse_body_arguments(self.headers.get("Content-Type", ""),
self.body, self.body_arguments, self.files)
for k, v in self.body_arguments.items():
self.arguments.setdefault(k, []).extend(v)
self._start_time = time.time()
self._finish_time = None
def supports_http_1_1(self):
"""Returns True if this request supports HTTP/1.1 semantics"""
return self.version == "HTTP/1.1"
@property
def cookies(self):
"""A dictionary of Cookie.Morsel objects."""
if not hasattr(self, "_cookies"):
self._cookies = Cookie.SimpleCookie()
if "Cookie" in self.headers:
try:
self._cookies.load(
native_str(self.headers["Cookie"]))
except Exception:
self._cookies = None
return self._cookies
def full_url(self):
"""Reconstructs the full URL for this request."""
return self.protocol + "://" + self.host + self.uri
def request_time(self):
"""Returns the amount of time it took for this request to execute."""
if self._finish_time is None:
return time.time() - self._start_time
else:
return self._finish_time - self._start_time
host = environ["SERVER_NAME"]
connection = _WSGIConnection(method, start_response,
_WSGIRequestContext(remote_ip, protocol))
request = httputil.HTTPServerRequest(
method, uri, "HTTP/1.1", headers=headers, body=body,
host=host, connection=connection)
request._parse_body()
self.application(request)
if connection._error:
raise connection._error
if not connection._finished:
raise Exception("request did not finish synchronously")
return connection._write_buffer
class WSGIContainer(object):
r"""Makes a WSGI-compatible function runnable on Tornado's HTTP server.
.. warning::
WSGI is a *synchronous* interface, while Tornado's concurrency model
is based on single-threaded asynchronous execution. This means that
running a WSGI app with Tornado's `WSGIContainer` is *less scalable*
than running the same app in a multi-threaded WSGI server like
``gunicorn`` or ``uwsgi``. Use `WSGIContainer` only when there are
benefits to combining Tornado and WSGI in the same process that
outweigh the reduced scalability.
Wrap a WSGI function in a `WSGIContainer` and pass it to `.HTTPServer` to
run it. For example::
@ -281,7 +310,7 @@ class WSGIContainer(object):
@staticmethod
def environ(request):
"""Converts a `tornado.httpserver.HTTPRequest` to a WSGI environment.
"""Converts a `tornado.httputil.HTTPServerRequest` to a WSGI environment.
"""
hostport = request.host.split(":")
if len(hostport) == 2:
@ -327,3 +356,6 @@ class WSGIContainer(object):
summary = request.method + " " + request.uri + " (" + \
request.remote_ip + ")"
log_method("%d %s %.2fms", status_code, summary, request_time)
HTTPRequest = httputil.HTTPServerRequest