Updated tornado source code.

Created custom class for rss feed parser.
This commit is contained in:
echel0n 2014-06-29 03:05:33 -07:00
parent feabf20c8c
commit 0c57676aed
25 changed files with 393 additions and 457 deletions

View file

@ -1,21 +1,36 @@
var message_url = sbRoot + '/ui/get_messages'; var message_url = sbRoot + '/ui/get_messages/';
$.pnotify.defaults.pnotify_width = "340px"; $.pnotify.defaults.width = "400px";
$.pnotify.defaults.pnotify_history = false; $.pnotify.defaults.styling = "jqueryui";
$.pnotify.defaults.pnotify_delay = 4000; $.pnotify.defaults.history = false;
$.pnotify.defaults.shadow = false;
$.pnotify.defaults.delay = 4000;
$.pnotify.defaults.maxonscreen = 5;
function check_notifications() { function check_notifications() {
$.getJSON(message_url, function(data){ var poll_interval = 5000;
$.ajax({
url: message_url,
success: function (data) {
poll_interval = 5000;
$.each(data, function (name, data) { $.each(data, function (name, data) {
$.pnotify({ $.pnotify({
pnotify_type: data.type, type: data.type,
pnotify_hide: data.type == 'notice', hide: data.type == 'notice',
pnotify_title: data.title, title: data.title,
pnotify_text: data.message text: data.message
}); });
}); });
},
error: function () {
poll_interval = 15000;
},
type: "GET",
dataType: "json",
complete: function () {
setTimeout(check_notifications, poll_interval);
},
timeout: 15000 // timeout every 15 secs
}); });
setTimeout(check_notifications, 3000)
} }
$(document).ready(function () { $(document).ready(function () {

View file

@ -756,8 +756,7 @@ def initialize(consoleLogging=True):
USE_PUSHOVER = bool(check_setting_int(CFG, 'Pushover', 'use_pushover', 0)) USE_PUSHOVER = bool(check_setting_int(CFG, 'Pushover', 'use_pushover', 0))
PUSHOVER_NOTIFY_ONSNATCH = bool(check_setting_int(CFG, 'Pushover', 'pushover_notify_onsnatch', 0)) PUSHOVER_NOTIFY_ONSNATCH = bool(check_setting_int(CFG, 'Pushover', 'pushover_notify_onsnatch', 0))
PUSHOVER_NOTIFY_ONDOWNLOAD = bool(check_setting_int(CFG, 'Pushover', 'pushover_notify_ondownload', 0)) PUSHOVER_NOTIFY_ONDOWNLOAD = bool(check_setting_int(CFG, 'Pushover', 'pushover_notify_ondownload', 0))
PUSHOVER_NOTIFY_ONSUBTITLEDOWNLOAD = bool( PUSHOVER_NOTIFY_ONSUBTITLEDOWNLOAD = bool(check_setting_int(CFG, 'Pushover', 'pushover_notify_onsubtitledownload', 0))
check_setting_int(CFG, 'Pushover', 'pushover_notify_onsubtitledownload', 0))
PUSHOVER_USERKEY = check_setting_str(CFG, 'Pushover', 'pushover_userkey', '') PUSHOVER_USERKEY = check_setting_str(CFG, 'Pushover', 'pushover_userkey', '')
PUSHOVER_APIKEY = check_setting_str(CFG, 'Pushover', 'pushover_apikey', '') PUSHOVER_APIKEY = check_setting_str(CFG, 'Pushover', 'pushover_apikey', '')
USE_LIBNOTIFY = bool(check_setting_int(CFG, 'Libnotify', 'use_libnotify', 0)) USE_LIBNOTIFY = bool(check_setting_int(CFG, 'Libnotify', 'use_libnotify', 0))

View file

@ -263,9 +263,9 @@ def download_file(url, filename):
def findCertainShow(showList, indexerid=None): def findCertainShow(showList, indexerid=None):
if indexerid: if indexerid:
results = filter(lambda x: x.indexerid == indexerid, showList) results = filter(lambda x: int(x.indexerid) == int(indexerid), showList)
else: else:
results = filter(lambda x: x.indexerid == indexerid, showList) results = filter(lambda x: int(x.indexerid) == int(indexerid), showList)
if len(results) == 0: if len(results) == 0:
return None return None

View file

@ -21,6 +21,7 @@ import generic
from sickbeard import logger from sickbeard import logger
from sickbeard import tvcache from sickbeard import tvcache
from sickbeard.exceptions import AuthException
class WombleProvider(generic.NZBProvider): class WombleProvider(generic.NZBProvider):
@ -40,13 +41,40 @@ class WombleCache(tvcache.TVCache):
# only poll Womble's Index every 15 minutes max # only poll Womble's Index every 15 minutes max
self.minTime = 15 self.minTime = 15
def _getRSSData(self): def updateCache(self):
url = self.provider.url + 'rss/?sec=TV-x264&fr=false'
# delete anything older then 7 days
logger.log(u"Clearing " + self.provider.name + " cache")
self._clearCache()
data = None
if not self.shouldUpdate():
for url in [self.provider.url + 'rss/?sec=tv-sd&fr=false', self.provider.url + 'rss/?sec=tv-hd&fr=false']:
logger.log(u"Womble's Index cache update URL: " + url, logger.DEBUG) logger.log(u"Womble's Index cache update URL: " + url, logger.DEBUG)
return self.getRSSFeed(url) data = self.getRSSFeed(url)
# As long as we got something from the provider we count it as an update
if not data:
return []
# By now we know we've got data and no auth errors, all we need to do is put it in the database
cl = []
for item in data.entries:
ci = self._parseItem(item)
if ci is not None:
cl.append(ci)
if cl:
myDB = self._getDB()
myDB.mass_action(cl)
# set last updated
if data:
self.setLastUpdate()
def _checkAuth(self, data): def _checkAuth(self, data):
return data != 'Invalid Link' return data != 'Invalid Link'
provider = WombleProvider() provider = WombleProvider()

62
sickbeard/rssfeeds.py Normal file
View file

@ -0,0 +1,62 @@
import os
import threading
import urllib
import urlparse
import re
import time
import sickbeard
from sickbeard import logger
from sickbeard import encodingKludge as ek
from sickbeard.exceptions import ex
from lib.shove import Shove
from lib.feedcache import cache
feed_lock = threading.Lock()
class RSSFeeds:
def __init__(self, db_name):
try:
self.fs = self.fs = Shove('sqlite:///' + ek.ek(os.path.join, sickbeard.CACHE_DIR, db_name + '.db'), compress=True)
self.fc = cache.Cache(self.fs)
except Exception, e:
logger.log(u"RSS error: " + ex(e), logger.ERROR)
raise
def __enter__(self):
return self
def __exit__(self, type, value, tb):
self.fc = None
self.fs.close()
def clearCache(self, age=None):
if not self.fc:
return
self.fc.purge(age)
def getRSSFeed(self, url, post_data=None):
if not self.fc:
return
with feed_lock:
parsed = list(urlparse.urlparse(url))
parsed[2] = re.sub("/{2,}", "/", parsed[2]) # replace two or more / with one
if post_data:
url += urllib.urlencode(post_data)
feed = self.fc.fetch(url)
if not feed:
logger.log(u"RSS Error loading URL: " + url, logger.ERROR)
return
elif 'error' in feed.feed:
logger.log(u"RSS ERROR:[%s] CODE:[%s]" % (feed.feed['error']['description'], feed.feed['error']['code']),
logger.DEBUG)
return
elif not feed.entries:
logger.log(u"No RSS items found using URL: " + url, logger.WARNING)
return
return feed

View file

@ -22,29 +22,21 @@ import os
import time import time
import datetime import datetime
import urllib
import urlparse
import re
import threading import threading
import sickbeard import sickbeard
from lib.shove import Shove
from lib.feedcache import cache
from sickbeard import db from sickbeard import db
from sickbeard import logger from sickbeard import logger
from sickbeard.common import Quality, cpu_presets from sickbeard.common import Quality
from sickbeard import helpers, show_name_helpers from sickbeard import helpers, show_name_helpers
from sickbeard.exceptions import MultipleShowObjectsException from sickbeard.exceptions import MultipleShowObjectsException
from sickbeard.exceptions import AuthException from sickbeard.exceptions import AuthException
from sickbeard import encodingKludge as ek
from name_parser.parser import NameParser, InvalidNameException from name_parser.parser import NameParser, InvalidNameException
from sickbeard.rssfeeds import RSSFeeds
cache_lock = threading.Lock() cache_lock = threading.Lock()
class CacheDBConnection(db.DBConnection): class CacheDBConnection(db.DBConnection):
def __init__(self, providerName): def __init__(self, providerName):
db.DBConnection.__init__(self, "cache.db") db.DBConnection.__init__(self, "cache.db")
@ -87,14 +79,16 @@ class TVCache():
return CacheDBConnection(self.providerID) return CacheDBConnection(self.providerID)
def _clearCache(self): def _clearCache(self):
if not self.shouldClearCache(): if self.shouldClearCache():
return
curDate = datetime.date.today() - datetime.timedelta(weeks=1) curDate = datetime.date.today() - datetime.timedelta(weeks=1)
myDB = self._getDB() myDB = self._getDB()
myDB.action("DELETE FROM [" + self.providerID + "] WHERE time < ?", [int(time.mktime(curDate.timetuple()))]) myDB.action("DELETE FROM [" + self.providerID + "] WHERE time < ?", [int(time.mktime(curDate.timetuple()))])
# clear RSS Feed cache
with RSSFeeds(self.providerID) as feed:
feed.clearCache(int(time.mktime(curDate.timetuple())))
def _getRSSData(self): def _getRSSData(self):
data = None data = None
@ -126,9 +120,8 @@ class TVCache():
return [] return []
if self._checkAuth(data): if self._checkAuth(data):
items = data.entries
cl = [] cl = []
for item in items: for item in data.entries:
ci = self._parseItem(item) ci = self._parseItem(item)
if ci is not None: if ci is not None:
cl.append(ci) cl.append(ci)
@ -143,34 +136,10 @@ class TVCache():
return [] return []
def getRSSFeed(self, url, post_data=None, request_headers=None): def getRSSFeed(self, url, post_data=None):
# create provider storaqe cache with RSSFeeds(self.providerID) as feed:
storage = Shove('sqlite:///' + ek.ek(os.path.join, sickbeard.CACHE_DIR, self.provider.name) + '.db') data = feed.getRSSFeed(url, post_data)
fc = cache.Cache(storage) return data
parsed = list(urlparse.urlparse(url))
parsed[2] = re.sub("/{2,}", "/", parsed[2]) # replace two or more / with one
if post_data:
url += urllib.urlencode(post_data)
f = fc.fetch(url, request_headers=request_headers)
if not f:
logger.log(u"Error loading " + self.providerID + " URL: " + url, logger.ERROR)
return None
elif 'error' in f.feed:
logger.log(u"Newznab ERROR:[%s] CODE:[%s]" % (f.feed['error']['description'], f.feed['error']['code']),
logger.DEBUG)
return None
elif not f.entries:
logger.log(u"No items found on " + self.providerID + " using URL: " + url, logger.WARNING)
return None
storage.close()
return f
def _translateTitle(self, title): def _translateTitle(self, title):
return title.replace(' ', '.') return title.replace(' ', '.')

View file

@ -19,6 +19,7 @@
from __future__ import with_statement from __future__ import with_statement
import base64 import base64
import inspect import inspect
import traceback
import urlparse import urlparse
import zipfile import zipfile
@ -136,7 +137,9 @@ class MainHandler(RequestHandler):
super(MainHandler, self).__init__(application, request, **kwargs) super(MainHandler, self).__init__(application, request, **kwargs)
global req_headers global req_headers
sickbeard.REMOTE_IP = self.request.remote_ip sickbeard.REMOTE_IP = self.request.headers.get('X-Forwarded-For',
self.request.headers.get('X-Real-Ip', self.request.remote_ip))
req_headers = self.request.headers req_headers = self.request.headers
def http_error_401_handler(self): def http_error_401_handler(self):
@ -158,11 +161,12 @@ class MainHandler(RequestHandler):
return self.redirectTo('/home/') return self.redirectTo('/home/')
def write_error(self, status_code, **kwargs): def write_error(self, status_code, **kwargs):
if status_code == 404: if status_code == 401:
return self.redirectTo('/home/')
elif status_code == 401:
self.finish(self.http_error_401_handler()) self.finish(self.http_error_401_handler())
elif status_code == 404:
self.redirectTo('/home/')
else: else:
logger.log(traceback.format_exc(), logger.DEBUG)
super(MainHandler, self).write_error(status_code, **kwargs) super(MainHandler, self).write_error(status_code, **kwargs)
def _dispatch(self): def _dispatch(self):
@ -209,22 +213,17 @@ class MainHandler(RequestHandler):
raise HTTPError(404) raise HTTPError(404)
def redirectTo(self, url): def redirectTo(self, url):
self._transforms = []
url = urlparse.urljoin(sickbeard.WEB_ROOT, url) url = urlparse.urljoin(sickbeard.WEB_ROOT, url)
logger.log(u"Redirecting to: " + url, logger.DEBUG) logger.log(u"Redirecting to: " + url, logger.DEBUG)
self.redirect(url, status=303) self._transforms = []
self.redirect(url)
def get(self, *args, **kwargs): def get(self, *args, **kwargs):
response = self._dispatch() self.write(self._dispatch())
if response:
self.finish(response)
def post(self, *args, **kwargs): def post(self, *args, **kwargs):
response = self._dispatch() self._dispatch()
if response:
self.finish(response)
def robots_txt(self, *args, **kwargs): def robots_txt(self, *args, **kwargs):
""" Keep web crawlers out """ """ Keep web crawlers out """
@ -456,13 +455,13 @@ class MainHandler(RequestHandler):
browser = WebFileBrowser browser = WebFileBrowser
class PageTemplate(Template): class PageTemplate(Template):
def __init__(self, *args, **KWs): def __init__(self, *args, **KWs):
global req_headers
KWs['file'] = os.path.join(sickbeard.PROG_DIR, "gui/" + sickbeard.GUI_NAME + "/interfaces/default/", KWs['file'] = os.path.join(sickbeard.PROG_DIR, "gui/" + sickbeard.GUI_NAME + "/interfaces/default/",
KWs['file']) KWs['file'])
super(PageTemplate, self).__init__(*args, **KWs) super(PageTemplate, self).__init__(*args, **KWs)
global req_headers
self.sbRoot = sickbeard.WEB_ROOT self.sbRoot = sickbeard.WEB_ROOT
self.sbHttpPort = sickbeard.WEB_PORT self.sbHttpPort = sickbeard.WEB_PORT
@ -512,9 +511,7 @@ class IndexerWebUI(MainHandler):
def _munge(string): def _munge(string):
to_return = unicode(string).encode('utf-8', 'xmlcharrefreplace') return unicode(string).encode('utf-8', 'xmlcharrefreplace')
return to_return
def _genericMessage(subject, message): def _genericMessage(subject, message):
t = PageTemplate(file="genericMessage.tmpl") t = PageTemplate(file="genericMessage.tmpl")
@ -4296,15 +4293,15 @@ class Home(MainHandler):
return json.dumps({'result': 'failure'}) return json.dumps({'result': 'failure'})
class UI(MainHandler): class UI(MainHandler):
def add_message(self, *args, **kwargs): def add_message(self):
ui.notifications.message('Test 1', 'This is test number 1') ui.notifications.message('Test 1', 'This is test number 1')
ui.notifications.error('Test 2', 'This is test number 2') ui.notifications.error('Test 2', 'This is test number 2')
return "ok" return "ok"
def get_messages(self, *args, **kwargs): def get_messages(self):
messages = {} messages = {}
cur_notification_num = 1 cur_notification_num = 1
for cur_notification in ui.notifications.get_notifications(): for cur_notification in ui.notifications.get_notifications():

View file

@ -25,5 +25,5 @@ from __future__ import absolute_import, division, print_function, with_statement
# is zero for an official release, positive for a development branch, # is zero for an official release, positive for a development branch,
# or negative for a release candidate or beta (after the base version # or negative for a release candidate or beta (after the base version
# number has been incremented) # number has been incremented)
version = "4.0.dev1" version = "4.0b1"
version_info = (4, 0, 0, -100) version_info = (4, 0, 0, -99)

View file

@ -883,7 +883,8 @@ class FriendFeedMixin(OAuthMixin):
class GoogleMixin(OpenIdMixin, OAuthMixin): class GoogleMixin(OpenIdMixin, OAuthMixin):
"""Google Open ID / OAuth authentication. """Google Open ID / OAuth authentication.
*Deprecated:* New applications should use `GoogleOAuth2Mixin` .. deprecated:: 4.0
New applications should use `GoogleOAuth2Mixin`
below instead of this class. As of May 19, 2014, Google has stopped below instead of this class. As of May 19, 2014, Google has stopped
supporting registration-free authentication. supporting registration-free authentication.
@ -1053,7 +1054,8 @@ class GoogleOAuth2Mixin(OAuth2Mixin):
class FacebookMixin(object): class FacebookMixin(object):
"""Facebook Connect authentication. """Facebook Connect authentication.
*Deprecated:* New applications should use `FacebookGraphMixin` .. deprecated:: 1.1
New applications should use `FacebookGraphMixin`
below instead of this class. This class does not support the below instead of this class. This class does not support the
Future-based interface seen on other classes in this module. Future-based interface seen on other classes in this module.

View file

@ -51,18 +51,6 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
self._fds = {} self._fds = {}
self._timeout = None self._timeout = None
try:
self._socket_action = self._multi.socket_action
except AttributeError:
# socket_action is found in pycurl since 7.18.2 (it's been
# in libcurl longer than that but wasn't accessible to
# python).
gen_log.warning("socket_action method missing from pycurl; "
"falling back to socket_all. Upgrading "
"libcurl and pycurl will improve performance")
self._socket_action = \
lambda fd, action: self._multi.socket_all()
# libcurl has bugs that sometimes cause it to not report all # libcurl has bugs that sometimes cause it to not report all
# relevant file descriptors and timeouts to TIMERFUNCTION/ # relevant file descriptors and timeouts to TIMERFUNCTION/
# SOCKETFUNCTION. Mitigate the effects of such bugs by # SOCKETFUNCTION. Mitigate the effects of such bugs by
@ -142,7 +130,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
action |= pycurl.CSELECT_OUT action |= pycurl.CSELECT_OUT
while True: while True:
try: try:
ret, num_handles = self._socket_action(fd, action) ret, num_handles = self._multi.socket_action(fd, action)
except pycurl.error as e: except pycurl.error as e:
ret = e.args[0] ret = e.args[0]
if ret != pycurl.E_CALL_MULTI_PERFORM: if ret != pycurl.E_CALL_MULTI_PERFORM:
@ -155,7 +143,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
self._timeout = None self._timeout = None
while True: while True:
try: try:
ret, num_handles = self._socket_action( ret, num_handles = self._multi.socket_action(
pycurl.SOCKET_TIMEOUT, 0) pycurl.SOCKET_TIMEOUT, 0)
except pycurl.error as e: except pycurl.error as e:
ret = e.args[0] ret = e.args[0]
@ -223,11 +211,6 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
"callback": callback, "callback": callback,
"curl_start_time": time.time(), "curl_start_time": time.time(),
} }
# Disable IPv6 to mitigate the effects of this bug
# on curl versions <= 7.21.0
# http://sourceforge.net/tracker/?func=detail&aid=3017819&group_id=976&atid=100976
if pycurl.version_info()[2] <= 0x71500: # 7.21.0
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
_curl_setup_request(curl, request, curl.info["buffer"], _curl_setup_request(curl, request, curl.info["buffer"],
curl.info["headers"]) curl.info["headers"])
self._multi.add_handle(curl) self._multi.add_handle(curl)
@ -383,7 +366,6 @@ def _curl_setup_request(curl, request, buffer, headers):
if request.allow_ipv6 is False: if request.allow_ipv6 is False:
# Curl behaves reasonably when DNS resolution gives an ipv6 address # Curl behaves reasonably when DNS resolution gives an ipv6 address
# that we can't reach, so allow ipv6 unless the user asks to disable. # that we can't reach, so allow ipv6 unless the user asks to disable.
# (but see version check in _process_queue above)
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4) curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
else: else:
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_WHATEVER) curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_WHATEVER)

View file

@ -22,14 +22,17 @@ to switch to ``curl_httpclient`` for reasons such as the following:
* ``curl_httpclient`` was the default prior to Tornado 2.0. * ``curl_httpclient`` was the default prior to Tornado 2.0.
Note that if you are using ``curl_httpclient``, it is highly recommended that Note that if you are using ``curl_httpclient``, it is highly
you use a recent version of ``libcurl`` and ``pycurl``. Currently the minimum recommended that you use a recent version of ``libcurl`` and
supported version is 7.18.2, and the recommended version is 7.21.1 or newer. ``pycurl``. Currently the minimum supported version of libcurl is
It is highly recommended that your ``libcurl`` installation is built with 7.21.1, and the minimum version of pycurl is 7.18.2. It is highly
asynchronous DNS resolver (threaded or c-ares), otherwise you may encounter recommended that your ``libcurl`` installation is built with
various problems with request timeouts (for more information, see asynchronous DNS resolver (threaded or c-ares), otherwise you may
encounter various problems with request timeouts (for more
information, see
http://curl.haxx.se/libcurl/c/curl_easy_setopt.html#CURLOPTCONNECTTIMEOUTMS http://curl.haxx.se/libcurl/c/curl_easy_setopt.html#CURLOPTCONNECTTIMEOUTMS
and comments in curl_httpclient.py). and comments in curl_httpclient.py).
""" """
from __future__ import absolute_import, division, print_function, with_statement from __future__ import absolute_import, division, print_function, with_statement
@ -144,12 +147,21 @@ class AsyncHTTPClient(Configurable):
def __new__(cls, io_loop=None, force_instance=False, **kwargs): def __new__(cls, io_loop=None, force_instance=False, **kwargs):
io_loop = io_loop or IOLoop.current() io_loop = io_loop or IOLoop.current()
if io_loop in cls._async_clients() and not force_instance: if force_instance:
return cls._async_clients()[io_loop] instance_cache = None
else:
instance_cache = cls._async_clients()
if instance_cache is not None and io_loop in instance_cache:
return instance_cache[io_loop]
instance = super(AsyncHTTPClient, cls).__new__(cls, io_loop=io_loop, instance = super(AsyncHTTPClient, cls).__new__(cls, io_loop=io_loop,
**kwargs) **kwargs)
if not force_instance: # Make sure the instance knows which cache to remove itself from.
cls._async_clients()[io_loop] = instance # It can't simply call _async_clients() because we may be in
# __new__(AsyncHTTPClient) but instance.__class__ may be
# SimpleAsyncHTTPClient.
instance._instance_cache = instance_cache
if instance_cache is not None:
instance_cache[instance.io_loop] = instance
return instance return instance
def initialize(self, io_loop, defaults=None): def initialize(self, io_loop, defaults=None):
@ -172,9 +184,13 @@ class AsyncHTTPClient(Configurable):
``close()``. ``close()``.
""" """
if self._closed:
return
self._closed = True self._closed = True
if self._async_clients().get(self.io_loop) is self: if self._instance_cache is not None:
del self._async_clients()[self.io_loop] if self._instance_cache.get(self.io_loop) is not self:
raise RuntimeError("inconsistent AsyncHTTPClient cache")
del self._instance_cache[self.io_loop]
def fetch(self, request, callback=None, **kwargs): def fetch(self, request, callback=None, **kwargs):
"""Executes a request, asynchronously returning an `HTTPResponse`. """Executes a request, asynchronously returning an `HTTPResponse`.

View file

@ -45,8 +45,7 @@ import traceback
from tornado.concurrent import TracebackFuture, is_future from tornado.concurrent import TracebackFuture, is_future
from tornado.log import app_log, gen_log from tornado.log import app_log, gen_log
from tornado import stack_context from tornado import stack_context
from tornado.util import Configurable from tornado.util import Configurable, errno_from_exception, timedelta_to_seconds
from tornado.util import errno_from_exception
try: try:
import signal import signal
@ -433,7 +432,7 @@ class IOLoop(Configurable):
""" """
return time.time() return time.time()
def add_timeout(self, deadline, callback): def add_timeout(self, deadline, callback, *args, **kwargs):
"""Runs the ``callback`` at the time ``deadline`` from the I/O loop. """Runs the ``callback`` at the time ``deadline`` from the I/O loop.
Returns an opaque handle that may be passed to Returns an opaque handle that may be passed to
@ -442,13 +441,59 @@ class IOLoop(Configurable):
``deadline`` may be a number denoting a time (on the same ``deadline`` may be a number denoting a time (on the same
scale as `IOLoop.time`, normally `time.time`), or a scale as `IOLoop.time`, normally `time.time`), or a
`datetime.timedelta` object for a deadline relative to the `datetime.timedelta` object for a deadline relative to the
current time. current time. Since Tornado 4.0, `call_later` is a more
convenient alternative for the relative case since it does not
require a timedelta object.
Note that it is not safe to call `add_timeout` from other threads. Note that it is not safe to call `add_timeout` from other threads.
Instead, you must use `add_callback` to transfer control to the Instead, you must use `add_callback` to transfer control to the
`IOLoop`'s thread, and then call `add_timeout` from there. `IOLoop`'s thread, and then call `add_timeout` from there.
Subclasses of IOLoop must implement either `add_timeout` or
`call_at`; the default implementations of each will call
the other. `call_at` is usually easier to implement, but
subclasses that wish to maintain compatibility with Tornado
versions prior to 4.0 must use `add_timeout` instead.
.. versionchanged:: 4.0
Now passes through ``*args`` and ``**kwargs`` to the callback.
""" """
raise NotImplementedError() if isinstance(deadline, numbers.Real):
return self.call_at(deadline, callback, *args, **kwargs)
elif isinstance(deadline, datetime.timedelta):
return self.call_at(self.time() + timedelta_to_seconds(deadline),
callback, *args, **kwargs)
else:
raise TypeError("Unsupported deadline %r" % deadline)
def call_later(self, delay, callback, *args, **kwargs):
"""Runs the ``callback`` after ``delay`` seconds have passed.
Returns an opaque handle that may be passed to `remove_timeout`
to cancel. Note that unlike the `asyncio` method of the same
name, the returned object does not have a ``cancel()`` method.
See `add_timeout` for comments on thread-safety and subclassing.
.. versionadded:: 4.0
"""
self.call_at(self.time() + delay, callback, *args, **kwargs)
def call_at(self, when, callback, *args, **kwargs):
"""Runs the ``callback`` at the absolute time designated by ``when``.
``when`` must be a number using the same reference point as
`IOLoop.time`.
Returns an opaque handle that may be passed to `remove_timeout`
to cancel. Note that unlike the `asyncio` method of the same
name, the returned object does not have a ``cancel()`` method.
See `add_timeout` for comments on thread-safety and subclassing.
.. versionadded:: 4.0
"""
self.add_timeout(when, callback, *args, **kwargs)
def remove_timeout(self, timeout): def remove_timeout(self, timeout):
"""Cancels a pending timeout. """Cancels a pending timeout.
@ -813,8 +858,11 @@ class PollIOLoop(IOLoop):
def time(self): def time(self):
return self.time_func() return self.time_func()
def add_timeout(self, deadline, callback): def call_at(self, deadline, callback, *args, **kwargs):
timeout = _Timeout(deadline, stack_context.wrap(callback), self) timeout = _Timeout(
deadline,
functools.partial(stack_context.wrap(callback), *args, **kwargs),
self)
heapq.heappush(self._timeouts, timeout) heapq.heappush(self._timeouts, timeout)
return timeout return timeout
@ -869,24 +917,12 @@ class _Timeout(object):
__slots__ = ['deadline', 'callback', 'tiebreaker'] __slots__ = ['deadline', 'callback', 'tiebreaker']
def __init__(self, deadline, callback, io_loop): def __init__(self, deadline, callback, io_loop):
if isinstance(deadline, numbers.Real): if not isinstance(deadline, numbers.Real):
self.deadline = deadline
elif isinstance(deadline, datetime.timedelta):
now = io_loop.time()
try:
self.deadline = now + deadline.total_seconds()
except AttributeError: # py2.6
self.deadline = now + _Timeout.timedelta_to_seconds(deadline)
else:
raise TypeError("Unsupported deadline %r" % deadline) raise TypeError("Unsupported deadline %r" % deadline)
self.deadline = deadline
self.callback = callback self.callback = callback
self.tiebreaker = next(io_loop._timeout_counter) self.tiebreaker = next(io_loop._timeout_counter)
@staticmethod
def timedelta_to_seconds(td):
"""Equivalent to td.total_seconds() (introduced in python 2.7)."""
return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / float(10 ** 6)
# Comparison methods to sort by deadline, with object id as a tiebreaker # Comparison methods to sort by deadline, with object id as a tiebreaker
# to guarantee a consistent ordering. The heapq module uses __le__ # to guarantee a consistent ordering. The heapq module uses __le__
# in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons # in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons

View file

@ -57,12 +57,24 @@ except ImportError:
# some they differ. # some they differ.
_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN) _ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
if hasattr(errno, "WSAEWOULDBLOCK"):
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,)
# These errnos indicate that a connection has been abruptly terminated. # These errnos indicate that a connection has been abruptly terminated.
# They should be caught and handled less noisily than other errors. # They should be caught and handled less noisily than other errors.
_ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE, _ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE,
errno.ETIMEDOUT) errno.ETIMEDOUT)
if hasattr(errno, "WSAECONNRESET"):
_ERRNO_CONNRESET += (errno.WSAECONNRESET, errno.WSAECONNABORTED, errno.WSAETIMEDOUT)
# More non-portable errnos:
_ERRNO_INPROGRESS = (errno.EINPROGRESS,)
if hasattr(errno, "WSAEINPROGRESS"):
_ERRNO_INPROGRESS += (errno.WSAEINPROGRESS,)
#######################################################
class StreamClosedError(IOError): class StreamClosedError(IOError):
"""Exception raised by `IOStream` methods when the stream is closed. """Exception raised by `IOStream` methods when the stream is closed.
@ -990,7 +1002,7 @@ class IOStream(BaseIOStream):
# returned immediately when attempting to connect to # returned immediately when attempting to connect to
# localhost, so handle them the same way as an error # localhost, so handle them the same way as an error
# reported later in _handle_connect. # reported later in _handle_connect.
if (errno_from_exception(e) != errno.EINPROGRESS and if (errno_from_exception(e) not in _ERRNO_INPROGRESS and
errno_from_exception(e) not in _ERRNO_WOULDBLOCK): errno_from_exception(e) not in _ERRNO_WOULDBLOCK):
gen_log.warning("Connect error on fd %s: %s", gen_log.warning("Connect error on fd %s: %s",
self.socket.fileno(), e) self.socket.fileno(), e)

View file

@ -57,6 +57,9 @@ u('foo').encode('idna')
# some they differ. # some they differ.
_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN) _ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
if hasattr(errno, "WSAEWOULDBLOCK"):
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,)
def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags=None): def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags=None):
"""Creates listening sockets bound to the given port and address. """Creates listening sockets bound to the given port and address.

View file

@ -13,9 +13,9 @@ from __future__ import absolute_import, division, print_function, with_statement
import datetime import datetime
import functools import functools
# _Timeout is used for its timedelta_to_seconds method for py26 compatibility. from tornado.ioloop import IOLoop
from tornado.ioloop import IOLoop, _Timeout
from tornado import stack_context from tornado import stack_context
from tornado.util import timedelta_to_seconds
try: try:
# Import the real asyncio module for py33+ first. Older versions of the # Import the real asyncio module for py33+ first. Older versions of the
@ -109,15 +109,13 @@ class BaseAsyncIOLoop(IOLoop):
def stop(self): def stop(self):
self.asyncio_loop.stop() self.asyncio_loop.stop()
def add_timeout(self, deadline, callback): def call_at(self, when, callback, *args, **kwargs):
if isinstance(deadline, (int, float)): # asyncio.call_at supports *args but not **kwargs, so bind them here.
delay = max(deadline - self.time(), 0) # We do not synchronize self.time and asyncio_loop.time, so
elif isinstance(deadline, datetime.timedelta): # convert from absolute to relative.
delay = _Timeout.timedelta_to_seconds(deadline) return self.asyncio_loop.call_later(
else: max(0, when - self.time()), self._run_callback,
raise TypeError("Unsupported deadline %r", deadline) functools.partial(stack_context.wrap(callback), *args, **kwargs))
return self.asyncio_loop.call_later(delay, self._run_callback,
stack_context.wrap(callback))
def remove_timeout(self, timeout): def remove_timeout(self, timeout):
timeout.cancel() timeout.cancel()

View file

@ -68,6 +68,7 @@ from __future__ import absolute_import, division, print_function, with_statement
import datetime import datetime
import functools import functools
import numbers
import socket import socket
import twisted.internet.abstract import twisted.internet.abstract
@ -90,11 +91,7 @@ from tornado.log import app_log
from tornado.netutil import Resolver from tornado.netutil import Resolver
from tornado.stack_context import NullContext, wrap from tornado.stack_context import NullContext, wrap
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
from tornado.util import timedelta_to_seconds
try:
long # py2
except NameError:
long = int # py3
@implementer(IDelayedCall) @implementer(IDelayedCall)
@ -475,14 +472,19 @@ class TwistedIOLoop(tornado.ioloop.IOLoop):
def stop(self): def stop(self):
self.reactor.crash() self.reactor.crash()
def add_timeout(self, deadline, callback): def add_timeout(self, deadline, callback, *args, **kwargs):
if isinstance(deadline, (int, long, float)): # This method could be simplified (since tornado 4.0) by
# overriding call_at instead of add_timeout, but we leave it
# for now as a test of backwards-compatibility.
if isinstance(deadline, numbers.Real):
delay = max(deadline - self.time(), 0) delay = max(deadline - self.time(), 0)
elif isinstance(deadline, datetime.timedelta): elif isinstance(deadline, datetime.timedelta):
delay = tornado.ioloop._Timeout.timedelta_to_seconds(deadline) delay = timedelta_to_seconds(deadline)
else: else:
raise TypeError("Unsupported deadline %r") raise TypeError("Unsupported deadline %r")
return self.reactor.callLater(delay, self._run_callback, wrap(callback)) return self.reactor.callLater(
delay, self._run_callback,
functools.partial(wrap(callback), *args, **kwargs))
def remove_timeout(self, timeout): def remove_timeout(self, timeout):
if timeout.active(): if timeout.active():

View file

@ -155,7 +155,7 @@ class TestIOLoop(AsyncTestCase):
def test_remove_timeout_after_fire(self): def test_remove_timeout_after_fire(self):
# It is not an error to call remove_timeout after it has run. # It is not an error to call remove_timeout after it has run.
handle = self.io_loop.add_timeout(self.io_loop.time(), self.stop()) handle = self.io_loop.add_timeout(self.io_loop.time(), self.stop)
self.wait() self.wait()
self.io_loop.remove_timeout(handle) self.io_loop.remove_timeout(handle)
@ -173,6 +173,18 @@ class TestIOLoop(AsyncTestCase):
self.io_loop.add_callback(lambda: self.io_loop.add_callback(self.stop)) self.io_loop.add_callback(lambda: self.io_loop.add_callback(self.stop))
self.wait() self.wait()
def test_timeout_with_arguments(self):
# This tests that all the timeout methods pass through *args correctly.
results = []
self.io_loop.add_timeout(self.io_loop.time(), results.append, 1)
self.io_loop.add_timeout(datetime.timedelta(seconds=0),
results.append, 2)
self.io_loop.call_at(self.io_loop.time(), results.append, 3)
self.io_loop.call_later(0, results.append, 4)
self.io_loop.call_later(0, self.stop)
self.wait()
self.assertEqual(results, [1, 2, 3, 4])
def test_close_file_object(self): def test_close_file_object(self):
"""When a file object is used instead of a numeric file descriptor, """When a file object is used instead of a numeric file descriptor,
the object should be closed (by IOLoop.close(all_fds=True), the object should be closed (by IOLoop.close(all_fds=True),

View file

@ -232,8 +232,11 @@ class TestIOStreamMixin(object):
self.assertFalse(self.connect_called) self.assertFalse(self.connect_called)
self.assertTrue(isinstance(stream.error, socket.error), stream.error) self.assertTrue(isinstance(stream.error, socket.error), stream.error)
if sys.platform != 'cygwin': if sys.platform != 'cygwin':
_ERRNO_CONNREFUSED = (errno.ECONNREFUSED,)
if hasattr(errno, "WSAECONNREFUSED"):
_ERRNO_CONNREFUSED += (errno.WSAECONNREFUSED,)
# cygwin's errnos don't match those used on native windows python # cygwin's errnos don't match those used on native windows python
self.assertEqual(stream.error.args[0], errno.ECONNREFUSED) self.assertTrue(stream.error.args[0] in _ERRNO_CONNREFUSED)
def test_gaierror(self): def test_gaierror(self):
# Test that IOStream sets its exc_info on getaddrinfo error # Test that IOStream sets its exc_info on getaddrinfo error

View file

@ -321,8 +321,10 @@ class SimpleHTTPClientTestMixin(object):
if sys.platform != 'cygwin': if sys.platform != 'cygwin':
# cygwin returns EPERM instead of ECONNREFUSED here # cygwin returns EPERM instead of ECONNREFUSED here
self.assertTrue(str(errno.ECONNREFUSED) in str(response.error), contains_errno = str(errno.ECONNREFUSED) in str(response.error)
response.error) if not contains_errno and hasattr(errno, "WSAECONNREFUSED"):
contains_errno = str(errno.WSAECONNREFUSED) in str(response.error)
self.assertTrue(contains_errno, response.error)
# This is usually "Connection refused". # This is usually "Connection refused".
# On windows, strerror is broken and returns "Unknown error". # On windows, strerror is broken and returns "Unknown error".
expected_message = os.strerror(errno.ECONNREFUSED) expected_message = os.strerror(errno.ECONNREFUSED)

View file

@ -35,11 +35,11 @@ class TestRequestHandler(RequestHandler):
logging.debug('in part3()') logging.debug('in part3()')
raise Exception('test exception') raise Exception('test exception')
def get_error_html(self, status_code, **kwargs): def write_error(self, status_code, **kwargs):
if 'exception' in kwargs and str(kwargs['exception']) == 'test exception': if 'exc_info' in kwargs and str(kwargs['exc_info'][1]) == 'test exception':
return 'got expected exception' self.write('got expected exception')
else: else:
return 'unexpected failure' self.write('unexpected failure')
class HTTPStackContextTest(AsyncHTTPTestCase): class HTTPStackContextTest(AsyncHTTPTestCase):

View file

@ -10,7 +10,7 @@ from tornado.template import DictLoader
from tornado.testing import AsyncHTTPTestCase, ExpectLog, gen_test from tornado.testing import AsyncHTTPTestCase, ExpectLog, gen_test
from tornado.test.util import unittest from tornado.test.util import unittest
from tornado.util import u, bytes_type, ObjectDict, unicode_type from tornado.util import u, bytes_type, ObjectDict, unicode_type
from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish
import binascii import binascii
import contextlib import contextlib
@ -773,20 +773,6 @@ class ErrorResponseTest(WebTestCase):
else: else:
self.write("Status: %d" % status_code) self.write("Status: %d" % status_code)
class GetErrorHtmlHandler(RequestHandler):
def get(self):
if self.get_argument("status", None):
self.send_error(int(self.get_argument("status")))
else:
1 / 0
def get_error_html(self, status_code, **kwargs):
self.set_header("Content-Type", "text/plain")
if "exception" in kwargs:
self.write("Exception: %s" % sys.exc_info()[0].__name__)
else:
self.write("Status: %d" % status_code)
class FailedWriteErrorHandler(RequestHandler): class FailedWriteErrorHandler(RequestHandler):
def get(self): def get(self):
1 / 0 1 / 0
@ -796,7 +782,6 @@ class ErrorResponseTest(WebTestCase):
return [url("/default", DefaultHandler), return [url("/default", DefaultHandler),
url("/write_error", WriteErrorHandler), url("/write_error", WriteErrorHandler),
url("/get_error_html", GetErrorHtmlHandler),
url("/failed_write_error", FailedWriteErrorHandler), url("/failed_write_error", FailedWriteErrorHandler),
] ]
@ -820,16 +805,6 @@ class ErrorResponseTest(WebTestCase):
self.assertEqual(response.code, 503) self.assertEqual(response.code, 503)
self.assertEqual(b"Status: 503", response.body) self.assertEqual(b"Status: 503", response.body)
def test_get_error_html(self):
with ExpectLog(app_log, "Uncaught exception"):
response = self.fetch("/get_error_html")
self.assertEqual(response.code, 500)
self.assertEqual(b"Exception: ZeroDivisionError", response.body)
response = self.fetch("/get_error_html?status=503")
self.assertEqual(response.code, 503)
self.assertEqual(b"Status: 503", response.body)
def test_failed_write_error(self): def test_failed_write_error(self):
with ExpectLog(app_log, "Uncaught exception"): with ExpectLog(app_log, "Uncaught exception"):
response = self.fetch("/failed_write_error") response = self.fetch("/failed_write_error")
@ -2307,3 +2282,20 @@ class XSRFTest(SimpleHandlerTestCase):
body=urllib_parse.urlencode(dict(_xsrf=body_token)), body=urllib_parse.urlencode(dict(_xsrf=body_token)),
headers=self.cookie_headers(cookie_token)) headers=self.cookie_headers(cookie_token))
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
@wsgi_safe
class FinishExceptionTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
self.set_status(401)
self.set_header('WWW-Authenticate', 'Basic realm="something"')
self.write('authentication required')
raise Finish()
def test_finish_exception(self):
response = self.fetch('/')
self.assertEqual(response.code, 401)
self.assertEqual('Basic realm="something"',
response.headers.get('WWW-Authenticate'))
self.assertEqual(b'authentication required', response.body)

View file

@ -70,8 +70,8 @@ def get_unused_port():
only that a series of get_unused_port calls in a single process return only that a series of get_unused_port calls in a single process return
distinct ports. distinct ports.
**Deprecated**. Use bind_unused_port instead, which is guaranteed .. deprecated::
to find an unused port. Use bind_unused_port instead, which is guaranteed to find an unused port.
""" """
global _next_port global _next_port
port = _next_port port = _next_port

View file

@ -311,6 +311,11 @@ class ArgReplacer(object):
return old_value, args, kwargs return old_value, args, kwargs
def timedelta_to_seconds(td):
"""Equivalent to td.total_seconds() (introduced in python 2.7)."""
return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / float(10 ** 6)
def _websocket_mask_python(mask, data): def _websocket_mask_python(mask, data):
"""Websocket masking function. """Websocket masking function.

View file

@ -630,7 +630,6 @@ class RequestHandler(object):
self.set_status(status) self.set_status(status)
self.set_header("Location", urlparse.urljoin(utf8(self.request.uri), self.set_header("Location", urlparse.urljoin(utf8(self.request.uri),
utf8(url))) utf8(url)))
self.finish() self.finish()
def write(self, chunk): def write(self, chunk):
@ -944,26 +943,7 @@ class RequestHandler(object):
``kwargs["exc_info"]``. Note that this exception may not be ``kwargs["exc_info"]``. Note that this exception may not be
the "current" exception for purposes of methods like the "current" exception for purposes of methods like
``sys.exc_info()`` or ``traceback.format_exc``. ``sys.exc_info()`` or ``traceback.format_exc``.
For historical reasons, if a method ``get_error_html`` exists,
it will be used instead of the default ``write_error`` implementation.
``get_error_html`` returned a string instead of producing output
normally, and had different semantics for exception handling.
Users of ``get_error_html`` are encouraged to convert their code
to override ``write_error`` instead.
""" """
if hasattr(self, 'get_error_html'):
if 'exc_info' in kwargs:
exc_info = kwargs.pop('exc_info')
kwargs['exception'] = exc_info[1]
try:
# Put the traceback into sys.exc_info()
raise_exc_info(exc_info)
except Exception:
self.finish(self.get_error_html(status_code, **kwargs))
else:
self.finish(self.get_error_html(status_code, **kwargs))
return
if self.settings.get("serve_traceback") and "exc_info" in kwargs: if self.settings.get("serve_traceback") and "exc_info" in kwargs:
# in debug mode, try to send a traceback # in debug mode, try to send a traceback
self.set_header('Content-Type', 'text/plain') self.set_header('Content-Type', 'text/plain')
@ -1385,6 +1365,11 @@ class RequestHandler(object):
" (" + self.request.remote_ip + ")" " (" + self.request.remote_ip + ")"
def _handle_request_exception(self, e): def _handle_request_exception(self, e):
if isinstance(e, Finish):
# Not an error; just finish the request without logging.
if not self._finished:
self.finish()
return
self.log_exception(*sys.exc_info()) self.log_exception(*sys.exc_info())
if self._finished: if self._finished:
# Extra errors after the request has been finished should # Extra errors after the request has been finished should
@ -1558,7 +1543,7 @@ def removeslash(method):
if uri: # don't try to redirect '/' to '' if uri: # don't try to redirect '/' to ''
if self.request.query: if self.request.query:
uri += "?" + self.request.query uri += "?" + self.request.query
self.redirectTo(uri, permanent=True) self.redirect(uri, permanent=True)
return return
else: else:
raise HTTPError(404) raise HTTPError(404)
@ -1580,7 +1565,7 @@ def addslash(method):
uri = self.request.path + "/" uri = self.request.path + "/"
if self.request.query: if self.request.query:
uri += "?" + self.request.query uri += "?" + self.request.query
self.redirectTo(uri, permanent=True) self.redirect(uri, permanent=True)
return return
raise HTTPError(404) raise HTTPError(404)
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
@ -1939,6 +1924,9 @@ class HTTPError(Exception):
`RequestHandler.send_error` since it automatically ends the `RequestHandler.send_error` since it automatically ends the
current function. current function.
To customize the response sent with an `HTTPError`, override
`RequestHandler.write_error`.
:arg int status_code: HTTP status code. Must be listed in :arg int status_code: HTTP status code. Must be listed in
`httplib.responses <http.client.responses>` unless the ``reason`` `httplib.responses <http.client.responses>` unless the ``reason``
keyword argument is given. keyword argument is given.
@ -1967,6 +1955,25 @@ class HTTPError(Exception):
return message return message
class Finish(Exception):
"""An exception that ends the request without producing an error response.
When `Finish` is raised in a `RequestHandler`, the request will end
(calling `RequestHandler.finish` if it hasn't already been called),
but the outgoing response will not be modified and the error-handling
methods (including `RequestHandler.write_error`) will not be called.
This can be a more convenient way to implement custom error pages
than overriding ``write_error`` (especially in library code)::
if self.current_user is None:
self.set_status(401)
self.set_header('WWW-Authenticate', 'Basic realm="something"')
raise Finish()
"""
pass
class MissingArgumentError(HTTPError): class MissingArgumentError(HTTPError):
"""Exception raised by `RequestHandler.get_argument`. """Exception raised by `RequestHandler.get_argument`.
@ -2494,9 +2501,9 @@ class FallbackHandler(RequestHandler):
class OutputTransform(object): class OutputTransform(object):
"""A transform modifies the result of an HTTP request (e.g., GZip encoding) """A transform modifies the result of an HTTP request (e.g., GZip encoding)
A new transform instance is created for every request. See the Applications are not expected to create their own OutputTransforms
GZipContentEncoding example below if you want to implement a or interact with them directly; the framework chooses which transforms
new Transform. (if any) to apply.
""" """
def __init__(self, request): def __init__(self, request):
pass pass
@ -2587,7 +2594,7 @@ def authenticated(method):
else: else:
next_url = self.request.uri next_url = self.request.uri
url += "?" + urlencode(dict(next=next_url)) url += "?" + urlencode(dict(next=next_url))
self.redirectTo(url) self.redirect(url)
return return
raise HTTPError(403) raise HTTPError(403)
return method(self, *args, **kwargs) return method(self, *args, **kwargs)

View file

@ -3,18 +3,17 @@
`WebSockets <http://dev.w3.org/html5/websockets/>`_ allow for bidirectional `WebSockets <http://dev.w3.org/html5/websockets/>`_ allow for bidirectional
communication between the browser and server. communication between the browser and server.
.. warning:: WebSockets are supported in the current versions of all major browsers,
although older versions that do not support WebSockets are still in use
(refer to http://caniuse.com/websockets for details).
The WebSocket protocol was recently finalized as `RFC 6455 This module implements the final version of the WebSocket protocol as
<http://tools.ietf.org/html/rfc6455>`_ and is not yet supported in defined in `RFC 6455 <http://tools.ietf.org/html/rfc6455>`_. Certain
all browsers. Refer to http://caniuse.com/websockets for details browser versions (notably Safari 5.x) implemented an earlier draft of
on compatibility. In addition, during development the protocol the protocol (known as "draft 76") and are not compatible with this module.
went through several incompatible versions, and some browsers only
support older versions. By default this module only supports the .. versionchanged:: 4.0
latest version of the protocol, but optional support for an older Removed support for the draft 76 protocol version.
version (known as "draft 76" or "hixie-76") can be enabled by
overriding `WebSocketHandler.allow_draft76` (see that method's
documentation for caveats).
""" """
from __future__ import absolute_import, division, print_function, with_statement from __future__ import absolute_import, division, print_function, with_statement
@ -22,11 +21,9 @@ from __future__ import absolute_import, division, print_function, with_statement
import base64 import base64
import collections import collections
import functools
import hashlib import hashlib
import os import os
import struct import struct
import time
import tornado.escape import tornado.escape
import tornado.web import tornado.web
@ -38,7 +35,7 @@ from tornado.iostream import StreamClosedError
from tornado.log import gen_log, app_log from tornado.log import gen_log, app_log
from tornado import simple_httpclient from tornado import simple_httpclient
from tornado.tcpclient import TCPClient from tornado.tcpclient import TCPClient
from tornado.util import bytes_type, unicode_type, _websocket_mask from tornado.util import bytes_type, _websocket_mask
try: try:
from urllib.parse import urlparse # py2 from urllib.parse import urlparse # py2
@ -161,10 +158,6 @@ class WebSocketHandler(tornado.web.RequestHandler):
if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"): if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
self.ws_connection = WebSocketProtocol13(self) self.ws_connection = WebSocketProtocol13(self)
self.ws_connection.accept_connection() self.ws_connection.accept_connection()
elif (self.allow_draft76() and
"Sec-WebSocket-Version" not in self.request.headers):
self.ws_connection = WebSocketProtocol76(self)
self.ws_connection.accept_connection()
else: else:
self.stream.write(tornado.escape.utf8( self.stream.write(tornado.escape.utf8(
"HTTP/1.1 426 Upgrade Required\r\n" "HTTP/1.1 426 Upgrade Required\r\n"
@ -256,9 +249,6 @@ class WebSocketHandler(tornado.web.RequestHandler):
closing. These values are made available to the client, but are closing. These values are made available to the client, but are
not otherwise interpreted by the websocket protocol. not otherwise interpreted by the websocket protocol.
The ``code`` and ``reason`` arguments are ignored in the "draft76"
protocol version.
.. versionchanged:: 4.0 .. versionchanged:: 4.0
Added the ``code`` and ``reason`` arguments. Added the ``code`` and ``reason`` arguments.
@ -296,21 +286,6 @@ class WebSocketHandler(tornado.web.RequestHandler):
# Check to see that origin matches host directly, including ports # Check to see that origin matches host directly, including ports
return origin == host return origin == host
def allow_draft76(self):
"""Override to enable support for the older "draft76" protocol.
The draft76 version of the websocket protocol is disabled by
default due to security concerns, but it can be enabled by
overriding this method to return True.
Connections using the draft76 protocol do not support the
``binary=True`` flag to `write_message`.
Support for the draft76 protocol is deprecated and will be
removed in a future version of Tornado.
"""
return False
def set_nodelay(self, value): def set_nodelay(self, value):
"""Set the no-delay flag for this stream. """Set the no-delay flag for this stream.
@ -327,18 +302,6 @@ class WebSocketHandler(tornado.web.RequestHandler):
""" """
self.stream.set_nodelay(value) self.stream.set_nodelay(value)
def get_websocket_scheme(self):
"""Return the url scheme used for this request, either "ws" or "wss".
This is normally decided by HTTPServer, but applications
may wish to override this if they are using an SSL proxy
that does not provide the X-Scheme header as understood
by HTTPServer.
Note that this is only used by the draft76 protocol.
"""
return "wss" if self.request.protocol == "https" else "ws"
def on_connection_close(self): def on_connection_close(self):
if self.ws_connection: if self.ws_connection:
self.ws_connection.on_connection_close() self.ws_connection.on_connection_close()
@ -392,175 +355,6 @@ class WebSocketProtocol(object):
self.close() # let the subclass cleanup self.close() # let the subclass cleanup
class WebSocketProtocol76(WebSocketProtocol):
"""Implementation of the WebSockets protocol, version hixie-76.
This class provides basic functionality to process WebSockets requests as
specified in
http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76
"""
def __init__(self, handler):
WebSocketProtocol.__init__(self, handler)
self.challenge = None
self._waiting = None
def accept_connection(self):
try:
self._handle_websocket_headers()
except ValueError:
gen_log.debug("Malformed WebSocket request received")
self._abort()
return
scheme = self.handler.get_websocket_scheme()
# draft76 only allows a single subprotocol
subprotocol_header = ''
subprotocol = self.request.headers.get("Sec-WebSocket-Protocol", None)
if subprotocol:
selected = self.handler.select_subprotocol([subprotocol])
if selected:
assert selected == subprotocol
subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
# Write the initial headers before attempting to read the challenge.
# This is necessary when using proxies (such as HAProxy), which
# need to see the Upgrade headers before passing through the
# non-HTTP traffic that follows.
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
"Upgrade: WebSocket\r\n"
"Connection: Upgrade\r\n"
"Server: TornadoServer/%(version)s\r\n"
"Sec-WebSocket-Origin: %(origin)s\r\n"
"Sec-WebSocket-Location: %(scheme)s://%(host)s%(uri)s\r\n"
"%(subprotocol)s"
"\r\n" % (dict(
version=tornado.version,
origin=self.request.headers["Origin"],
scheme=scheme,
host=self.request.host,
uri=self.request.uri,
subprotocol=subprotocol_header))))
self.stream.read_bytes(8, self._handle_challenge)
def challenge_response(self, challenge):
"""Generates the challenge response that's needed in the handshake
The challenge parameter should be the raw bytes as sent from the
client.
"""
key_1 = self.request.headers.get("Sec-Websocket-Key1")
key_2 = self.request.headers.get("Sec-Websocket-Key2")
try:
part_1 = self._calculate_part(key_1)
part_2 = self._calculate_part(key_2)
except ValueError:
raise ValueError("Invalid Keys/Challenge")
return self._generate_challenge_response(part_1, part_2, challenge)
def _handle_challenge(self, challenge):
try:
challenge_response = self.challenge_response(challenge)
except ValueError:
gen_log.debug("Malformed key data in WebSocket request")
self._abort()
return
self._write_response(challenge_response)
def _write_response(self, challenge):
self.stream.write(challenge)
self._run_callback(self.handler.open, *self.handler.open_args,
**self.handler.open_kwargs)
self._receive_message()
def _handle_websocket_headers(self):
"""Verifies all invariant- and required headers
If a header is missing or have an incorrect value ValueError will be
raised
"""
fields = ("Origin", "Host", "Sec-Websocket-Key1",
"Sec-Websocket-Key2")
if not all(map(lambda f: self.request.headers.get(f), fields)):
raise ValueError("Missing/Invalid WebSocket headers")
def _calculate_part(self, key):
"""Processes the key headers and calculates their key value.
Raises ValueError when feed invalid key."""
# pyflakes complains about variable reuse if both of these lines use 'c'
number = int(''.join(c for c in key if c.isdigit()))
spaces = len([c2 for c2 in key if c2.isspace()])
try:
key_number = number // spaces
except (ValueError, ZeroDivisionError):
raise ValueError
return struct.pack(">I", key_number)
def _generate_challenge_response(self, part_1, part_2, part_3):
m = hashlib.md5()
m.update(part_1)
m.update(part_2)
m.update(part_3)
return m.digest()
def _receive_message(self):
self.stream.read_bytes(1, self._on_frame_type)
def _on_frame_type(self, byte):
frame_type = ord(byte)
if frame_type == 0x00:
self.stream.read_until(b"\xff", self._on_end_delimiter)
elif frame_type == 0xff:
self.stream.read_bytes(1, self._on_length_indicator)
else:
self._abort()
def _on_end_delimiter(self, frame):
if not self.client_terminated:
self._run_callback(self.handler.on_message,
frame[:-1].decode("utf-8", "replace"))
if not self.client_terminated:
self._receive_message()
def _on_length_indicator(self, byte):
if ord(byte) != 0x00:
self._abort()
return
self.client_terminated = True
self.close()
def write_message(self, message, binary=False):
"""Sends the given message to the client of this Web Socket."""
if binary:
raise ValueError(
"Binary messages not supported by this version of websockets")
if isinstance(message, unicode_type):
message = message.encode("utf-8")
assert isinstance(message, bytes_type)
self.stream.write(b"\x00" + message + b"\xff")
def write_ping(self, data):
"""Send ping frame."""
raise ValueError("Ping messages not supported by this version of websockets")
def close(self, code=None, reason=None):
"""Closes the WebSocket connection."""
if not self.server_terminated:
if not self.stream.closed():
self.stream.write("\xff\x00")
self.server_terminated = True
if self.client_terminated:
if self._waiting is not None:
self.stream.io_loop.remove_timeout(self._waiting)
self._waiting = None
self.stream.close()
elif self._waiting is None:
self._waiting = self.stream.io_loop.add_timeout(
time.time() + 5, self._abort)
class WebSocketProtocol13(WebSocketProtocol): class WebSocketProtocol13(WebSocketProtocol):
"""Implementation of the WebSocket protocol from RFC 6455. """Implementation of the WebSocket protocol from RFC 6455.