Update Tornado webserver to 4.1dev1 from 4.0b1 and add the certifi lib dependency.

This commit is contained in:
JackDandy 2014-10-14 05:24:01 +01:00
parent 0333a9aa3e
commit 2d18f5b8ab
53 changed files with 6628 additions and 4202 deletions

1
lib/certifi/__init__.py Normal file
View file

@ -0,0 +1 @@
from .core import where

2
lib/certifi/__main__.py Normal file
View file

@ -0,0 +1,2 @@
from certifi import where
print(where())

5134
lib/certifi/cacert.pem Normal file

File diff suppressed because it is too large Load diff

19
lib/certifi/core.py Normal file
View file

@ -0,0 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
certifi.py
~~~~~~~~~~
This module returns the installation location of cacert.pem.
"""
import os
def where():
f = os.path.split(__file__)[0]
return os.path.join(f, 'cacert.pem')
if __name__ == '__main__':
print(where())

View file

@ -25,5 +25,5 @@ from __future__ import absolute_import, division, print_function, with_statement
# is zero for an official release, positive for a development branch,
# or negative for a release candidate or beta (after the base version
# number has been incremented)
version = "4.0b1"
version_info = (4, 0, 0, -99)
version = "4.1.dev1"
version_info = (4, 1, 0, -100)

View file

@ -76,7 +76,7 @@ from tornado import escape
from tornado.httputil import url_concat
from tornado.log import gen_log
from tornado.stack_context import ExceptionStackContext
from tornado.util import bytes_type, u, unicode_type, ArgReplacer
from tornado.util import u, unicode_type, ArgReplacer
try:
import urlparse # py2
@ -333,7 +333,7 @@ class OAuthMixin(object):
The ``callback_uri`` may be omitted if you have previously
registered a callback URI with the third-party service. For
some sevices (including Friendfeed), you must use a
some services (including Friendfeed), you must use a
previously-registered callback URI and cannot specify a
callback via this method.
@ -1112,7 +1112,7 @@ class FacebookMixin(object):
args["cancel_url"] = urlparse.urljoin(
self.request.full_url(), cancel_uri)
if extended_permissions:
if isinstance(extended_permissions, (unicode_type, bytes_type)):
if isinstance(extended_permissions, (unicode_type, bytes)):
extended_permissions = [extended_permissions]
args["req_perms"] = ",".join(extended_permissions)
self.redirect("http://www.facebook.com/login.php?" +

File diff suppressed because it is too large Load diff

View file

@ -29,6 +29,7 @@ import sys
from tornado.stack_context import ExceptionStackContext, wrap
from tornado.util import raise_exc_info, ArgReplacer
from tornado.log import app_log
try:
from concurrent import futures
@ -173,8 +174,11 @@ class Future(object):
def _set_done(self):
self._done = True
for cb in self._callbacks:
# TODO: error handling
try:
cb(self)
except Exception:
app_log.exception('exception calling callback %r for %r',
cb, self)
self._callbacks = None
TracebackFuture = Future

View file

@ -19,10 +19,12 @@
from __future__ import absolute_import, division, print_function, with_statement
import collections
import functools
import logging
import pycurl
import threading
import time
from io import BytesIO
from tornado import httputil
from tornado import ioloop
@ -31,12 +33,6 @@ from tornado import stack_context
from tornado.escape import utf8, native_str
from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main
from tornado.util import bytes_type
try:
from io import BytesIO # py3
except ImportError:
from cStringIO import StringIO as BytesIO # py2
class CurlAsyncHTTPClient(AsyncHTTPClient):
@ -45,7 +41,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
self._multi = pycurl.CurlMulti()
self._multi.setopt(pycurl.M_TIMERFUNCTION, self._set_timeout)
self._multi.setopt(pycurl.M_SOCKETFUNCTION, self._handle_socket)
self._curls = [_curl_create() for i in range(max_clients)]
self._curls = [self._curl_create() for i in range(max_clients)]
self._free_list = self._curls[:]
self._requests = collections.deque()
self._fds = {}
@ -211,7 +207,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
"callback": callback,
"curl_start_time": time.time(),
}
_curl_setup_request(curl, request, curl.info["buffer"],
self._curl_setup_request(curl, request, curl.info["buffer"],
curl.info["headers"])
self._multi.add_handle(curl)
@ -259,22 +255,14 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
def handle_callback_exception(self, callback):
self.io_loop.handle_callback_exception(callback)
class CurlError(HTTPError):
def __init__(self, errno, message):
HTTPError.__init__(self, 599, message)
self.errno = errno
def _curl_create():
def _curl_create(self):
curl = pycurl.Curl()
if gen_log.isEnabledFor(logging.DEBUG):
curl.setopt(pycurl.VERBOSE, 1)
curl.setopt(pycurl.DEBUGFUNCTION, _curl_debug)
curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug)
return curl
def _curl_setup_request(curl, request, buffer, headers):
def _curl_setup_request(self, curl, request, buffer, headers):
curl.setopt(pycurl.URL, native_str(request.url))
# libcurl's magic "Expect: 100-continue" behavior causes delays
@ -292,26 +280,19 @@ def _curl_setup_request(curl, request, buffer, headers):
if "Pragma" not in request.headers:
request.headers["Pragma"] = ""
# Request headers may be either a regular dict or HTTPHeaders object
if isinstance(request.headers, httputil.HTTPHeaders):
curl.setopt(pycurl.HTTPHEADER,
[native_str("%s: %s" % i) for i in request.headers.get_all()])
else:
curl.setopt(pycurl.HTTPHEADER,
[native_str("%s: %s" % i) for i in request.headers.items()])
["%s: %s" % (native_str(k), native_str(v))
for k, v in request.headers.get_all()])
if request.header_callback:
curl.setopt(pycurl.HEADERFUNCTION,
lambda line: request.header_callback(native_str(line)))
else:
curl.setopt(pycurl.HEADERFUNCTION,
lambda line: _curl_header_callback(headers,
native_str(line)))
functools.partial(self._curl_header_callback,
headers, request.header_callback))
if request.streaming_callback:
write_function = request.streaming_callback
write_function = lambda chunk: self.io_loop.add_callback(
request.streaming_callback, chunk)
else:
write_function = buffer.write
if bytes_type is str: # py2
if bytes is str: # py2
curl.setopt(pycurl.WRITEFUNCTION, write_function)
else: # py3
# Upstream pycurl doesn't support py3, but ubuntu 12.10 includes
@ -332,7 +313,7 @@ def _curl_setup_request(curl, request, buffer, headers):
curl.setopt(pycurl.USERAGENT, "Mozilla/5.0 (compatible; pycurl)")
if request.network_interface:
curl.setopt(pycurl.INTERFACE, request.network_interface)
if request.use_gzip:
if request.decompress_response:
curl.setopt(pycurl.ENCODING, "gzip,deflate")
else:
curl.setopt(pycurl.ENCODING, "none")
@ -390,25 +371,26 @@ def _curl_setup_request(curl, request, buffer, headers):
raise KeyError('unknown method ' + request.method)
# Handle curl's cryptic options for every individual HTTP method
if request.method in ("POST", "PUT"):
if request.method == "GET":
if request.body is not None:
raise ValueError('Body must be None for GET request')
elif request.method in ("POST", "PUT") or request.body:
if request.body is None:
raise AssertionError(
'Body must not be empty for "%s" request'
raise ValueError(
'Body must not be None for "%s" request'
% request.method)
request_buffer = BytesIO(utf8(request.body))
curl.setopt(pycurl.READFUNCTION, request_buffer.read)
if request.method == "POST":
def ioctl(cmd):
if cmd == curl.IOCMD_RESTARTREAD:
request_buffer.seek(0)
curl.setopt(pycurl.READFUNCTION, request_buffer.read)
curl.setopt(pycurl.IOCTLFUNCTION, ioctl)
if request.method == "POST":
curl.setopt(pycurl.POSTFIELDSIZE, len(request.body))
else:
curl.setopt(pycurl.UPLOAD, True)
curl.setopt(pycurl.INFILESIZE, len(request.body))
elif request.method == "GET":
if request.body is not None:
raise AssertionError('Body must be empty for GET request')
if request.auth_username is not None:
userpwd = "%s:%s" % (request.auth_username, request.auth_password or '')
@ -446,8 +428,10 @@ def _curl_setup_request(curl, request, buffer, headers):
if request.prepare_curl_callback is not None:
request.prepare_curl_callback(curl)
def _curl_header_callback(headers, header_line):
def _curl_header_callback(self, headers, header_callback, header_line):
header_line = native_str(header_line)
if header_callback is not None:
self.io_loop.add_callback(header_callback, header_line)
# header_line as returned by curl includes the end-of-line characters.
header_line = header_line.strip()
if header_line.startswith("HTTP/"):
@ -461,8 +445,7 @@ def _curl_header_callback(headers, header_line):
return
headers.parse_line(header_line)
def _curl_debug(debug_type, debug_msg):
def _curl_debug(self, debug_type, debug_msg):
debug_types = ('I', '<', '>', '<', '>')
if debug_type == 0:
gen_log.debug('%s', debug_msg.strip())
@ -472,6 +455,13 @@ def _curl_debug(debug_type, debug_msg):
elif debug_type == 4:
gen_log.debug('%s %r', debug_types[debug_type], debug_msg)
class CurlError(HTTPError):
def __init__(self, errno, message):
HTTPError.__init__(self, 599, message)
self.errno = errno
if __name__ == "__main__":
AsyncHTTPClient.configure(CurlAsyncHTTPClient)
main()

View file

@ -25,7 +25,7 @@ from __future__ import absolute_import, division, print_function, with_statement
import re
import sys
from tornado.util import bytes_type, unicode_type, basestring_type, u
from tornado.util import unicode_type, basestring_type, u
try:
from urllib.parse import parse_qs as _parse_qs # py3
@ -187,7 +187,7 @@ else:
return encoded
_UTF8_TYPES = (bytes_type, type(None))
_UTF8_TYPES = (bytes, type(None))
def utf8(value):
@ -215,7 +215,7 @@ def to_unicode(value):
"""
if isinstance(value, _TO_UNICODE_TYPES):
return value
if not isinstance(value, bytes_type):
if not isinstance(value, bytes):
raise TypeError(
"Expected bytes, unicode, or None; got %r" % type(value)
)
@ -246,7 +246,7 @@ def to_basestring(value):
"""
if isinstance(value, _BASESTRING_TYPES):
return value
if not isinstance(value, bytes_type):
if not isinstance(value, bytes):
raise TypeError(
"Expected bytes, unicode, or None; got %r" % type(value)
)
@ -264,7 +264,7 @@ def recursive_unicode(obj):
return list(recursive_unicode(i) for i in obj)
elif isinstance(obj, tuple):
return tuple(recursive_unicode(i) for i in obj)
elif isinstance(obj, bytes_type):
elif isinstance(obj, bytes):
return to_unicode(obj)
else:
return obj

View file

@ -29,16 +29,7 @@ could be written with ``gen`` as::
Most asynchronous functions in Tornado return a `.Future`;
yielding this object returns its `~.Future.result`.
For functions that do not return ``Futures``, `Task` works with any
function that takes a ``callback`` keyword argument (most Tornado functions
can be used in either style, although the ``Future`` style is preferred
since it is both shorter and provides better exception handling)::
@gen.coroutine
def get(self):
yield gen.Task(AsyncHTTPClient().fetch, "http://example.com")
You can also yield a list or dict of ``Futures`` and/or ``Tasks``, which will be
You can also yield a list or dict of ``Futures``, which will be
started at the same time and run in parallel; a list or dict of results will
be returned when they are all finished::
@ -54,30 +45,6 @@ be returned when they are all finished::
.. versionchanged:: 3.2
Dict support added.
For more complicated interfaces, `Task` can be split into two parts:
`Callback` and `Wait`::
class GenAsyncHandler2(RequestHandler):
@gen.coroutine
def get(self):
http_client = AsyncHTTPClient()
http_client.fetch("http://example.com",
callback=(yield gen.Callback("key")))
response = yield gen.Wait("key")
do_something_with_response(response)
self.render("template.html")
The ``key`` argument to `Callback` and `Wait` allows for multiple
asynchronous operations to be started at different times and proceed
in parallel: yield several callbacks with different keys, then wait
for them once all the async operations have started.
The result of a `Wait` or `Task` yield expression depends on how the callback
was run. If it was called with no arguments, the result is ``None``. If
it was called with one argument, the result is that argument. If it was
called with more than one argument or any keyword arguments, the result
is an `Arguments` object, which is a named tuple ``(args, kwargs)``.
"""
from __future__ import absolute_import, division, print_function, with_statement
@ -142,7 +109,10 @@ def engine(func):
raise ReturnValueIgnoredError(
"@gen.engine functions cannot return values: %r" %
(future.result(),))
future.add_done_callback(final_callback)
# The engine interface doesn't give us any way to return
# errors but to raise them into the stack context.
# Save the stack context here to use when the Future has resolved.
future.add_done_callback(stack_context.wrap(final_callback))
return wrapper
@ -169,6 +139,17 @@ def coroutine(func, replace_callback=True):
From the caller's perspective, ``@gen.coroutine`` is similar to
the combination of ``@return_future`` and ``@gen.engine``.
.. warning::
When exceptions occur inside a coroutine, the exception
information will be stored in the `.Future` object. You must
examine the result of the `.Future` object, or the exception
may go unnoticed by your code. This means yielding the function
if called from another coroutine, using something like
`.IOLoop.run_sync` for top-level calls, or passing the `.Future`
to `.IOLoop.add_future`.
"""
return _make_coroutine_wrapper(func, replace_callback=True)
@ -218,7 +199,18 @@ def _make_coroutine_wrapper(func, replace_callback):
future.set_exc_info(sys.exc_info())
else:
Runner(result, future, yielded)
try:
return future
finally:
# Subtle memory optimization: if next() raised an exception,
# the future's exc_info contains a traceback which
# includes this stack frame. This creates a cycle,
# which will be collected at the next full GC but has
# been shown to greatly increase memory usage of
# benchmarks (relative to the refcount-based scheme
# used in the absence of cycles). We can avoid the
# cycle by clearing the local variable after we return it.
future = None
future.set_result(result)
return future
return wrapper
@ -252,8 +244,8 @@ class Return(Exception):
class YieldPoint(object):
"""Base class for objects that may be yielded from the generator.
Applications do not normally need to use this class, but it may be
subclassed to provide additional yielding behavior.
.. deprecated:: 4.0
Use `Futures <.Future>` instead.
"""
def start(self, runner):
"""Called by the runner after the generator has yielded.
@ -289,6 +281,9 @@ class Callback(YieldPoint):
The callback may be called with zero or one arguments; if an argument
is given it will be returned by `Wait`.
.. deprecated:: 4.0
Use `Futures <.Future>` instead.
"""
def __init__(self, key):
self.key = key
@ -305,7 +300,11 @@ class Callback(YieldPoint):
class Wait(YieldPoint):
"""Returns the argument passed to the result of a previous `Callback`."""
"""Returns the argument passed to the result of a previous `Callback`.
.. deprecated:: 4.0
Use `Futures <.Future>` instead.
"""
def __init__(self, key):
self.key = key
@ -326,6 +325,9 @@ class WaitAll(YieldPoint):
a list of results in the same order.
`WaitAll` is equivalent to yielding a list of `Wait` objects.
.. deprecated:: 4.0
Use `Futures <.Future>` instead.
"""
def __init__(self, keys):
self.keys = keys
@ -341,20 +343,12 @@ class WaitAll(YieldPoint):
def Task(func, *args, **kwargs):
"""Runs a single asynchronous operation.
"""Adapts a callback-based asynchronous function for use in coroutines.
Takes a function (and optional additional arguments) and runs it with
those arguments plus a ``callback`` keyword argument. The argument passed
to the callback is returned as the result of the yield expression.
A `Task` is equivalent to a `Callback`/`Wait` pair (with a unique
key generated automatically)::
result = yield gen.Task(func, args)
func(args, callback=(yield gen.Callback(key)))
result = yield gen.Wait(key)
.. versionchanged:: 4.0
``gen.Task`` is now a function that returns a `.Future`, instead of
a subclass of `YieldPoint`. It still behaves the same way when

View file

@ -21,6 +21,8 @@
from __future__ import absolute_import, division, print_function, with_statement
import re
from tornado.concurrent import Future
from tornado.escape import native_str, utf8
from tornado import gen
@ -56,7 +58,7 @@ class HTTP1ConnectionParameters(object):
"""
def __init__(self, no_keep_alive=False, chunk_size=None,
max_header_size=None, header_timeout=None, max_body_size=None,
body_timeout=None, use_gzip=False):
body_timeout=None, decompress=False):
"""
:arg bool no_keep_alive: If true, always close the connection after
one request.
@ -65,7 +67,8 @@ class HTTP1ConnectionParameters(object):
:arg float header_timeout: how long to wait for all headers (seconds)
:arg int max_body_size: maximum amount of data for body
:arg float body_timeout: how long to wait while reading body (seconds)
:arg bool use_gzip: if true, decode incoming ``Content-Encoding: gzip``
:arg bool decompress: if true, decode incoming
``Content-Encoding: gzip``
"""
self.no_keep_alive = no_keep_alive
self.chunk_size = chunk_size or 65536
@ -73,7 +76,7 @@ class HTTP1ConnectionParameters(object):
self.header_timeout = header_timeout
self.max_body_size = max_body_size
self.body_timeout = body_timeout
self.use_gzip = use_gzip
self.decompress = decompress
class HTTP1Connection(httputil.HTTPConnection):
@ -141,7 +144,7 @@ class HTTP1Connection(httputil.HTTPConnection):
Returns a `.Future` that resolves to None after the full response has
been read.
"""
if self.params.use_gzip:
if self.params.decompress:
delegate = _GzipMessageDelegate(delegate, self.params.chunk_size)
return self._read_message(delegate)
@ -190,8 +193,17 @@ class HTTP1Connection(httputil.HTTPConnection):
skip_body = True
code = start_line.code
if code == 304:
# 304 responses may include the content-length header
# but do not actually have a body.
# http://tools.ietf.org/html/rfc7230#section-3.3
skip_body = True
if code >= 100 and code < 200:
# 1xx responses should never indicate the presence of
# a body.
if ('Content-Length' in headers or
'Transfer-Encoding' in headers):
raise httputil.HTTPInputError(
"Response code %d cannot have body" % code)
# TODO: client delegates will get headers_received twice
# in the case of a 100-continue. Document or change?
yield self._read_message(delegate)
@ -200,7 +212,8 @@ class HTTP1Connection(httputil.HTTPConnection):
not self._write_finished):
self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
if not skip_body:
body_future = self._read_body(headers, delegate)
body_future = self._read_body(
start_line.code if self.is_client else 0, headers, delegate)
if body_future is not None:
if self._body_timeout is None:
yield body_future
@ -293,6 +306,8 @@ class HTTP1Connection(httputil.HTTPConnection):
self._clear_callbacks()
stream = self.stream
self.stream = None
if not self._finish_future.done():
self._finish_future.set_result(None)
return stream
def set_body_timeout(self, timeout):
@ -454,6 +469,7 @@ class HTTP1Connection(httputil.HTTPConnection):
if start_line.version == "HTTP/1.1":
return connection_header != "close"
elif ("Content-Length" in headers
or headers.get("Transfer-Encoding", "").lower() == "chunked"
or start_line.method in ("HEAD", "GET")):
return connection_header == "keep-alive"
return False
@ -470,7 +486,11 @@ class HTTP1Connection(httputil.HTTPConnection):
self._finish_future.set_result(None)
def _parse_headers(self, data):
data = native_str(data.decode('latin1'))
# The lstrip removes newlines that some implementations sometimes
# insert between messages of a reused connection. Per RFC 7230,
# we SHOULD ignore at least one empty line before the request.
# http://tools.ietf.org/html/rfc7230#section-3.5
data = native_str(data.decode('latin1')).lstrip("\r\n")
eol = data.find("\r\n")
start_line = data[:eol]
try:
@ -481,12 +501,36 @@ class HTTP1Connection(httputil.HTTPConnection):
data[eol:100])
return start_line, headers
def _read_body(self, headers, delegate):
content_length = headers.get("Content-Length")
if content_length:
content_length = int(content_length)
def _read_body(self, code, headers, delegate):
if "Content-Length" in headers:
if "," in headers["Content-Length"]:
# Proxies sometimes cause Content-Length headers to get
# duplicated. If all the values are identical then we can
# use them but if they differ it's an error.
pieces = re.split(r',\s*', headers["Content-Length"])
if any(i != pieces[0] for i in pieces):
raise httputil.HTTPInputError(
"Multiple unequal Content-Lengths: %r" %
headers["Content-Length"])
headers["Content-Length"] = pieces[0]
content_length = int(headers["Content-Length"])
if content_length > self._max_body_size:
raise httputil.HTTPInputError("Content-Length too long")
else:
content_length = None
if code == 204:
# This response code is not allowed to have a non-empty body,
# and has an implicit length of zero instead of read-until-close.
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
if ("Transfer-Encoding" in headers or
content_length not in (None, 0)):
raise httputil.HTTPInputError(
"Response with code %d should not have body" % code)
content_length = 0
if content_length is not None:
return self._read_fixed_body(content_length, delegate)
if headers.get("Transfer-Encoding") == "chunked":
return self._read_chunked_body(delegate)
@ -581,6 +625,9 @@ class _GzipMessageDelegate(httputil.HTTPMessageDelegate):
self._delegate.data_received(tail)
return self._delegate.finish()
def on_connection_close(self):
return self._delegate.on_connection_close()
class HTTP1ServerConnection(object):
"""An HTTP/1.x server."""

View file

@ -33,6 +33,9 @@ information, see
http://curl.haxx.se/libcurl/c/curl_easy_setopt.html#CURLOPTCONNECTTIMEOUTMS
and comments in curl_httpclient.py).
To select ``curl_httpclient``, call `AsyncHTTPClient.configure` at startup::
AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
"""
from __future__ import absolute_import, division, print_function, with_statement
@ -60,7 +63,12 @@ class HTTPClient(object):
response = http_client.fetch("http://www.google.com/")
print response.body
except httpclient.HTTPError as e:
print "Error:", e
# HTTPError is raised for non-200 responses; the response
# can be found in e.response.
print("Error: " + str(e))
except Exception as e:
# Other errors are possible, such as IOError.
print("Error: " + str(e))
http_client.close()
"""
def __init__(self, async_client_class=None, **kwargs):
@ -279,7 +287,7 @@ class HTTPRequest(object):
request_timeout=20.0,
follow_redirects=True,
max_redirects=5,
use_gzip=True,
decompress_response=True,
proxy_password='',
allow_nonstandard_methods=False,
validate_cert=True)
@ -296,7 +304,7 @@ class HTTPRequest(object):
validate_cert=None, ca_certs=None,
allow_ipv6=None,
client_key=None, client_cert=None, body_producer=None,
expect_100_continue=False):
expect_100_continue=False, decompress_response=None):
r"""All parameters except ``url`` are optional.
:arg string url: URL to fetch
@ -330,7 +338,11 @@ class HTTPRequest(object):
or return the 3xx response?
:arg int max_redirects: Limit for ``follow_redirects``
:arg string user_agent: String to send as ``User-Agent`` header
:arg bool use_gzip: Request gzip encoding from the server
:arg bool decompress_response: Request a compressed response from
the server and decompress it after downloading. Default is True.
New in Tornado 4.0.
:arg bool use_gzip: Deprecated alias for ``decompress_response``
since Tornado 4.0.
:arg string network_interface: Network interface to use for request.
``curl_httpclient`` only; see note below.
:arg callable streaming_callback: If set, ``streaming_callback`` will
@ -373,7 +385,6 @@ class HTTPRequest(object):
before sending the request body. Only supported with
simple_httpclient.
.. note::
When using ``curl_httpclient`` certain options may be
@ -414,7 +425,10 @@ class HTTPRequest(object):
self.follow_redirects = follow_redirects
self.max_redirects = max_redirects
self.user_agent = user_agent
self.use_gzip = use_gzip
if decompress_response is not None:
self.decompress_response = decompress_response
else:
self.decompress_response = use_gzip
self.network_interface = network_interface
self.streaming_callback = streaming_callback
self.header_callback = header_callback

View file

@ -50,6 +50,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
import tornado.httpserver
import tornado.ioloop
from tornado import httputil
def handle_request(request):
message = "You requested %s\n" % request.uri
@ -129,13 +130,14 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
way other than `tornado.netutil.bind_sockets`.
.. versionchanged:: 4.0
Added ``gzip``, ``chunk_size``, ``max_header_size``,
Added ``decompress_request``, ``chunk_size``, ``max_header_size``,
``idle_connection_timeout``, ``body_timeout``, ``max_body_size``
arguments. Added support for `.HTTPServerConnectionDelegate`
instances as ``request_callback``.
"""
def __init__(self, request_callback, no_keep_alive=False, io_loop=None,
xheaders=False, ssl_options=None, protocol=None, gzip=False,
xheaders=False, ssl_options=None, protocol=None,
decompress_request=False,
chunk_size=None, max_header_size=None,
idle_connection_timeout=None, body_timeout=None,
max_body_size=None, max_buffer_size=None):
@ -144,7 +146,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
self.xheaders = xheaders
self.protocol = protocol
self.conn_params = HTTP1ConnectionParameters(
use_gzip=gzip,
decompress=decompress_request,
chunk_size=chunk_size,
max_header_size=max_header_size,
header_timeout=idle_connection_timeout or 3600,

View file

@ -33,7 +33,7 @@ import time
from tornado.escape import native_str, parse_qs_bytes, utf8
from tornado.log import gen_log
from tornado.util import ObjectDict, bytes_type
from tornado.util import ObjectDict
try:
import Cookie # py2
@ -335,7 +335,7 @@ class HTTPServerRequest(object):
# set remote IP and protocol
context = getattr(connection, 'context', None)
self.remote_ip = getattr(context, 'remote_ip')
self.remote_ip = getattr(context, 'remote_ip', None)
self.protocol = getattr(context, 'protocol', "http")
self.host = host or self.headers.get("Host") or "127.0.0.1"
@ -379,7 +379,7 @@ class HTTPServerRequest(object):
Use ``request.connection`` and the `.HTTPConnection` methods
to write the response.
"""
assert isinstance(chunk, bytes_type)
assert isinstance(chunk, bytes)
self.connection.write(chunk, callback=callback)
def finish(self):
@ -562,11 +562,18 @@ class HTTPConnection(object):
def url_concat(url, args):
"""Concatenate url and argument dictionary regardless of whether
"""Concatenate url and arguments regardless of whether
url has existing query parameters.
``args`` may be either a dictionary or a list of key-value pairs
(the latter allows for multiple values with the same key.
>>> url_concat("http://example.com/foo", dict(c="d"))
'http://example.com/foo?c=d'
>>> url_concat("http://example.com/foo?a=b", dict(c="d"))
'http://example.com/foo?a=b&c=d'
>>> url_concat("http://example.com/foo?a=b", [("c", "d"), ("c", "d2")])
'http://example.com/foo?a=b&c=d&c=d2'
"""
if not args:
return url
@ -803,6 +810,8 @@ def parse_response_start_line(line):
# _parseparam and _parse_header are copied and modified from python2.7's cgi.py
# The original 2.7 version of this code did not correctly support some
# combinations of semicolons and double quotes.
# It has also been modified to support valueless parameters as seen in
# websocket extension negotiations.
def _parseparam(s):
@ -836,9 +845,31 @@ def _parse_header(line):
value = value[1:-1]
value = value.replace('\\\\', '\\').replace('\\"', '"')
pdict[name] = value
else:
pdict[p] = None
return key, pdict
def _encode_header(key, pdict):
"""Inverse of _parse_header.
>>> _encode_header('permessage-deflate',
... {'client_max_window_bits': 15, 'client_no_context_takeover': None})
'permessage-deflate; client_max_window_bits=15; client_no_context_takeover'
"""
if not pdict:
return key
out = [key]
# Sort the parameters just to make it easy to test.
for k, v in sorted(pdict.items()):
if v is None:
out.append(k)
else:
# TODO: quote if necessary.
out.append('%s=%s' % (k, v))
return '; '.join(out)
def doctests():
import doctest
return doctest.DocTestSuite()

View file

@ -197,7 +197,7 @@ class IOLoop(Configurable):
An `IOLoop` automatically becomes current for its thread
when it is started, but it is sometimes useful to call
`make_current` explictly before starting the `IOLoop`,
`make_current` explicitly before starting the `IOLoop`,
so that code run at startup time can find the right
instance.
"""
@ -477,7 +477,7 @@ class IOLoop(Configurable):
.. versionadded:: 4.0
"""
self.call_at(self.time() + delay, callback, *args, **kwargs)
return self.call_at(self.time() + delay, callback, *args, **kwargs)
def call_at(self, when, callback, *args, **kwargs):
"""Runs the ``callback`` at the absolute time designated by ``when``.
@ -493,7 +493,7 @@ class IOLoop(Configurable):
.. versionadded:: 4.0
"""
self.add_timeout(when, callback, *args, **kwargs)
return self.add_timeout(when, callback, *args, **kwargs)
def remove_timeout(self, timeout):
"""Cancels a pending timeout.
@ -724,7 +724,7 @@ class PollIOLoop(IOLoop):
#
# If someone has already set a wakeup fd, we don't want to
# disturb it. This is an issue for twisted, which does its
# SIGCHILD processing in response to its own wakeup fd being
# SIGCHLD processing in response to its own wakeup fd being
# written to. As long as the wakeup fd is registered on the IOLoop,
# the loop will still wake up and everything should work.
old_wakeup_fd = None
@ -754,17 +754,18 @@ class PollIOLoop(IOLoop):
# Do not run anything until we have determined which ones
# are ready, so timeouts that call add_timeout cannot
# schedule anything in this iteration.
due_timeouts = []
if self._timeouts:
now = self.time()
while self._timeouts:
if self._timeouts[0].callback is None:
# the timeout was cancelled
# The timeout was cancelled. Note that the
# cancellation check is repeated below for timeouts
# that are cancelled by another timeout or callback.
heapq.heappop(self._timeouts)
self._cancellations -= 1
elif self._timeouts[0].deadline <= now:
timeout = heapq.heappop(self._timeouts)
callbacks.append(timeout.callback)
del timeout
due_timeouts.append(heapq.heappop(self._timeouts))
else:
break
if (self._cancellations > 512
@ -778,9 +779,12 @@ class PollIOLoop(IOLoop):
for callback in callbacks:
self._run_callback(callback)
for timeout in due_timeouts:
if timeout.callback is not None:
self._run_callback(timeout.callback)
# Closures may be holding on to a lot of memory, so allow
# them to be freed before we go into our poll wait.
callbacks = callback = None
callbacks = callback = due_timeouts = timeout = None
if self._callbacks:
# If any callbacks or timeouts called add_callback,
@ -969,9 +973,10 @@ class PeriodicCallback(object):
if not self._running:
return
try:
self.callback()
return self.callback()
except Exception:
self.io_loop.handle_callback_exception(self.callback)
finally:
self._schedule_next()
def _schedule_next(self):

View file

@ -39,7 +39,7 @@ from tornado import ioloop
from tornado.log import gen_log, app_log
from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError
from tornado import stack_context
from tornado.util import bytes_type, errno_from_exception
from tornado.util import errno_from_exception
try:
from tornado.platform.posix import _set_nonblocking
@ -324,7 +324,7 @@ class BaseIOStream(object):
.. versionchanged:: 4.0
Now returns a `.Future` if no callback is given.
"""
assert isinstance(data, bytes_type)
assert isinstance(data, bytes)
self._check_closed()
# We use bool(_write_buffer) as a proxy for write_buffer_size>0,
# so never put empty strings in the buffer.
@ -505,7 +505,7 @@ class BaseIOStream(object):
def wrapper():
self._pending_callbacks -= 1
try:
callback(*args)
return callback(*args)
except Exception:
app_log.error("Uncaught exception, closing connection.",
exc_info=True)
@ -517,6 +517,7 @@ class BaseIOStream(object):
# Re-raise the exception so that IOLoop.handle_callback_exception
# can see it and log the error
raise
finally:
self._maybe_add_error_listener()
# We schedule callbacks to be run on the next IOLoop iteration
# rather than running them directly for several reasons:
@ -553,7 +554,7 @@ class BaseIOStream(object):
# Pretend to have a pending callback so that an EOF in
# _read_to_buffer doesn't trigger an immediate close
# callback. At the end of this method we'll either
# estabilsh a real pending callback via
# establish a real pending callback via
# _read_from_buffer or run the close callback.
#
# We need two try statements here so that
@ -992,6 +993,11 @@ class IOStream(BaseIOStream):
"""
self._connecting = True
if callback is not None:
self._connect_callback = stack_context.wrap(callback)
future = None
else:
future = self._connect_future = TracebackFuture()
try:
self.socket.connect(address)
except socket.error as e:
@ -1007,12 +1013,7 @@ class IOStream(BaseIOStream):
gen_log.warning("Connect error on fd %s: %s",
self.socket.fileno(), e)
self.close(exc_info=True)
return
if callback is not None:
self._connect_callback = stack_context.wrap(callback)
future = None
else:
future = self._connect_future = TracebackFuture()
return future
self._add_io_state(self.io_loop.WRITE)
return future
@ -1184,8 +1185,14 @@ class SSLIOStream(IOStream):
return self.close(exc_info=True)
raise
except socket.error as err:
if err.args[0] in _ERRNO_CONNRESET:
# Some port scans (e.g. nmap in -sT mode) have been known
# to cause do_handshake to raise EBADF, so make that error
# quiet as well.
# https://groups.google.com/forum/?fromgroups#!topic/python-tornado/ApucKJat1_0
if (err.args[0] in _ERRNO_CONNRESET or
err.args[0] == errno.EBADF):
return self.close(exc_info=True)
raise
except AttributeError:
# On Linux, if the connection was reset before the call to
# wrap_socket, do_handshake will fail with an

View file

@ -179,7 +179,7 @@ class LogFormatter(logging.Formatter):
def enable_pretty_logging(options=None, logger=None):
"""Turns on formatted logging output as configured.
This is called automaticaly by `tornado.options.parse_command_line`
This is called automatically by `tornado.options.parse_command_line`
and `tornado.options.parse_config_file`.
"""
if options is None:

View file

@ -35,6 +35,11 @@ except ImportError:
# ssl is not available on Google App Engine
ssl = None
try:
xrange # py2
except NameError:
xrange = range # py3
if hasattr(ssl, 'match_hostname') and hasattr(ssl, 'CertificateError'): # python 3.2+
ssl_match_hostname = ssl.match_hostname
SSLCertificateError = ssl.CertificateError
@ -60,8 +65,11 @@ _ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
if hasattr(errno, "WSAEWOULDBLOCK"):
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,)
# Default backlog used when calling sock.listen()
_DEFAULT_BACKLOG = 128
def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags=None):
def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
backlog=_DEFAULT_BACKLOG, flags=None):
"""Creates listening sockets bound to the given port and address.
Returns a list of socket objects (multiple sockets are returned if
@ -141,7 +149,7 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags
return sockets
if hasattr(socket, 'AF_UNIX'):
def bind_unix_socket(file, mode=0o600, backlog=128):
def bind_unix_socket(file, mode=0o600, backlog=_DEFAULT_BACKLOG):
"""Creates a listening unix socket.
If a socket with the given name already exists, it will be deleted.
@ -184,7 +192,18 @@ def add_accept_handler(sock, callback, io_loop=None):
io_loop = IOLoop.current()
def accept_handler(fd, events):
while True:
# More connections may come in while we're handling callbacks;
# to prevent starvation of other tasks we must limit the number
# of connections we accept at a time. Ideally we would accept
# up to the number of connections that were waiting when we
# entered this method, but this information is not available
# (and rearranging this method to call accept() as many times
# as possible before running any callbacks would have adverse
# effects on load balancing in multiprocess configurations).
# Instead, we use the (default) listen backlog as a rough
# heuristic for the number of connections we can reasonably
# accept at once.
for i in xrange(_DEFAULT_BACKLOG):
try:
connection, address = sock.accept()
except socket.error as e:

View file

@ -79,7 +79,7 @@ import sys
import os
import textwrap
from tornado.escape import _unicode
from tornado.escape import _unicode, native_str
from tornado.log import define_logging_options
from tornado import stack_context
from tornado.util import basestring_type, exec_in
@ -271,10 +271,14 @@ class OptionParser(object):
If ``final`` is ``False``, parse callbacks will not be run.
This is useful for applications that wish to combine configurations
from multiple sources.
.. versionchanged:: 4.1
Config files are now always interpreted as utf-8 instead of
the system default encoding.
"""
config = {}
with open(path) as f:
exec_in(f.read(), config, config)
with open(path, 'rb') as f:
exec_in(native_str(f.read()), config, config)
for name in config:
if name in self._options:
self._options[name].set(config[name])

View file

@ -10,12 +10,10 @@ unfinished callbacks on the event loop that fail when it resumes)
"""
from __future__ import absolute_import, division, print_function, with_statement
import datetime
import functools
from tornado.ioloop import IOLoop
from tornado import stack_context
from tornado.util import timedelta_to_seconds
try:
# Import the real asyncio module for py33+ first. Older versions of the

View file

@ -141,7 +141,7 @@ class TornadoDelayedCall(object):
class TornadoReactor(PosixReactorBase):
"""Twisted reactor built on the Tornado IOLoop.
Since it is intented to be used in applications where the top-level
Since it is intended to be used in applications where the top-level
event loop is ``io_loop.start()`` rather than ``reactor.run()``,
it is implemented a little differently than other Twisted reactors.
We override `mainLoop` instead of `doIteration` and must implement

View file

@ -39,7 +39,7 @@ from tornado.util import errno_from_exception
try:
import multiprocessing
except ImportError:
# Multiprocessing is not availble on Google App Engine.
# Multiprocessing is not available on Google App Engine.
multiprocessing = None
try:
@ -240,7 +240,7 @@ class Subprocess(object):
The callback takes one argument, the return code of the process.
This method uses a ``SIGCHILD`` handler, which is a global setting
This method uses a ``SIGCHLD`` handler, which is a global setting
and may conflict if you have other libraries trying to handle the
same signal. If you are using more than one ``IOLoop`` it may
be necessary to call `Subprocess.initialize` first to designate
@ -257,7 +257,7 @@ class Subprocess(object):
@classmethod
def initialize(cls, io_loop=None):
"""Initializes the ``SIGCHILD`` handler.
"""Initializes the ``SIGCHLD`` handler.
The signal handler is run on an `.IOLoop` to avoid locking issues.
Note that the `.IOLoop` used for signal handling need not be the
@ -275,7 +275,7 @@ class Subprocess(object):
@classmethod
def uninitialize(cls):
"""Removes the ``SIGCHILD`` handler."""
"""Removes the ``SIGCHLD`` handler."""
if not cls._initialized:
return
signal.signal(signal.SIGCHLD, cls._old_sigchld)

View file

@ -1,77 +0,0 @@
'''
Created on May 31, 2014
@author: Fenriswolf
'''
import logging
import inspect
import re
from collections import OrderedDict
from tornado.web import Application, RequestHandler, HTTPError
app_routing_log = logging.getLogger("tornado.application.routing")
class RoutingApplication(Application):
def __init__(self, handlers=None, default_host="", transforms=None, wsgi=False, **settings):
Application.__init__(self, handlers, default_host, transforms, wsgi, **settings)
self.handler_map = OrderedDict()
def expose(self, rule='', methods = ['GET'], kwargs=None, name=None):
"""
A decorator that is used to register a given URL rule.
"""
def decorator(func, *args, **kwargs):
func_name = func.__name__
frm = inspect.stack()[1]
class_name = frm[3]
module_name = frm[0].f_back.f_globals["__name__"]
full_class_name = module_name + '.' + class_name
for method in methods:
func_rule = rule if rule else None
if not func_rule:
if func_name == 'index':
func_rule = class_name
else:
func_rule = class_name + '/' + func_name
func_rule = r'/%s(.*)(/?)' % func_rule
if full_class_name not in self.handler_map:
self.handler_map.setdefault(full_class_name, {})[method] = [(func_rule, func_name)]
else:
self.handler_map[full_class_name][method] += [(func_rule, func_name)]
app_routing_log.info("register %s %s to %s.%s" % (method, func_rule, full_class_name, func_name))
return func
return decorator
def setRouteHandlers(self):
handlers = [(rule[0], full_class_name)
for full_class_name, methods in self.handler_map.items()
for rules in methods.values()
for rule in rules]
self.add_handlers(".*$", handlers)
class RequestRoutingHandler(RequestHandler):
def _get_func_name(self):
full_class_name = self.__module__ + '.' + self.__class__.__name__
rules = self.application.handler_map.get(full_class_name, {}).get(self.request.method, [])
for rule, func_name in rules:
if not rule or not func_name:
continue
match = re.match(rule, self.request.path)
if match:
return func_name
raise HTTPError(404, "")
def _execute_method(self):
if not self._finished:
func_name = self._get_func_name()
method = getattr(self, func_name)
self._when_complete(method(*self.path_args, **self.path_kwargs),
self._execute_finish)

View file

@ -19,11 +19,8 @@ import functools
import re
import socket
import sys
from io import BytesIO
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
try:
import urlparse # py2
@ -37,7 +34,7 @@ except ImportError:
ssl = None
try:
import certifi
import lib.certifi
except ImportError:
certifi = None
@ -222,6 +219,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
stack_context.wrap(self._on_timeout))
self.tcp_client.connect(host, port, af=af,
ssl_options=ssl_options,
max_buffer_size=self.max_buffer_size,
callback=self._on_connect)
def _get_ssl_options(self, scheme):
@ -277,7 +275,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
stream.close()
return
self.stream = stream
self.stream.set_close_callback(self._on_close)
self.stream.set_close_callback(self.on_connection_close)
self._remove_timeout()
if self.final_callback is None:
return
@ -316,18 +314,18 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
if self.request.user_agent:
self.request.headers["User-Agent"] = self.request.user_agent
if not self.request.allow_nonstandard_methods:
if self.request.method in ("POST", "PATCH", "PUT"):
if (self.request.body is None and
self.request.body_producer is None):
raise AssertionError(
'Body must not be empty for "%s" request'
% self.request.method)
else:
if (self.request.body is not None or
self.request.body_producer is not None):
raise AssertionError(
'Body must be empty for "%s" request'
% self.request.method)
# Some HTTP methods nearly always have bodies while others
# almost never do. Fail in this case unless the user has
# opted out of sanity checks with allow_nonstandard_methods.
body_expected = self.request.method in ("POST", "PATCH", "PUT")
body_present = (self.request.body is not None or
self.request.body_producer is not None)
if ((body_expected and not body_present) or
(body_present and not body_expected)):
raise ValueError(
'Body must %sbe None for method %s (unelss '
'allow_nonstandard_methods is true)' %
('not ' if body_expected else '', self.request.method))
if self.request.expect_100_continue:
self.request.headers["Expect"] = "100-continue"
if self.request.body is not None:
@ -338,7 +336,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
if (self.request.method == "POST" and
"Content-Type" not in self.request.headers):
self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
if self.request.use_gzip:
if self.request.decompress_response:
self.request.headers["Accept-Encoding"] = "gzip"
req_path = ((self.parsed.path or '/') +
(('?' + self.parsed.query) if self.parsed.query else ''))
@ -348,7 +346,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
HTTP1ConnectionParameters(
no_keep_alive=True,
max_header_size=self.max_header_size,
use_gzip=self.request.use_gzip),
decompress=self.request.decompress_response),
self._sockaddr)
start_line = httputil.RequestStartLine(self.request.method,
req_path, 'HTTP/1.1')
@ -418,12 +416,15 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
# pass it along, unless it's just the stream being closed.
return isinstance(value, StreamClosedError)
def _on_close(self):
def on_connection_close(self):
if self.final_callback is not None:
message = "Connection closed"
if self.stream.error:
raise self.stream.error
try:
raise HTTPError(599, message)
except HTTPError:
self._handle_exception(*sys.exc_info())
def headers_received(self, first_line, headers):
if self.request.expect_100_continue and first_line.code == 100:
@ -433,20 +434,6 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
self.code = first_line.code
self.reason = first_line.reason
if "Content-Length" in self.headers:
if "," in self.headers["Content-Length"]:
# Proxies sometimes cause Content-Length headers to get
# duplicated. If all the values are identical then we can
# use them but if they differ it's an error.
pieces = re.split(r',\s*', self.headers["Content-Length"])
if any(i != pieces[0] for i in pieces):
raise ValueError("Multiple unequal Content-Lengths: %r" %
self.headers["Content-Length"])
self.headers["Content-Length"] = pieces[0]
content_length = int(self.headers["Content-Length"])
else:
content_length = None
if self.request.header_callback is not None:
# Reassemble the start line.
self.request.header_callback('%s %s %s\r\n' % first_line)
@ -454,14 +441,6 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
self.request.header_callback("%s: %s\r\n" % (k, v))
self.request.header_callback('\r\n')
if 100 <= self.code < 200 or self.code == 204:
# These response codes never have bodies
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
if ("Transfer-Encoding" in self.headers or
content_length not in (None, 0)):
raise ValueError("Response with code %d should not have body" %
self.code)
def finish(self):
data = b''.join(self.chunks)
self._remove_timeout()

View file

@ -41,13 +41,13 @@ Example usage::
sys.exit(1)
with StackContext(die_on_error):
# Any exception thrown here *or in callback and its desendents*
# Any exception thrown here *or in callback and its descendants*
# will cause the process to exit instead of spinning endlessly
# in the ioloop.
http_client.fetch(url, callback)
ioloop.start()
Most applications shouln't have to work with `StackContext` directly.
Most applications shouldn't have to work with `StackContext` directly.
Here are a few rules of thumb for when it's necessary:
* If you're writing an asynchronous library that doesn't rely on a

View file

@ -163,7 +163,7 @@ class TCPClient(object):
functools.partial(self._create_stream, max_buffer_size))
af, addr, stream = yield connector.start()
# TODO: For better performance we could cache the (af, addr)
# information here and re-use it on sbusequent connections to
# information here and re-use it on subsequent connections to
# the same host. (http://tools.ietf.org/html/rfc6555#section-4.2)
if ssl_options is not None:
stream = yield stream.start_tls(False, ssl_options=ssl_options,

View file

@ -199,7 +199,7 @@ import threading
from tornado import escape
from tornado.log import app_log
from tornado.util import bytes_type, ObjectDict, exec_in, unicode_type
from tornado.util import ObjectDict, exec_in, unicode_type
try:
from cStringIO import StringIO # py2
@ -261,7 +261,7 @@ class Template(object):
"linkify": escape.linkify,
"datetime": datetime,
"_tt_utf8": escape.utf8, # for internal use
"_tt_string_types": (unicode_type, bytes_type),
"_tt_string_types": (unicode_type, bytes),
# __name__ and __loader__ allow the traceback mechanism to find
# the generated source code.
"__name__": self.name.replace('.', '_'),

View file

@ -1,4 +1,4 @@
Test coverage is almost non-existent, but it's a start. Be sure to
set PYTHONPATH apprioriately (generally to the root directory of your
set PYTHONPATH appropriately (generally to the root directory of your
tornado checkout) when running tests to make sure you're getting the
version of the tornado package that you expect.

View file

@ -8,7 +8,8 @@ from tornado.stack_context import ExceptionStackContext
from tornado.testing import AsyncHTTPTestCase
from tornado.test import httpclient_test
from tornado.test.util import unittest
from tornado.web import Application, RequestHandler
from tornado.web import Application, RequestHandler, URLSpec
try:
import pycurl

View file

@ -4,8 +4,8 @@
from __future__ import absolute_import, division, print_function, with_statement
import tornado.escape
from tornado.escape import utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape, to_unicode, json_decode, json_encode
from tornado.util import u, unicode_type, bytes_type
from tornado.escape import utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape, to_unicode, json_decode, json_encode, squeeze, recursive_unicode
from tornado.util import u, unicode_type
from tornado.test.util import unittest
linkify_tests = [
@ -212,6 +212,22 @@ class EscapeTestCase(unittest.TestCase):
# convert automatically if they are utf8; on python 3 byte strings
# are not allowed.
self.assertEqual(json_decode(json_encode(u("\u00e9"))), u("\u00e9"))
if bytes_type is str:
if bytes is str:
self.assertEqual(json_decode(json_encode(utf8(u("\u00e9")))), u("\u00e9"))
self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9")
def test_squeeze(self):
self.assertEqual(squeeze(u('sequences of whitespace chars'))
, u('sequences of whitespace chars'))
def test_recursive_unicode(self):
tests = {
'dict': {b"foo": b"bar"},
'list': [b"foo", b"bar"],
'tuple': (b"foo", b"bar"),
'bytes': b"foo"
}
self.assertEqual(recursive_unicode(tests['dict']), {u("foo"): u("bar")})
self.assertEqual(recursive_unicode(tests['list']), [u("foo"), u("bar")])
self.assertEqual(recursive_unicode(tests['tuple']), (u("foo"), u("bar")))
self.assertEqual(recursive_unicode(tests['bytes']), u("foo"))

View file

@ -8,6 +8,8 @@ from contextlib import closing
import functools
import sys
import threading
import datetime
from io import BytesIO
from tornado.escape import utf8
from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient
@ -19,13 +21,9 @@ from tornado import netutil
from tornado.stack_context import ExceptionStackContext, NullContext
from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog
from tornado.test.util import unittest, skipOnTravis
from tornado.util import u, bytes_type
from tornado.util import u
from tornado.web import Application, RequestHandler, url
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO
from tornado.httputil import format_timestamp, HTTPHeaders
class HelloWorldHandler(RequestHandler):
@ -41,6 +39,18 @@ class PostHandler(RequestHandler):
self.get_argument("arg1"), self.get_argument("arg2")))
class PutHandler(RequestHandler):
def put(self):
self.write("Put body: ")
self.write(self.request.body)
class RedirectHandler(RequestHandler):
def prepare(self):
self.redirect(self.get_argument("url"),
status=int(self.get_argument("status", "302")))
class ChunkHandler(RequestHandler):
def get(self):
self.write("asdf")
@ -83,6 +93,13 @@ class ContentLength304Handler(RequestHandler):
pass
class PatchHandler(RequestHandler):
def patch(self):
"Return the request payload - so we can check it is being kept"
self.write(self.request.body)
class AllMethodsHandler(RequestHandler):
SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ('OTHER',)
@ -101,6 +118,8 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase):
return Application([
url("/hello", HelloWorldHandler),
url("/post", PostHandler),
url("/put", PutHandler),
url("/redirect", RedirectHandler),
url("/chunk", ChunkHandler),
url("/auth", AuthHandler),
url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
@ -108,8 +127,15 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase):
url("/user_agent", UserAgentHandler),
url("/304_with_content_length", ContentLength304Handler),
url("/all_methods", AllMethodsHandler),
url('/patch', PatchHandler),
], gzip=True)
def test_patch_receives_payload(self):
body = b"some patch data"
response = self.fetch("/patch", method='PATCH', body=body)
self.assertEqual(response.code, 200)
self.assertEqual(response.body, body)
@skipOnTravis
def test_hello_world(self):
response = self.fetch("/hello")
@ -263,7 +289,7 @@ Transfer-Encoding: chunked
def test_types(self):
response = self.fetch("/hello")
self.assertEqual(type(response.body), bytes_type)
self.assertEqual(type(response.body), bytes)
self.assertEqual(type(response.headers["Content-Type"]), str)
self.assertEqual(type(response.code), int)
self.assertEqual(type(response.effective_url), str)
@ -314,11 +340,28 @@ Transfer-Encoding: chunked
# Construct a new instance of the configured client class
client = self.http_client.__class__(self.io_loop, force_instance=True,
defaults=defaults)
try:
client.fetch(self.get_url('/user_agent'), callback=self.stop)
response = self.wait()
self.assertEqual(response.body, b'TestDefaultUserAgent')
finally:
client.close()
def test_header_types(self):
# Header values may be passed as character or utf8 byte strings,
# in a plain dictionary or an HTTPHeaders object.
# Keys must always be the native str type.
# All combinations should have the same results on the wire.
for value in [u("MyUserAgent"), b"MyUserAgent"]:
for container in [dict, HTTPHeaders]:
headers = container()
headers['User-Agent'] = value
resp = self.fetch('/user_agent', headers=headers)
self.assertEqual(
resp.body, b"MyUserAgent",
"response=%r, value=%r, container=%r" %
(resp.body, value, container))
def test_304_with_content_length(self):
# According to the spec 304 responses SHOULD NOT include
# Content-Length or other entity headers, but some servers do it
@ -388,17 +431,39 @@ Transfer-Encoding: chunked
self.assertEqual(response.body, b'OTHER')
@gen_test
def test_body(self):
def test_body_sanity_checks(self):
hello_url = self.get_url('/hello')
with self.assertRaises(AssertionError) as context:
with self.assertRaises(ValueError) as context:
yield self.http_client.fetch(hello_url, body='data')
self.assertTrue('must be empty' in str(context.exception))
self.assertTrue('must be None' in str(context.exception))
with self.assertRaises(AssertionError) as context:
with self.assertRaises(ValueError) as context:
yield self.http_client.fetch(hello_url, method='POST')
self.assertTrue('must not be empty' in str(context.exception))
self.assertTrue('must not be None' in str(context.exception))
# This test causes odd failures with the combination of
# curl_httpclient (at least with the version of libcurl available
# on ubuntu 12.04), TwistedIOLoop, and epoll. For POST (but not PUT),
# curl decides the response came back too soon and closes the connection
# to start again. It does this *before* telling the socket callback to
# unregister the FD. Some IOLoop implementations have special kernel
# integration to discover this immediately. Tornado's IOLoops
# ignore errors on remove_handler to accommodate this behavior, but
# Twisted's reactor does not. The removeReader call fails and so
# do all future removeAll calls (which our tests do at cleanup).
#
#def test_post_307(self):
# response = self.fetch("/redirect?status=307&url=/post",
# method="POST", body=b"arg1=foo&arg2=bar")
# self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
def test_put_307(self):
response = self.fetch("/redirect?status=307&url=/put",
method="PUT", body=b"hello")
response.rethrow()
self.assertEqual(response.body, b"Put body: hello")
class RequestProxyTest(unittest.TestCase):
@ -515,3 +580,9 @@ class HTTPRequestTestCase(unittest.TestCase):
request = HTTPRequest('http://example.com')
request.body = 'foo'
self.assertEqual(request.body, utf8('foo'))
def test_if_modified_since(self):
http_date = datetime.datetime.utcnow()
request = HTTPRequest('http://example.com', if_modified_since=http_date)
self.assertEqual(request.headers,
{'If-Modified-Since': format_timestamp(http_date)})

View file

@ -9,12 +9,12 @@ from tornado.http1connection import HTTP1Connection
from tornado.httpserver import HTTPServer
from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine
from tornado.iostream import IOStream
from tornado.log import gen_log, app_log
from tornado.log import gen_log
from tornado.netutil import ssl_options_to_context
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test
from tornado.test.util import unittest, skipOnTravis
from tornado.util import u, bytes_type
from tornado.util import u
from tornado.web import Application, RequestHandler, asynchronous, stream_request_body
from contextlib import closing
import datetime
@ -25,11 +25,7 @@ import socket
import ssl
import sys
import tempfile
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
from io import BytesIO
def read_stream_body(stream, callback):
@ -297,10 +293,10 @@ class TypeCheckHandler(RequestHandler):
# secure cookies
self.check_type('arg_key', list(self.request.arguments.keys())[0], str)
self.check_type('arg_value', list(self.request.arguments.values())[0][0], bytes_type)
self.check_type('arg_value', list(self.request.arguments.values())[0][0], bytes)
def post(self):
self.check_type('body', self.request.body, bytes_type)
self.check_type('body', self.request.body, bytes)
self.write(self.errors)
def get(self):
@ -358,7 +354,7 @@ class HTTPServerTest(AsyncHTTPTestCase):
# if the data is not utf8. On python 2 parse_qs will work,
# but then the recursive_unicode call in EchoHandler will
# fail.
if str is bytes_type:
if str is bytes:
return
with ExpectLog(gen_log, 'Invalid x-www-form-urlencoded body'):
response = self.fetch(
@ -586,6 +582,8 @@ class KeepAliveTest(AsyncHTTPTestCase):
class HelloHandler(RequestHandler):
def get(self):
self.finish('Hello world')
def post(self):
self.finish('Hello world')
class LargeHandler(RequestHandler):
def get(self):
@ -687,6 +685,17 @@ class KeepAliveTest(AsyncHTTPTestCase):
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.close()
def test_http10_keepalive_extra_crlf(self):
self.http_version = b'HTTP/1.0'
self.connect()
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n')
self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.close()
def test_pipelined_requests(self):
self.connect()
self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
@ -715,6 +724,19 @@ class KeepAliveTest(AsyncHTTPTestCase):
self.read_headers()
self.close()
def test_keepalive_chunked(self):
self.http_version = b'HTTP/1.0'
self.connect()
self.stream.write(b'POST / HTTP/1.0\r\nConnection: keep-alive\r\n'
b'Transfer-Encoding: chunked\r\n'
b'\r\n0\r\n')
self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
self.read_response()
self.assertEqual(self.headers['Connection'], 'Keep-Alive')
self.close()
class GzipBaseTest(object):
def get_app(self):
@ -736,7 +758,7 @@ class GzipBaseTest(object):
class GzipTest(GzipBaseTest, AsyncHTTPTestCase):
def get_httpserver_options(self):
return dict(gzip=True)
return dict(decompress_request=True)
def test_gzip(self):
response = self.post_gzip('foo=bar')
@ -764,7 +786,7 @@ class StreamingChunkSizeTest(AsyncHTTPTestCase):
return SimpleAsyncHTTPClient(io_loop=self.io_loop)
def get_httpserver_options(self):
return dict(chunk_size=self.CHUNK_SIZE, gzip=True)
return dict(chunk_size=self.CHUNK_SIZE, decompress_request=True)
class MessageDelegate(HTTPMessageDelegate):
def __init__(self, connection):

View file

@ -2,7 +2,7 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado.httputil import url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp
from tornado.httputil import url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp, HTTPServerRequest, parse_request_start_line
from tornado.escape import utf8
from tornado.log import gen_log
from tornado.testing import ExpectLog
@ -253,3 +253,26 @@ class FormatTimestampTest(unittest.TestCase):
def test_datetime(self):
self.check(datetime.datetime.utcfromtimestamp(self.TIMESTAMP))
# HTTPServerRequest is mainly tested incidentally to the server itself,
# but this tests the parts of the class that can be tested in isolation.
class HTTPServerRequestTest(unittest.TestCase):
def test_default_constructor(self):
# All parameters are formally optional, but uri is required
# (and has been for some time). This test ensures that no
# more required parameters slip in.
HTTPServerRequest(uri='/')
class ParseRequestStartLineTest(unittest.TestCase):
METHOD = "GET"
PATH = "/foo"
VERSION = "HTTP/1.1"
def test_parse_request_start_line(self):
start_line = " ".join([self.METHOD, self.PATH, self.VERSION])
parsed_start_line = parse_request_start_line(start_line)
self.assertEqual(parsed_start_line.method, self.METHOD)
self.assertEqual(parsed_start_line.path, self.PATH)
self.assertEqual(parsed_start_line.version, self.VERSION)

View file

@ -173,6 +173,25 @@ class TestIOLoop(AsyncTestCase):
self.io_loop.add_callback(lambda: self.io_loop.add_callback(self.stop))
self.wait()
def test_remove_timeout_from_timeout(self):
calls = [False, False]
# Schedule several callbacks and wait for them all to come due at once.
# t2 should be cancelled by t1, even though it is already scheduled to
# be run before the ioloop even looks at it.
now = self.io_loop.time()
def t1():
calls[0] = True
self.io_loop.remove_timeout(t2_handle)
self.io_loop.add_timeout(now + 0.01, t1)
def t2():
calls[1] = True
t2_handle = self.io_loop.add_timeout(now + 0.02, t2)
self.io_loop.add_timeout(now + 0.03, self.stop)
time.sleep(0.03)
self.wait()
self.assertEqual(calls, [True, False])
def test_timeout_with_arguments(self):
# This tests that all the timeout methods pass through *args correctly.
results = []
@ -185,6 +204,23 @@ class TestIOLoop(AsyncTestCase):
self.wait()
self.assertEqual(results, [1, 2, 3, 4])
def test_add_timeout_return(self):
# All the timeout methods return non-None handles that can be
# passed to remove_timeout.
handle = self.io_loop.add_timeout(self.io_loop.time(), lambda: None)
self.assertFalse(handle is None)
self.io_loop.remove_timeout(handle)
def test_call_at_return(self):
handle = self.io_loop.call_at(self.io_loop.time(), lambda: None)
self.assertFalse(handle is None)
self.io_loop.remove_timeout(handle)
def test_call_later_return(self):
handle = self.io_loop.call_later(0, lambda: None)
self.assertFalse(handle is None)
self.io_loop.remove_timeout(handle)
def test_close_file_object(self):
"""When a file object is used instead of a numeric file descriptor,
the object should be closed (by IOLoop.close(all_fds=True),
@ -298,6 +334,33 @@ class TestIOLoop(AsyncTestCase):
with ExpectLog(app_log, "Exception in callback"):
self.wait()
@skipIfNonUnix
def test_remove_handler_from_handler(self):
# Create two sockets with simultaneous read events.
client, server = socket.socketpair()
try:
client.send(b'abc')
server.send(b'abc')
# After reading from one fd, remove the other from the IOLoop.
chunks = []
def handle_read(fd, events):
chunks.append(fd.recv(1024))
if fd is client:
self.io_loop.remove_handler(server)
else:
self.io_loop.remove_handler(client)
self.io_loop.add_handler(client, handle_read, self.io_loop.READ)
self.io_loop.add_handler(server, handle_read, self.io_loop.READ)
self.io_loop.call_later(0.01, self.stop)
self.wait()
# Only one fd was read; the other was cleanly removed.
self.assertEqual(chunks, [b'abc'])
finally:
client.close()
server.close()
# Deliberately not a subclass of AsyncTestCase so the IOLoop isn't
# automatically set as current.

View file

@ -10,7 +10,7 @@ from tornado.stack_context import NullContext
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test
from tornado.test.util import unittest, skipIfNonUnix
from tornado.web import RequestHandler, Application
import certifi
import lib.certifi
import errno
import logging
import os
@ -511,7 +511,7 @@ class TestIOStreamMixin(object):
server, client = self.make_iostream_pair()
server.set_close_callback(self.stop)
try:
# Start a read that will be fullfilled asynchronously.
# Start a read that will be fulfilled asynchronously.
server.read_bytes(1, lambda data: None)
client.write(b'a')
# Stub out read_from_fd to make it fail.

View file

@ -57,3 +57,43 @@ class EnglishTest(unittest.TestCase):
date = datetime.datetime(2013, 4, 28, 18, 35)
self.assertEqual(locale.format_date(date, full_format=True),
'April 28, 2013 at 6:35 pm')
self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(seconds=2), full_format=False),
'2 seconds ago')
self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(minutes=2), full_format=False),
'2 minutes ago')
self.assertEqual(locale.format_date(datetime.datetime.utcnow() - datetime.timedelta(hours=2), full_format=False),
'2 hours ago')
now = datetime.datetime.utcnow()
self.assertEqual(locale.format_date(now - datetime.timedelta(days=1), full_format=False, shorter=True),
'yesterday')
date = now - datetime.timedelta(days=2)
self.assertEqual(locale.format_date(date, full_format=False, shorter=True),
locale._weekdays[date.weekday()])
date = now - datetime.timedelta(days=300)
self.assertEqual(locale.format_date(date, full_format=False, shorter=True),
'%s %d' % (locale._months[date.month - 1], date.day))
date = now - datetime.timedelta(days=500)
self.assertEqual(locale.format_date(date, full_format=False, shorter=True),
'%s %d, %d' % (locale._months[date.month - 1], date.day, date.year))
def test_friendly_number(self):
locale = tornado.locale.get('en_US')
self.assertEqual(locale.friendly_number(1000000), '1,000,000')
def test_list(self):
locale = tornado.locale.get('en_US')
self.assertEqual(locale.list([]), '')
self.assertEqual(locale.list(['A']), 'A')
self.assertEqual(locale.list(['A', 'B']), 'A and B')
self.assertEqual(locale.list(['A', 'B', 'C']), 'A, B and C')
def test_format_day(self):
locale = tornado.locale.get('en_US')
date = datetime.datetime(2013, 4, 28, 18, 35)
self.assertEqual(locale.format_day(date=date, dow=True), 'Sunday, April 28')
self.assertEqual(locale.format_day(date=date, dow=False), 'April 28')

View file

@ -29,7 +29,7 @@ from tornado.escape import utf8
from tornado.log import LogFormatter, define_logging_options, enable_pretty_logging
from tornado.options import OptionParser
from tornado.test.util import unittest
from tornado.util import u, bytes_type, basestring_type
from tornado.util import u, basestring_type
@contextlib.contextmanager
@ -95,8 +95,9 @@ class LogFormatterTest(unittest.TestCase):
self.assertEqual(self.get_output(), utf8(repr(b"\xe9")))
def test_utf8_logging(self):
with ignore_bytes_warning():
self.logger.error(u("\u00e9").encode("utf8"))
if issubclass(bytes_type, basestring_type):
if issubclass(bytes, basestring_type):
# on python 2, utf8 byte strings (and by extension ascii byte
# strings) are passed through as-is.
self.assertEqual(self.get_output(), utf8(u("\u00e9")))

View file

@ -34,15 +34,6 @@ else:
class _ResolverTestMixin(object):
def skipOnCares(self):
# Some DNS-hijacking ISPs (e.g. Time Warner) return non-empty results
# with an NXDOMAIN status code. Most resolvers treat this as an error;
# C-ares returns the results, making the "bad_host" tests unreliable.
# C-ares will try to resolve even malformed names, such as the
# name with spaces used in this test.
if self.resolver.__class__.__name__ == 'CaresResolver':
self.skipTest("CaresResolver doesn't recognize fake NXDOMAIN")
def test_localhost(self):
self.resolver.resolve('localhost', 80, callback=self.stop)
result = self.wait()
@ -55,8 +46,11 @@ class _ResolverTestMixin(object):
self.assertIn((socket.AF_INET, ('127.0.0.1', 80)),
addrinfo)
# It is impossible to quickly and consistently generate an error in name
# resolution, so test this case separately, using mocks as needed.
class _ResolverErrorTestMixin(object):
def test_bad_host(self):
self.skipOnCares()
def handler(exc_typ, exc_val, exc_tb):
self.stop(exc_val)
return True # Halt propagation.
@ -69,11 +63,13 @@ class _ResolverTestMixin(object):
@gen_test
def test_future_interface_bad_host(self):
self.skipOnCares()
with self.assertRaises(Exception):
yield self.resolver.resolve('an invalid domain', 80,
socket.AF_UNSPEC)
def _failing_getaddrinfo(*args):
"""Dummy implementation of getaddrinfo for use in mocks"""
raise socket.gaierror("mock: lookup failed")
@skipIfNoNetwork
class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
@ -82,6 +78,21 @@ class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
self.resolver = BlockingResolver(io_loop=self.io_loop)
# getaddrinfo-based tests need mocking to reliably generate errors;
# some configurations are slow to produce errors and take longer than
# our default timeout.
class BlockingResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
def setUp(self):
super(BlockingResolverErrorTest, self).setUp()
self.resolver = BlockingResolver(io_loop=self.io_loop)
self.real_getaddrinfo = socket.getaddrinfo
socket.getaddrinfo = _failing_getaddrinfo
def tearDown(self):
socket.getaddrinfo = self.real_getaddrinfo
super(BlockingResolverErrorTest, self).tearDown()
@skipIfNoNetwork
@unittest.skipIf(futures is None, "futures module not present")
class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
@ -94,6 +105,18 @@ class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
super(ThreadedResolverTest, self).tearDown()
class ThreadedResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
def setUp(self):
super(ThreadedResolverErrorTest, self).setUp()
self.resolver = BlockingResolver(io_loop=self.io_loop)
self.real_getaddrinfo = socket.getaddrinfo
socket.getaddrinfo = _failing_getaddrinfo
def tearDown(self):
socket.getaddrinfo = self.real_getaddrinfo
super(ThreadedResolverErrorTest, self).tearDown()
@skipIfNoNetwork
@unittest.skipIf(futures is None, "futures module not present")
@unittest.skipIf(sys.platform == 'win32', "preexec_fn not available on win32")
@ -121,6 +144,12 @@ class ThreadedResolverImportTest(unittest.TestCase):
self.fail("import timed out")
# We do not test errors with CaresResolver:
# Some DNS-hijacking ISPs (e.g. Time Warner) return non-empty results
# with an NXDOMAIN status code. Most resolvers treat this as an error;
# C-ares returns the results, making the "bad_host" tests unreliable.
# C-ares will try to resolve even malformed names, such as the
# name with spaces used in this test.
@skipIfNoNetwork
@unittest.skipIf(pycares is None, "pycares module not present")
class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
@ -129,10 +158,13 @@ class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
self.resolver = CaresResolver(io_loop=self.io_loop)
# TwistedResolver produces consistent errors in our test cases so we
# can test the regular and error cases in the same class.
@skipIfNoNetwork
@unittest.skipIf(twisted is None, "twisted module not present")
@unittest.skipIf(getattr(twisted, '__version__', '0.0') < "12.1", "old version of twisted")
class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin):
class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin,
_ResolverErrorTestMixin):
def setUp(self):
super(TwistedResolverTest, self).setUp()
self.resolver = TwistedResolver(io_loop=self.io_loop)

View file

@ -1,2 +1,3 @@
port=443
port=443
username='李康'

View file

@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, with_statement
import datetime
@ -32,9 +33,11 @@ class OptionsTest(unittest.TestCase):
def test_parse_config_file(self):
options = OptionParser()
options.define("port", default=80)
options.define("username", default='foo')
options.parse_config_file(os.path.join(os.path.dirname(__file__),
"options_test.cfg"))
self.assertEquals(options.port, 443)
self.assertEqual(options.username, "李康")
def test_parse_callbacks(self):
options = OptionParser()

View file

@ -14,7 +14,7 @@ from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httputil import HTTPHeaders
from tornado.ioloop import IOLoop
from tornado.log import gen_log, app_log
from tornado.log import gen_log
from tornado.netutil import Resolver, bind_sockets
from tornado.simple_httpclient import SimpleAsyncHTTPClient, _default_ca_certs
from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler
@ -294,10 +294,13 @@ class SimpleHTTPClientTestMixin(object):
self.assertEqual(response.code, 204)
# 204 status doesn't need a content-length, but tornado will
# add a zero content-length anyway.
#
# A test without a content-length header is included below
# in HTTP204NoContentTestCase.
self.assertEqual(response.headers["Content-length"], "0")
# 204 status with non-zero content length is malformed
with ExpectLog(app_log, "Uncaught exception"):
with ExpectLog(gen_log, "Malformed HTTP message"):
response = self.fetch("/no_content?error=1")
self.assertEqual(response.code, 599)
@ -476,6 +479,27 @@ class HTTP100ContinueTestCase(AsyncHTTPTestCase):
self.assertEqual(res.body, b'A')
class HTTP204NoContentTestCase(AsyncHTTPTestCase):
def respond_204(self, request):
# A 204 response never has a body, even if doesn't have a content-length
# (which would otherwise mean read-until-close). Tornado always
# sends a content-length, so we simulate here a server that sends
# no content length and does not close the connection.
#
# Tests of a 204 response with a Content-Length header are included
# in SimpleHTTPClientTestMixin.
request.connection.stream.write(
b"HTTP/1.1 204 No content\r\n\r\n")
def get_app(self):
return self.respond_204
def test_204_no_content(self):
resp = self.fetch('/')
self.assertEqual(resp.code, 204)
self.assertEqual(resp.body, b'')
class HostnameMappingTestCase(AsyncHTTPTestCase):
def setUp(self):
super(HostnameMappingTestCase, self).setUp()

View file

@ -72,7 +72,9 @@ class TCPClientTest(AsyncTestCase):
super(TCPClientTest, self).tearDown()
def skipIfLocalhostV4(self):
Resolver().resolve('localhost', 0, callback=self.stop)
# The port used here doesn't matter, but some systems require it
# to be non-zero if we do not also pass AI_PASSIVE.
Resolver().resolve('localhost', 80, callback=self.stop)
addrinfo = self.wait()
families = set(addr[0] for addr in addrinfo)
if socket.AF_INET6 not in families:

View file

@ -7,7 +7,7 @@ import traceback
from tornado.escape import utf8, native_str, to_unicode
from tornado.template import Template, DictLoader, ParseError, Loader
from tornado.test.util import unittest
from tornado.util import u, bytes_type, ObjectDict, unicode_type
from tornado.util import u, ObjectDict, unicode_type
class TemplateTest(unittest.TestCase):
@ -374,7 +374,7 @@ raw: {% raw name %}""",
"{% autoescape py_escape %}s = {{ name }}\n"})
def py_escape(s):
self.assertEqual(type(s), bytes_type)
self.assertEqual(type(s), bytes)
return repr(native_str(s))
def render(template, name):

View file

@ -3,7 +3,8 @@
from __future__ import absolute_import, division, print_function, with_statement
from tornado import gen, ioloop
from tornado.testing import AsyncTestCase, gen_test
from tornado.log import app_log
from tornado.testing import AsyncTestCase, gen_test, ExpectLog
from tornado.test.util import unittest
import contextlib
@ -13,7 +14,7 @@ import traceback
@contextlib.contextmanager
def set_environ(name, value):
old_value = os.environ.get('name')
old_value = os.environ.get(name)
os.environ[name] = value
try:
@ -62,6 +63,17 @@ class AsyncTestCaseTest(AsyncTestCase):
self.io_loop.add_timeout(self.io_loop.time() + 0.03, self.stop)
self.wait(timeout=0.15)
def test_multiple_errors(self):
def fail(message):
raise Exception(message)
self.io_loop.add_callback(lambda: fail("error one"))
self.io_loop.add_callback(lambda: fail("error two"))
# The first error gets raised; the second gets logged.
with ExpectLog(app_log, "multiple unhandled exceptions"):
with self.assertRaises(Exception) as cm:
self.wait()
self.assertEqual(str(cm.exception), "error one")
class AsyncTestCaseWrapperTest(unittest.TestCase):
def test_undecorated_generator(self):

View file

@ -1,9 +1,10 @@
# coding: utf-8
from __future__ import absolute_import, division, print_function, with_statement
import sys
import datetime
from tornado.escape import utf8
from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer
from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer, timedelta_to_seconds
from tornado.test.util import unittest
try:
@ -170,3 +171,9 @@ class ArgReplacerTest(unittest.TestCase):
self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
self.assertEqual(self.replacer.replace('new', args, kwargs),
('old', (1,), dict(y=2, callback='new', z=3)))
class TimedeltaToSecondsTest(unittest.TestCase):
def test_timedelta_to_seconds(self):
time_delta = datetime.timedelta(hours=1)
self.assertEqual(timedelta_to_seconds(time_delta), 3600.0)

View file

@ -4,13 +4,14 @@ from tornado import gen
from tornado.escape import json_decode, utf8, to_unicode, recursive_unicode, native_str, to_basestring
from tornado.httputil import format_timestamp
from tornado.iostream import IOStream
from tornado import locale
from tornado.log import app_log, gen_log
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.template import DictLoader
from tornado.testing import AsyncHTTPTestCase, ExpectLog, gen_test
from tornado.test.util import unittest
from tornado.util import u, bytes_type, ObjectDict, unicode_type
from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish
from tornado.util import u, ObjectDict, unicode_type, timedelta_to_seconds
from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler
import binascii
import contextlib
@ -21,7 +22,6 @@ import logging
import os
import re
import socket
import sys
try:
import urllib.parse as urllib_parse # py3
@ -163,11 +163,21 @@ class CookieTest(WebTestCase):
# Attributes from the first call are not carried over.
self.set_cookie("a", "e")
class SetCookieMaxAgeHandler(RequestHandler):
def get(self):
self.set_cookie("foo", "bar", max_age=10)
class SetCookieExpiresDaysHandler(RequestHandler):
def get(self):
self.set_cookie("foo", "bar", expires_days=10)
return [("/set", SetCookieHandler),
("/get", GetCookieHandler),
("/set_domain", SetCookieDomainHandler),
("/special_char", SetCookieSpecialCharHandler),
("/set_overwrite", SetCookieOverwriteHandler),
("/set_max_age", SetCookieMaxAgeHandler),
("/set_expires_days", SetCookieExpiresDaysHandler),
]
def test_set_cookie(self):
@ -222,6 +232,23 @@ class CookieTest(WebTestCase):
self.assertEqual(sorted(headers),
["a=e; Path=/", "c=d; Domain=example.com; Path=/"])
def test_set_cookie_max_age(self):
response = self.fetch("/set_max_age")
headers = response.headers.get_list("Set-Cookie")
self.assertEqual(sorted(headers),
["foo=bar; Max-Age=10; Path=/"])
def test_set_cookie_expires_days(self):
response = self.fetch("/set_expires_days")
header = response.headers.get("Set-Cookie")
match = re.match("foo=bar; expires=(?P<expires>.+); Path=/", header)
self.assertIsNotNone(match)
expires = datetime.datetime.utcnow() + datetime.timedelta(days=10)
header_expires = datetime.datetime(
*email.utils.parsedate(match.groupdict()["expires"])[:6])
self.assertTrue(abs(timedelta_to_seconds(expires - header_expires)) < 10)
class AuthRedirectRequestHandler(RequestHandler):
def initialize(self, login_url):
@ -302,7 +329,7 @@ class EchoHandler(RequestHandler):
if type(key) != str:
raise Exception("incorrect type for key: %r" % type(key))
for value in self.request.arguments[key]:
if type(value) != bytes_type:
if type(value) != bytes:
raise Exception("incorrect type for value: %r" %
type(value))
for value in self.get_arguments(key):
@ -370,10 +397,10 @@ class TypeCheckHandler(RequestHandler):
if list(self.cookies.keys()) != ['asdf']:
raise Exception("unexpected values for cookie keys: %r" %
self.cookies.keys())
self.check_type('get_secure_cookie', self.get_secure_cookie('asdf'), bytes_type)
self.check_type('get_secure_cookie', self.get_secure_cookie('asdf'), bytes)
self.check_type('get_cookie', self.get_cookie('asdf'), str)
self.check_type('xsrf_token', self.xsrf_token, bytes_type)
self.check_type('xsrf_token', self.xsrf_token, bytes)
self.check_type('xsrf_form_html', self.xsrf_form_html(), str)
self.check_type('reverse_url', self.reverse_url('typecheck', 'foo'), str)
@ -399,7 +426,7 @@ class TypeCheckHandler(RequestHandler):
class DecodeArgHandler(RequestHandler):
def decode_argument(self, value, name=None):
if type(value) != bytes_type:
if type(value) != bytes:
raise Exception("unexpected type for value: %r" % type(value))
# use self.request.arguments directly to avoid recursion
if 'encoding' in self.request.arguments:
@ -409,7 +436,7 @@ class DecodeArgHandler(RequestHandler):
def get(self, arg):
def describe(s):
if type(s) == bytes_type:
if type(s) == bytes:
return ["bytes", native_str(binascii.b2a_hex(s))]
elif type(s) == unicode_type:
return ["unicode", s]
@ -550,6 +577,8 @@ class WSGISafeWebTest(WebTestCase):
url("/optional_path/(.+)?", OptionalPathHandler),
url("/multi_header", MultiHeaderHandler),
url("/redirect", RedirectHandler),
url("/web_redirect_permanent", WebRedirectHandler, {"url": "/web_redirect_newpath"}),
url("/web_redirect", WebRedirectHandler, {"url": "/web_redirect_newpath", "permanent": False}),
url("/header_injection", HeaderInjectionHandler),
url("/get_argument", GetArgumentHandler),
url("/get_arguments", GetArgumentsHandler),
@ -675,6 +704,14 @@ js_embed()
response = self.fetch("/redirect?status=307", follow_redirects=False)
self.assertEqual(response.code, 307)
def test_web_redirect(self):
response = self.fetch("/web_redirect_permanent", follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], '/web_redirect_newpath')
response = self.fetch("/web_redirect", follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertEqual(response.headers['Location'], '/web_redirect_newpath')
def test_header_injection(self):
response = self.fetch("/header_injection")
self.assertEqual(response.body, b"ok")
@ -1348,7 +1385,9 @@ class GzipTestCase(SimpleHandlerTestCase):
self.write('hello world')
def get_app_kwargs(self):
return dict(gzip=True)
return dict(
gzip=True,
static_path=os.path.join(os.path.dirname(__file__), 'static'))
def test_gzip(self):
response = self.fetch('/')
@ -1361,6 +1400,17 @@ class GzipTestCase(SimpleHandlerTestCase):
'gzip')
self.assertEqual(response.headers['Vary'], 'Accept-Encoding')
def test_gzip_static(self):
# The streaming responses in StaticFileHandler have subtle
# interactions with the gzip output so test this case separately.
response = self.fetch('/robots.txt')
self.assertEqual(
response.headers.get(
'Content-Encoding',
response.headers.get('X-Consumed-Content-Encoding')),
'gzip')
self.assertEqual(response.headers['Vary'], 'Accept-Encoding')
def test_gzip_not_requested(self):
response = self.fetch('/', use_gzip=False)
self.assertNotIn('Content-Encoding', response.headers)
@ -1554,19 +1604,26 @@ class MultipleExceptionTest(SimpleHandlerTestCase):
@wsgi_safe
class SetCurrentUserTest(SimpleHandlerTestCase):
class SetLazyPropertiesTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def prepare(self):
self.current_user = 'Ben'
self.locale = locale.get('en_US')
def get_user_locale(self):
raise NotImplementedError()
def get_current_user(self):
raise NotImplementedError()
def get(self):
self.write('Hello %s' % self.current_user)
self.write('Hello %s (%s)' % (self.current_user, self.locale.code))
def test_set_current_user(self):
def test_set_properties(self):
# Ensure that current_user can be assigned to normally for apps
# that want to forgo the lazy get_current_user property
response = self.fetch('/')
self.assertEqual(response.body, b'Hello Ben')
self.assertEqual(response.body, b'Hello Ben (en_US)')
@wsgi_safe
@ -2193,6 +2250,20 @@ class XSRFTest(SimpleHandlerTestCase):
headers=self.cookie_headers())
self.assertEqual(response.code, 403)
def test_xsrf_success_short_token(self):
response = self.fetch(
"/", method="POST",
body=urllib_parse.urlencode(dict(_xsrf='deadbeef')),
headers=self.cookie_headers(token='deadbeef'))
self.assertEqual(response.code, 200)
def test_xsrf_success_non_hex_token(self):
response = self.fetch(
"/", method="POST",
body=urllib_parse.urlencode(dict(_xsrf='xoxo')),
headers=self.cookie_headers(token='xoxo'))
self.assertEqual(response.code, 200)
def test_xsrf_success_post_body(self):
response = self.fetch(
"/", method="POST",
@ -2299,3 +2370,38 @@ class FinishExceptionTest(SimpleHandlerTestCase):
self.assertEqual('Basic realm="something"',
response.headers.get('WWW-Authenticate'))
self.assertEqual(b'authentication required', response.body)
class DecoratorTest(WebTestCase):
def get_handlers(self):
class RemoveSlashHandler(RequestHandler):
@removeslash
def get(self):
pass
class AddSlashHandler(RequestHandler):
@addslash
def get(self):
pass
return [("/removeslash/", RemoveSlashHandler),
("/addslash", AddSlashHandler),
]
def test_removeslash(self):
response = self.fetch("/removeslash/", follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], "/removeslash")
response = self.fetch("/removeslash/?foo=bar", follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], "/removeslash?foo=bar")
def test_addslash(self):
response = self.fetch("/addslash", follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], "/addslash/")
response = self.fetch("/addslash?foo=bar", follow_redirects=False)
self.assertEqual(response.code, 301)
self.assertEqual(response.headers['Location'], "/addslash/?foo=bar")

View file

@ -3,11 +3,13 @@ from __future__ import absolute_import, division, print_function, with_statement
import traceback
from tornado.concurrent import Future
from tornado import gen
from tornado.httpclient import HTTPError, HTTPRequest
from tornado.log import gen_log
from tornado.log import gen_log, app_log
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
from tornado.test.util import unittest
from tornado.web import Application, RequestHandler
from tornado.util import u
try:
import tornado.websocket
@ -33,8 +35,12 @@ class TestWebSocketHandler(WebSocketHandler):
This allows for deterministic cleanup of the associated socket.
"""
def initialize(self, close_future):
def initialize(self, close_future, compression_options=None):
self.close_future = close_future
self.compression_options = compression_options
def get_compression_options(self):
return self.compression_options
def on_close(self):
self.close_future.set_result((self.close_code, self.close_reason))
@ -45,6 +51,11 @@ class EchoHandler(TestWebSocketHandler):
self.write_message(message, isinstance(message, bytes))
class ErrorInOnMessageHandler(TestWebSocketHandler):
def on_message(self, message):
1/0
class HeaderHandler(TestWebSocketHandler):
def open(self):
try:
@ -67,7 +78,34 @@ class CloseReasonHandler(TestWebSocketHandler):
self.close(1001, "goodbye")
class WebSocketTest(AsyncHTTPTestCase):
class AsyncPrepareHandler(TestWebSocketHandler):
@gen.coroutine
def prepare(self):
yield gen.moment
def on_message(self, message):
self.write_message(message)
class WebSocketBaseTestCase(AsyncHTTPTestCase):
@gen.coroutine
def ws_connect(self, path, compression_options=None):
ws = yield websocket_connect(
'ws://localhost:%d%s' % (self.get_http_port(), path),
compression_options=compression_options)
raise gen.Return(ws)
@gen.coroutine
def close(self, ws):
"""Close a websocket connection and wait for the server side.
If we don't wait here, there are sometimes leak warnings in the
tests.
"""
ws.close()
yield self.close_future
class WebSocketTest(WebSocketBaseTestCase):
def get_app(self):
self.close_future = Future()
return Application([
@ -76,6 +114,10 @@ class WebSocketTest(AsyncHTTPTestCase):
('/header', HeaderHandler, dict(close_future=self.close_future)),
('/close_reason', CloseReasonHandler,
dict(close_future=self.close_future)),
('/error_in_on_message', ErrorInOnMessageHandler,
dict(close_future=self.close_future)),
('/async_prepare', AsyncPrepareHandler,
dict(close_future=self.close_future)),
])
def test_http_request(self):
@ -85,14 +127,11 @@ class WebSocketTest(AsyncHTTPTestCase):
@gen_test
def test_websocket_gen(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port(),
io_loop=self.io_loop)
ws = yield self.ws_connect('/echo')
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
yield self.close(ws)
def test_websocket_callbacks(self):
websocket_connect(
@ -107,20 +146,41 @@ class WebSocketTest(AsyncHTTPTestCase):
ws.close()
self.wait()
@gen_test
def test_binary_message(self):
ws = yield self.ws_connect('/echo')
ws.write_message(b'hello \xe9', binary=True)
response = yield ws.read_message()
self.assertEqual(response, b'hello \xe9')
yield self.close(ws)
@gen_test
def test_unicode_message(self):
ws = yield self.ws_connect('/echo')
ws.write_message(u('hello \u00e9'))
response = yield ws.read_message()
self.assertEqual(response, u('hello \u00e9'))
yield self.close(ws)
@gen_test
def test_error_in_on_message(self):
ws = yield self.ws_connect('/error_in_on_message')
ws.write_message('hello')
with ExpectLog(app_log, "Uncaught exception"):
response = yield ws.read_message()
self.assertIs(response, None)
yield self.close(ws)
@gen_test
def test_websocket_http_fail(self):
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(
'ws://localhost:%d/notfound' % self.get_http_port(),
io_loop=self.io_loop)
yield self.ws_connect('/notfound')
self.assertEqual(cm.exception.code, 404)
@gen_test
def test_websocket_http_success(self):
with self.assertRaises(WebSocketError):
yield websocket_connect(
'ws://localhost:%d/non_ws' % self.get_http_port(),
io_loop=self.io_loop)
yield self.ws_connect('/non_ws')
@gen_test
def test_websocket_network_fail(self):
@ -139,6 +199,7 @@ class WebSocketTest(AsyncHTTPTestCase):
'ws://localhost:%d/echo' % self.get_http_port())
ws.write_message('hello')
ws.write_message('world')
# Close the underlying stream.
ws.stream.close()
yield self.close_future
@ -150,13 +211,11 @@ class WebSocketTest(AsyncHTTPTestCase):
headers={'X-Test': 'hello'}))
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
yield self.close(ws)
@gen_test
def test_server_close_reason(self):
ws = yield websocket_connect(
'ws://localhost:%d/close_reason' % self.get_http_port())
ws = yield self.ws_connect('/close_reason')
msg = yield ws.read_message()
# A message of None means the other side closed the connection.
self.assertIs(msg, None)
@ -165,13 +224,21 @@ class WebSocketTest(AsyncHTTPTestCase):
@gen_test
def test_client_close_reason(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port())
ws = yield self.ws_connect('/echo')
ws.close(1001, 'goodbye')
code, reason = yield self.close_future
self.assertEqual(code, 1001)
self.assertEqual(reason, 'goodbye')
@gen_test
def test_async_prepare(self):
# Previously, an async prepare method triggered a bug that would
# result in a timeout on test shutdown (and a memory leak).
ws = yield self.ws_connect('/async_prepare')
ws.write_message('hello')
res = yield ws.read_message()
self.assertEqual(res, 'hello')
@gen_test
def test_check_origin_valid_no_path(self):
port = self.get_http_port()
@ -184,8 +251,7 @@ class WebSocketTest(AsyncHTTPTestCase):
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
yield self.close(ws)
@gen_test
def test_check_origin_valid_with_path(self):
@ -199,8 +265,7 @@ class WebSocketTest(AsyncHTTPTestCase):
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
ws.close()
yield self.close_future
yield self.close(ws)
@gen_test
def test_check_origin_invalid_partial_url(self):
@ -245,6 +310,78 @@ class WebSocketTest(AsyncHTTPTestCase):
self.assertEqual(cm.exception.code, 403)
class CompressionTestMixin(object):
MESSAGE = 'Hello world. Testing 123 123'
def get_app(self):
self.close_future = Future()
return Application([
('/echo', EchoHandler, dict(
close_future=self.close_future,
compression_options=self.get_server_compression_options())),
])
def get_server_compression_options(self):
return None
def get_client_compression_options(self):
return None
@gen_test
def test_message_sizes(self):
ws = yield self.ws_connect(
'/echo',
compression_options=self.get_client_compression_options())
# Send the same message three times so we can measure the
# effect of the context_takeover options.
for i in range(3):
ws.write_message(self.MESSAGE)
response = yield ws.read_message()
self.assertEqual(response, self.MESSAGE)
self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3)
self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3)
self.verify_wire_bytes(ws.protocol._wire_bytes_in,
ws.protocol._wire_bytes_out)
yield self.close(ws)
class UncompressedTestMixin(CompressionTestMixin):
"""Specialization of CompressionTestMixin when we expect no compression."""
def verify_wire_bytes(self, bytes_in, bytes_out):
# Bytes out includes the 4-byte mask key per message.
self.assertEqual(bytes_out, 3 * (len(self.MESSAGE) + 6))
self.assertEqual(bytes_in, 3 * (len(self.MESSAGE) + 2))
class NoCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
pass
# If only one side tries to compress, the extension is not negotiated.
class ServerOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
def get_server_compression_options(self):
return {}
class ClientOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
def get_client_compression_options(self):
return {}
class DefaultCompressionTest(CompressionTestMixin, WebSocketBaseTestCase):
def get_server_compression_options(self):
return {}
def get_client_compression_options(self):
return {}
def verify_wire_bytes(self, bytes_in, bytes_out):
self.assertLess(bytes_out, 3 * (len(self.MESSAGE) + 6))
self.assertLess(bytes_in, 3 * (len(self.MESSAGE) + 2))
# Bytes out includes the 4 bytes mask key per message.
self.assertEqual(bytes_out, bytes_in + 12)
class MaskFunctionMixin(object):
# Subclasses should define self.mask(mask, data)
def test_mask(self):

View file

@ -28,7 +28,7 @@ except ImportError:
IOLoop = None
netutil = None
SimpleAsyncHTTPClient = None
from tornado.log import gen_log
from tornado.log import gen_log, app_log
from tornado.stack_context import ExceptionStackContext
from tornado.util import raise_exc_info, basestring_type
import functools
@ -114,8 +114,8 @@ class _TestMethodWrapper(object):
def __init__(self, orig_method):
self.orig_method = orig_method
def __call__(self):
result = self.orig_method()
def __call__(self, *args, **kwargs):
result = self.orig_method(*args, **kwargs)
if isinstance(result, types.GeneratorType):
raise TypeError("Generator test methods should be decorated with "
"tornado.testing.gen_test")
@ -237,7 +237,11 @@ class AsyncTestCase(unittest.TestCase):
return IOLoop()
def _handle_exception(self, typ, value, tb):
if self.__failure is None:
self.__failure = (typ, value, tb)
else:
app_log.error("multiple unhandled exceptions in test",
exc_info=(typ, value, tb))
self.stop()
return True
@ -395,7 +399,8 @@ class AsyncHTTPTestCase(AsyncTestCase):
def tearDown(self):
self.http_server.stop()
self.io_loop.run_sync(self.http_server.close_all_connections)
self.io_loop.run_sync(self.http_server.close_all_connections,
timeout=get_async_test_timeout())
if (not IOLoop.initialized() or
self.http_client.io_loop is not IOLoop.instance()):
self.http_client.close()

View file

@ -115,16 +115,17 @@ def import_object(name):
if type('') is not type(b''):
def u(s):
return s
bytes_type = bytes
unicode_type = str
basestring_type = str
else:
def u(s):
return s.decode('unicode_escape')
bytes_type = str
unicode_type = unicode
basestring_type = basestring
# Deprecated alias that was used before we dropped py25 support.
# Left here in case anyone outside Tornado is using it.
bytes_type = bytes
if sys.version_info > (3,):
exec("""
@ -154,7 +155,7 @@ def errno_from_exception(e):
"""Provides the errno from an Exception object.
There are cases that the errno attribute was not set so we pull
the errno out of the args but if someone instatiates an Exception
the errno out of the args but if someone instantiates an Exception
without any args you will get a tuple error. So this function
abstracts all that behavior to give you a safe way to get the
errno.
@ -202,7 +203,7 @@ class Configurable(object):
impl = cls
args.update(kwargs)
instance = super(Configurable, cls).__new__(impl)
# initialize vs __init__ chosen for compatiblity with AsyncHTTPClient
# initialize vs __init__ chosen for compatibility with AsyncHTTPClient
# singleton magic. If we get rid of that we can switch to __init__
# here too.
instance.initialize(**args)
@ -237,7 +238,7 @@ class Configurable(object):
some parameters.
"""
base = cls.configurable_base()
if isinstance(impl, (unicode_type, bytes_type)):
if isinstance(impl, (unicode_type, bytes)):
impl = import_object(impl)
if impl is not None and not issubclass(impl, cls):
raise ValueError("Invalid subclass of %s" % cls)

View file

@ -35,8 +35,7 @@ Here is a simple "Hello, world" example app::
application.listen(8888)
tornado.ioloop.IOLoop.instance().start()
See the :doc:`Tornado overview <overview>` for more details and a good getting
started guide.
See the :doc:`guide` for additional information.
Thread-safety notes
-------------------
@ -48,6 +47,7 @@ not thread-safe. In particular, methods such as
you use multiple threads it is important to use `.IOLoop.add_callback`
to transfer control back to the main thread before finishing the
request.
"""
from __future__ import absolute_import, division, print_function, with_statement
@ -72,6 +72,7 @@ import time
import tornado
import traceback
import types
from io import BytesIO
from tornado.concurrent import Future, is_future
from tornado import escape
@ -83,12 +84,8 @@ from tornado.log import access_log, app_log, gen_log
from tornado import stack_context
from tornado import template
from tornado.escape import utf8, _unicode
from tornado.util import bytes_type, import_object, ObjectDict, raise_exc_info, unicode_type, _websocket_mask
from tornado.util import import_object, ObjectDict, raise_exc_info, unicode_type, _websocket_mask
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
try:
import Cookie # py2
@ -344,7 +341,7 @@ class RequestHandler(object):
_INVALID_HEADER_CHAR_RE = re.compile(br"[\x00-\x1f]")
def _convert_header_value(self, value):
if isinstance(value, bytes_type):
if isinstance(value, bytes):
pass
elif isinstance(value, unicode_type):
value = value.encode('utf-8')
@ -652,7 +649,7 @@ class RequestHandler(object):
raise RuntimeError("Cannot write() after finish(). May be caused "
"by using async operations without the "
"@asynchronous decorator.")
if not isinstance(chunk, (bytes_type, unicode_type, dict)):
if not isinstance(chunk, (bytes, unicode_type, dict)):
raise TypeError("write() only accepts bytes, unicode, and dict objects")
if isinstance(chunk, dict):
chunk = escape.json_encode(chunk)
@ -677,7 +674,7 @@ class RequestHandler(object):
js_embed.append(utf8(embed_part))
file_part = module.javascript_files()
if file_part:
if isinstance(file_part, (unicode_type, bytes_type)):
if isinstance(file_part, (unicode_type, bytes)):
js_files.append(file_part)
else:
js_files.extend(file_part)
@ -686,7 +683,7 @@ class RequestHandler(object):
css_embed.append(utf8(embed_part))
file_part = module.css_files()
if file_part:
if isinstance(file_part, (unicode_type, bytes_type)):
if isinstance(file_part, (unicode_type, bytes)):
css_files.append(file_part)
else:
css_files.extend(file_part)
@ -919,7 +916,7 @@ class RequestHandler(object):
return
self.clear()
reason = None
reason = kwargs.get('reason')
if 'exc_info' in kwargs:
exception = kwargs['exc_info'][1]
if isinstance(exception, HTTPError) and exception.reason:
@ -959,12 +956,15 @@ class RequestHandler(object):
@property
def locale(self):
"""The local for the current session.
"""The locale for the current session.
Determined by either `get_user_locale`, which you can override to
set the locale based on, e.g., a user preference stored in a
database, or `get_browser_locale`, which uses the ``Accept-Language``
header.
.. versionchanged: 4.1
Added a property setter.
"""
if not hasattr(self, "_locale"):
self._locale = self.get_user_locale()
@ -973,6 +973,10 @@ class RequestHandler(object):
assert self._locale
return self._locale
@locale.setter
def locale(self, value):
self._locale = value
def get_user_locale(self):
"""Override to determine the locale from the authenticated user.
@ -1128,14 +1132,15 @@ class RequestHandler(object):
else:
# Treat unknown versions as not present instead of failing.
return None, None, None
elif len(cookie) == 32:
else:
version = 1
try:
token = binascii.a2b_hex(utf8(cookie))
except (binascii.Error, TypeError):
token = utf8(cookie)
# We don't have a usable timestamp in older versions.
timestamp = int(time.time())
return (version, token, timestamp)
else:
return None, None, None
def check_xsrf_cookie(self):
"""Verifies that the ``_xsrf`` cookie matches the ``_xsrf`` argument.
@ -1627,7 +1632,7 @@ class Application(httputil.HTTPServerConnectionDelegate):
**settings):
if transforms is None:
self.transforms = []
if settings.get("gzip"):
if settings.get("compress_response") or settings.get("gzip"):
self.transforms.append(GZipContentEncoding)
else:
self.transforms = transforms
@ -2164,11 +2169,14 @@ class StaticFileHandler(RequestHandler):
if include_body:
content = self.get_content(self.absolute_path, start, end)
if isinstance(content, bytes_type):
if isinstance(content, bytes):
content = [content]
for chunk in content:
try:
self.write(chunk)
yield self.flush()
except iostream.StreamClosedError:
return
else:
assert self.request.method == "HEAD"
@ -2335,7 +2343,7 @@ class StaticFileHandler(RequestHandler):
"""
data = cls.get_content(abspath)
hasher = hashlib.md5()
if isinstance(data, bytes_type):
if isinstance(data, bytes):
hasher.update(data)
else:
for chunk in data:
@ -2547,7 +2555,6 @@ class GZipContentEncoding(OutputTransform):
ctype = _unicode(headers.get("Content-Type", "")).split(";")[0]
self._gzipping = self._compressible_type(ctype) and \
(not finishing or len(chunk) >= self.MIN_LENGTH) and \
(finishing or "Content-Length" not in headers) and \
("Content-Encoding" not in headers)
if self._gzipping:
headers["Content-Encoding"] = "gzip"
@ -2555,7 +2562,14 @@ class GZipContentEncoding(OutputTransform):
self._gzip_file = gzip.GzipFile(mode="w", fileobj=self._gzip_value)
chunk = self.transform_chunk(chunk, finishing)
if "Content-Length" in headers:
# The original content length is no longer correct.
# If this is the last (and only) chunk, we can set the new
# content-length; otherwise we remove it and fall back to
# chunked encoding.
if finishing:
headers["Content-Length"] = str(len(chunk))
else:
del headers["Content-Length"]
return status_code, headers, chunk
def transform_chunk(self, chunk, finishing):
@ -2704,7 +2718,7 @@ class TemplateModule(UIModule):
def javascript_files(self):
result = []
for f in self._get_resources("javascript_files"):
if isinstance(f, (unicode_type, bytes_type)):
if isinstance(f, (unicode_type, bytes)):
result.append(f)
else:
result.extend(f)
@ -2716,7 +2730,7 @@ class TemplateModule(UIModule):
def css_files(self):
result = []
for f in self._get_resources("css_files"):
if isinstance(f, (unicode_type, bytes_type)):
if isinstance(f, (unicode_type, bytes)):
result.append(f)
else:
result.extend(f)
@ -2754,7 +2768,7 @@ class URLSpec(object):
in the regex will be passed in to the handler's get/post/etc
methods as arguments.
* ``handler_class``: `RequestHandler` subclass to be invoked.
* ``handler``: `RequestHandler` subclass to be invoked.
* ``kwargs`` (optional): A dictionary of additional arguments
to be passed to the handler's constructor.
@ -2821,7 +2835,7 @@ class URLSpec(object):
return self._path
converted_args = []
for a in args:
if not isinstance(a, (unicode_type, bytes_type)):
if not isinstance(a, (unicode_type, bytes)):
a = str(a)
converted_args.append(escape.url_escape(utf8(a), plus=False))
return self._path % tuple(converted_args)

View file

@ -26,6 +26,7 @@ import os
import struct
import tornado.escape
import tornado.web
import zlib
from tornado.concurrent import TracebackFuture
from tornado.escape import utf8, native_str, to_unicode
@ -35,7 +36,7 @@ from tornado.iostream import StreamClosedError
from tornado.log import gen_log, app_log
from tornado import simple_httpclient
from tornado.tcpclient import TCPClient
from tornado.util import bytes_type, _websocket_mask
from tornado.util import _websocket_mask
try:
from urllib.parse import urlparse # py2
@ -105,6 +106,21 @@ class WebSocketHandler(tornado.web.RequestHandler):
};
This script pops up an alert box that says "You said: Hello, world".
Web browsers allow any site to open a websocket connection to any other,
instead of using the same-origin policy that governs other network
access from javascript. This can be surprising and is a potential
security hole, so since Tornado 4.0 `WebSocketHandler` requires
applications that wish to receive cross-origin websockets to opt in
by overriding the `~WebSocketHandler.check_origin` method (see that
method's docs for details). Failure to do so is the most likely
cause of 403 errors when making a websocket connection.
When using a secure websocket connection (``wss://``) with a self-signed
certificate, the connection from a browser may fail because it wants
to show the "accept this certificate" dialog but has nowhere to show it.
You must first visit a regular HTML page using the same certificate
to accept it before the websocket connection will succeed.
"""
def __init__(self, application, request, **kwargs):
tornado.web.RequestHandler.__init__(self, application, request,
@ -156,9 +172,11 @@ class WebSocketHandler(tornado.web.RequestHandler):
self.stream.set_close_callback(self.on_connection_close)
if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
self.ws_connection = WebSocketProtocol13(self)
self.ws_connection = WebSocketProtocol13(
self, compression_options=self.get_compression_options())
self.ws_connection.accept_connection()
else:
if not self.stream.closed():
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 426 Upgrade Required\r\n"
"Sec-WebSocket-Version: 8\r\n\r\n"))
@ -198,6 +216,19 @@ class WebSocketHandler(tornado.web.RequestHandler):
"""
return None
def get_compression_options(self):
"""Override to return compression options for the connection.
If this method returns None (the default), compression will
be disabled. If it returns a dict (even an empty one), it
will be enabled. The contents of the dict may be used to
control the memory and CPU usage of the compression,
but no such options are currently implemented.
.. versionadded:: 4.1
"""
return None
def open(self):
"""Invoked when a new WebSocket is opened.
@ -275,6 +306,19 @@ class WebSocketHandler(tornado.web.RequestHandler):
browsers, since WebSockets are allowed to bypass the usual same-origin
policies and don't use CORS headers.
To accept all cross-origin traffic (which was the default prior to
Tornado 4.0), simply override this method to always return true::
def check_origin(self, origin):
return True
To allow connections from any subdomain of your site, you might
do something like::
def check_origin(self, origin):
parsed_origin = urllib.parse.urlparse(origin)
return parsed_origin.netloc.endswith(".mydomain.com")
.. versionadded:: 4.0
"""
parsed_origin = urlparse(origin)
@ -308,6 +352,15 @@ class WebSocketHandler(tornado.web.RequestHandler):
self.ws_connection = None
self.on_close()
def send_error(self, *args, **kwargs):
if self.stream is None:
super(WebSocketHandler, self).send_error(*args, **kwargs)
else:
# If we get an uncaught exception during the handshake,
# we have no choice but to abruptly close the connection.
# TODO: for uncaught exceptions after the handshake,
# we can close the connection more gracefully.
self.stream.close()
def _wrap_method(method):
def _disallow_for_websocket(self, *args, **kwargs):
@ -316,7 +369,7 @@ def _wrap_method(method):
else:
raise RuntimeError("Method not supported for Web Sockets")
return _disallow_for_websocket
for method in ["write", "redirect", "set_header", "send_error", "set_cookie",
for method in ["write", "redirect", "set_header", "set_cookie",
"set_status", "flush", "finish"]:
setattr(WebSocketHandler, method,
_wrap_method(getattr(WebSocketHandler, method)))
@ -355,13 +408,68 @@ class WebSocketProtocol(object):
self.close() # let the subclass cleanup
class _PerMessageDeflateCompressor(object):
def __init__(self, persistent, max_wbits):
if max_wbits is None:
max_wbits = zlib.MAX_WBITS
# There is no symbolic constant for the minimum wbits value.
if not (8 <= max_wbits <= zlib.MAX_WBITS):
raise ValueError("Invalid max_wbits value %r; allowed range 8-%d",
max_wbits, zlib.MAX_WBITS)
self._max_wbits = max_wbits
if persistent:
self._compressor = self._create_compressor()
else:
self._compressor = None
def _create_compressor(self):
return zlib.compressobj(-1, zlib.DEFLATED, -self._max_wbits)
def compress(self, data):
compressor = self._compressor or self._create_compressor()
data = (compressor.compress(data) +
compressor.flush(zlib.Z_SYNC_FLUSH))
assert data.endswith(b'\x00\x00\xff\xff')
return data[:-4]
class _PerMessageDeflateDecompressor(object):
def __init__(self, persistent, max_wbits):
if max_wbits is None:
max_wbits = zlib.MAX_WBITS
if not (8 <= max_wbits <= zlib.MAX_WBITS):
raise ValueError("Invalid max_wbits value %r; allowed range 8-%d",
max_wbits, zlib.MAX_WBITS)
self._max_wbits = max_wbits
if persistent:
self._decompressor = self._create_decompressor()
else:
self._decompressor = None
def _create_decompressor(self):
return zlib.decompressobj(-self._max_wbits)
def decompress(self, data):
decompressor = self._decompressor or self._create_decompressor()
return decompressor.decompress(data + b'\x00\x00\xff\xff')
class WebSocketProtocol13(WebSocketProtocol):
"""Implementation of the WebSocket protocol from RFC 6455.
This class supports versions 7 and 8 of the protocol in addition to the
final version 13.
"""
def __init__(self, handler, mask_outgoing=False):
# Bit masks for the first byte of a frame.
FIN = 0x80
RSV1 = 0x40
RSV2 = 0x20
RSV3 = 0x10
RSV_MASK = RSV1 | RSV2 | RSV3
OPCODE_MASK = 0x0f
def __init__(self, handler, mask_outgoing=False,
compression_options=None):
WebSocketProtocol.__init__(self, handler)
self.mask_outgoing = mask_outgoing
self._final_frame = False
@ -372,6 +480,19 @@ class WebSocketProtocol13(WebSocketProtocol):
self._fragmented_message_buffer = None
self._fragmented_message_opcode = None
self._waiting = None
self._compression_options = compression_options
self._decompressor = None
self._compressor = None
self._frame_compressed = None
# The total uncompressed size of all messages received or sent.
# Unicode messages are encoded to utf8.
# Only for testing; subject to change.
self._message_bytes_in = 0
self._message_bytes_out = 0
# The total size of all packets received or sent. Includes
# the effect of compression, frame overhead, and control frames.
self._wire_bytes_in = 0
self._wire_bytes_out = 0
def accept_connection(self):
try:
@ -416,24 +537,99 @@ class WebSocketProtocol13(WebSocketProtocol):
assert selected in subprotocols
subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
extension_header = ''
extensions = self._parse_extensions_header(self.request.headers)
for ext in extensions:
if (ext[0] == 'permessage-deflate' and
self._compression_options is not None):
# TODO: negotiate parameters if compression_options
# specifies limits.
self._create_compressors('server', ext[1])
if ('client_max_window_bits' in ext[1] and
ext[1]['client_max_window_bits'] is None):
# Don't echo an offered client_max_window_bits
# parameter with no value.
del ext[1]['client_max_window_bits']
extension_header = ('Sec-WebSocket-Extensions: %s\r\n' %
httputil._encode_header(
'permessage-deflate', ext[1]))
break
if self.stream.closed():
self._abort()
return
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Accept: %s\r\n"
"%s"
"\r\n" % (self._challenge_response(), subprotocol_header)))
"%s%s"
"\r\n" % (self._challenge_response(),
subprotocol_header, extension_header)))
self._run_callback(self.handler.open, *self.handler.open_args,
**self.handler.open_kwargs)
self._receive_frame()
def _write_frame(self, fin, opcode, data):
def _parse_extensions_header(self, headers):
extensions = headers.get("Sec-WebSocket-Extensions", '')
if extensions:
return [httputil._parse_header(e.strip())
for e in extensions.split(',')]
return []
def _process_server_headers(self, key, headers):
"""Process the headers sent by the server to this client connection.
'key' is the websocket handshake challenge/response key.
"""
assert headers['Upgrade'].lower() == 'websocket'
assert headers['Connection'].lower() == 'upgrade'
accept = self.compute_accept_value(key)
assert headers['Sec-Websocket-Accept'] == accept
extensions = self._parse_extensions_header(headers)
for ext in extensions:
if (ext[0] == 'permessage-deflate' and
self._compression_options is not None):
self._create_compressors('client', ext[1])
else:
raise ValueError("unsupported extension %r", ext)
def _get_compressor_options(self, side, agreed_parameters):
"""Converts a websocket agreed_parameters set to keyword arguments
for our compressor objects.
"""
options = dict(
persistent=(side + '_no_context_takeover') not in agreed_parameters)
wbits_header = agreed_parameters.get(side + '_max_window_bits', None)
if wbits_header is None:
options['max_wbits'] = zlib.MAX_WBITS
else:
options['max_wbits'] = int(wbits_header)
return options
def _create_compressors(self, side, agreed_parameters):
# TODO: handle invalid parameters gracefully
allowed_keys = set(['server_no_context_takeover',
'client_no_context_takeover',
'server_max_window_bits',
'client_max_window_bits'])
for key in agreed_parameters:
if key not in allowed_keys:
raise ValueError("unsupported compression parameter %r" % key)
other_side = 'client' if (side == 'server') else 'server'
self._compressor = _PerMessageDeflateCompressor(
**self._get_compressor_options(side, agreed_parameters))
self._decompressor = _PerMessageDeflateDecompressor(
**self._get_compressor_options(other_side, agreed_parameters))
def _write_frame(self, fin, opcode, data, flags=0):
if fin:
finbit = 0x80
finbit = self.FIN
else:
finbit = 0
frame = struct.pack("B", finbit | opcode)
frame = struct.pack("B", finbit | opcode | flags)
l = len(data)
if self.mask_outgoing:
mask_bit = 0x80
@ -449,7 +645,11 @@ class WebSocketProtocol13(WebSocketProtocol):
mask = os.urandom(4)
data = mask + _websocket_mask(mask, data)
frame += data
self._wire_bytes_out += len(frame)
try:
self.stream.write(frame)
except StreamClosedError:
self._abort()
def write_message(self, message, binary=False):
"""Sends the given message to the client of this Web Socket."""
@ -458,15 +658,17 @@ class WebSocketProtocol13(WebSocketProtocol):
else:
opcode = 0x1
message = tornado.escape.utf8(message)
assert isinstance(message, bytes_type)
try:
self._write_frame(True, opcode, message)
except StreamClosedError:
self._abort()
assert isinstance(message, bytes)
self._message_bytes_out += len(message)
flags = 0
if self._compressor:
message = self._compressor.compress(message)
flags |= self.RSV1
self._write_frame(True, opcode, message, flags=flags)
def write_ping(self, data):
"""Send ping frame."""
assert isinstance(data, bytes_type)
assert isinstance(data, bytes)
self._write_frame(True, 0x9, data)
def _receive_frame(self):
@ -476,11 +678,15 @@ class WebSocketProtocol13(WebSocketProtocol):
self._abort()
def _on_frame_start(self, data):
self._wire_bytes_in += len(data)
header, payloadlen = struct.unpack("BB", data)
self._final_frame = header & 0x80
reserved_bits = header & 0x70
self._frame_opcode = header & 0xf
self._final_frame = header & self.FIN
reserved_bits = header & self.RSV_MASK
self._frame_opcode = header & self.OPCODE_MASK
self._frame_opcode_is_control = self._frame_opcode & 0x8
if self._decompressor is not None:
self._frame_compressed = bool(reserved_bits & self.RSV1)
reserved_bits &= ~self.RSV1
if reserved_bits:
# client is using as-yet-undefined extensions; abort
self._abort()
@ -506,6 +712,7 @@ class WebSocketProtocol13(WebSocketProtocol):
self._abort()
def _on_frame_length_16(self, data):
self._wire_bytes_in += len(data)
self._frame_length = struct.unpack("!H", data)[0]
try:
if self._masked_frame:
@ -516,6 +723,7 @@ class WebSocketProtocol13(WebSocketProtocol):
self._abort()
def _on_frame_length_64(self, data):
self._wire_bytes_in += len(data)
self._frame_length = struct.unpack("!Q", data)[0]
try:
if self._masked_frame:
@ -526,6 +734,7 @@ class WebSocketProtocol13(WebSocketProtocol):
self._abort()
def _on_masking_key(self, data):
self._wire_bytes_in += len(data)
self._frame_mask = data
try:
self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
@ -533,9 +742,11 @@ class WebSocketProtocol13(WebSocketProtocol):
self._abort()
def _on_masked_frame_data(self, data):
# Don't touch _wire_bytes_in; we'll do it in _on_frame_data.
self._on_frame_data(_websocket_mask(self._frame_mask, data))
def _on_frame_data(self, data):
self._wire_bytes_in += len(data)
if self._frame_opcode_is_control:
# control frames may be interleaved with a series of fragmented
# data frames, so control frames must not interact with
@ -576,8 +787,12 @@ class WebSocketProtocol13(WebSocketProtocol):
if self.client_terminated:
return
if self._frame_compressed:
data = self._decompressor.decompress(data)
if opcode == 0x1:
# UTF-8 data
self._message_bytes_in += len(data)
try:
decoded = data.decode("utf-8")
except UnicodeDecodeError:
@ -586,7 +801,8 @@ class WebSocketProtocol13(WebSocketProtocol):
self._run_callback(self.handler.on_message, decoded)
elif opcode == 0x2:
# Binary data
self._run_callback(self.handler.on_message, decoded)
self._message_bytes_in += len(data)
self._run_callback(self.handler.on_message, data)
elif opcode == 0x8:
# Close
self.client_terminated = True
@ -636,7 +852,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
This class should not be instantiated directly; use the
`websocket_connect` function instead.
"""
def __init__(self, io_loop, request):
def __init__(self, io_loop, request, compression_options=None):
self.compression_options = compression_options
self.connect_future = TracebackFuture()
self.read_future = None
self.read_queue = collections.deque()
@ -651,6 +868,14 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
'Sec-WebSocket-Key': self.key,
'Sec-WebSocket-Version': '13',
})
if self.compression_options is not None:
# Always offer to let the server set our max_wbits (and even though
# we don't offer it, we will accept a client_no_context_takeover
# from the server).
# TODO: set server parameters for deflate extension
# if requested in self.compression_options.
request.headers['Sec-WebSocket-Extensions'] = (
'permessage-deflate; client_max_window_bits')
self.tcp_client = TCPClient(io_loop=io_loop)
super(WebSocketClientConnection, self).__init__(
@ -673,10 +898,12 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
self.protocol.close(code, reason)
self.protocol = None
def _on_close(self):
def on_connection_close(self):
if not self.connect_future.done():
self.connect_future.set_exception(StreamClosedError())
self.on_message(None)
self.resolver.close()
super(WebSocketClientConnection, self)._on_close()
self.tcp_client.close()
super(WebSocketClientConnection, self).on_connection_close()
def _on_http_response(self, response):
if not self.connect_future.done():
@ -692,12 +919,10 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
start_line, headers)
self.headers = headers
assert self.headers['Upgrade'].lower() == 'websocket'
assert self.headers['Connection'].lower() == 'upgrade'
accept = WebSocketProtocol13.compute_accept_value(self.key)
assert self.headers['Sec-Websocket-Accept'] == accept
self.protocol = WebSocketProtocol13(self, mask_outgoing=True)
self.protocol = WebSocketProtocol13(
self, mask_outgoing=True,
compression_options=self.compression_options)
self.protocol._process_server_headers(self.key, self.headers)
self.protocol._receive_frame()
if self._timeout is not None:
@ -705,7 +930,12 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
self._timeout = None
self.stream = self.connection.detach()
self.stream.set_close_callback(self._on_close)
self.stream.set_close_callback(self.on_connection_close)
# Once we've taken over the connection, clear the final callback
# we set on the http request. This deactivates the error handling
# in simple_httpclient that would otherwise interfere with our
# ability to see exceptions.
self.final_callback = None
self.connect_future.set_result(self)
@ -742,14 +972,21 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
pass
def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None):
def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
compression_options=None):
"""Client-side websocket support.
Takes a url and returns a Future whose result is a
`WebSocketClientConnection`.
``compression_options`` is interpreted in the same way as the
return value of `.WebSocketHandler.get_compression_options`.
.. versionchanged:: 3.2
Also accepts ``HTTPRequest`` objects in place of urls.
.. versionchanged:: 4.1
Added ``compression_options``.
"""
if io_loop is None:
io_loop = IOLoop.current()
@ -763,7 +1000,7 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None):
request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
request = httpclient._RequestProxy(
request, httpclient.HTTPRequest._DEFAULTS)
conn = WebSocketClientConnection(io_loop, request)
conn = WebSocketClientConnection(io_loop, request, compression_options)
if callback is not None:
io_loop.add_future(conn.connect_future, callback)
return conn.connect_future

View file

@ -32,6 +32,7 @@ provides WSGI support in two ways:
from __future__ import absolute_import, division, print_function, with_statement
import sys
from io import BytesIO
import tornado
from tornado.concurrent import Future
@ -40,12 +41,8 @@ from tornado import httputil
from tornado.log import access_log
from tornado import web
from tornado.escape import native_str
from tornado.util import bytes_type, unicode_type
from tornado.util import unicode_type
try:
from io import BytesIO # python 3
except ImportError:
from cStringIO import StringIO as BytesIO # python 2
try:
import urllib.parse as urllib_parse # py3
@ -58,7 +55,7 @@ except ImportError:
# here to minimize the temptation to use them in non-wsgi contexts.
if str is unicode_type:
def to_wsgi_str(s):
assert isinstance(s, bytes_type)
assert isinstance(s, bytes)
return s.decode('latin1')
def from_wsgi_str(s):
@ -66,7 +63,7 @@ if str is unicode_type:
return s.encode('latin1')
else:
def to_wsgi_str(s):
assert isinstance(s, bytes_type)
assert isinstance(s, bytes)
return s
def from_wsgi_str(s):