diff --git a/gui/slick/js/ajaxNotifications.js b/gui/slick/js/ajaxNotifications.js index 8f1a12a5..b87f4af7 100644 --- a/gui/slick/js/ajaxNotifications.js +++ b/gui/slick/js/ajaxNotifications.js @@ -1,24 +1,39 @@ -var message_url = sbRoot + '/ui/get_messages'; -$.pnotify.defaults.pnotify_width = "340px"; -$.pnotify.defaults.pnotify_history = false; -$.pnotify.defaults.pnotify_delay = 4000; +var message_url = sbRoot + '/ui/get_messages/'; +$.pnotify.defaults.width = "400px"; +$.pnotify.defaults.styling = "jqueryui"; +$.pnotify.defaults.history = false; +$.pnotify.defaults.shadow = false; +$.pnotify.defaults.delay = 4000; +$.pnotify.defaults.maxonscreen = 5; function check_notifications() { - $.getJSON(message_url, function(data){ - $.each(data, function(name,data){ - $.pnotify({ - pnotify_type: data.type, - pnotify_hide: data.type == 'notice', - pnotify_title: data.title, - pnotify_text: data.message + var poll_interval = 5000; + $.ajax({ + url: message_url, + success: function (data) { + poll_interval = 5000; + $.each(data, function (name, data) { + $.pnotify({ + type: data.type, + hide: data.type == 'notice', + title: data.title, + text: data.message + }); }); - }); + }, + error: function () { + poll_interval = 15000; + }, + type: "GET", + dataType: "json", + complete: function () { + setTimeout(check_notifications, poll_interval); + }, + timeout: 15000 // timeout every 15 secs }); - - setTimeout(check_notifications, 3000) } -$(document).ready(function(){ +$(document).ready(function () { check_notifications(); diff --git a/sickbeard/__init__.py b/sickbeard/__init__.py index 0c95bb38..d551e829 100644 --- a/sickbeard/__init__.py +++ b/sickbeard/__init__.py @@ -756,10 +756,9 @@ def initialize(consoleLogging=True): USE_PUSHOVER = bool(check_setting_int(CFG, 'Pushover', 'use_pushover', 0)) PUSHOVER_NOTIFY_ONSNATCH = bool(check_setting_int(CFG, 'Pushover', 'pushover_notify_onsnatch', 0)) PUSHOVER_NOTIFY_ONDOWNLOAD = bool(check_setting_int(CFG, 'Pushover', 'pushover_notify_ondownload', 0)) - PUSHOVER_NOTIFY_ONSUBTITLEDOWNLOAD = bool( - check_setting_int(CFG, 'Pushover', 'pushover_notify_onsubtitledownload', 0)) + PUSHOVER_NOTIFY_ONSUBTITLEDOWNLOAD = bool(check_setting_int(CFG, 'Pushover', 'pushover_notify_onsubtitledownload', 0)) PUSHOVER_USERKEY = check_setting_str(CFG, 'Pushover', 'pushover_userkey', '') - PUSHOVER_APIKEY = check_setting_str(CFG, 'Pushover', 'pushover_apikey', '') + PUSHOVER_APIKEY = check_setting_str(CFG, 'Pushover', 'pushover_apikey', '') USE_LIBNOTIFY = bool(check_setting_int(CFG, 'Libnotify', 'use_libnotify', 0)) LIBNOTIFY_NOTIFY_ONSNATCH = bool(check_setting_int(CFG, 'Libnotify', 'libnotify_notify_onsnatch', 0)) LIBNOTIFY_NOTIFY_ONDOWNLOAD = bool(check_setting_int(CFG, 'Libnotify', 'libnotify_notify_ondownload', 0)) diff --git a/sickbeard/helpers.py b/sickbeard/helpers.py index 05bef4bc..e0cb730a 100644 --- a/sickbeard/helpers.py +++ b/sickbeard/helpers.py @@ -263,9 +263,9 @@ def download_file(url, filename): def findCertainShow(showList, indexerid=None): if indexerid: - results = filter(lambda x: x.indexerid == indexerid, showList) + results = filter(lambda x: int(x.indexerid) == int(indexerid), showList) else: - results = filter(lambda x: x.indexerid == indexerid, showList) + results = filter(lambda x: int(x.indexerid) == int(indexerid), showList) if len(results) == 0: return None diff --git a/sickbeard/providers/womble.py b/sickbeard/providers/womble.py index 64d32508..8dbfe9b7 100644 --- a/sickbeard/providers/womble.py +++ b/sickbeard/providers/womble.py @@ -21,6 +21,7 @@ import generic from sickbeard import logger from sickbeard import tvcache +from sickbeard.exceptions import AuthException class WombleProvider(generic.NZBProvider): @@ -40,13 +41,40 @@ class WombleCache(tvcache.TVCache): # only poll Womble's Index every 15 minutes max self.minTime = 15 - def _getRSSData(self): - url = self.provider.url + 'rss/?sec=TV-x264&fr=false' - logger.log(u"Womble's Index cache update URL: " + url, logger.DEBUG) - return self.getRSSFeed(url) + def updateCache(self): + + # delete anything older then 7 days + logger.log(u"Clearing " + self.provider.name + " cache") + self._clearCache() + + data = None + if not self.shouldUpdate(): + for url in [self.provider.url + 'rss/?sec=tv-sd&fr=false', self.provider.url + 'rss/?sec=tv-hd&fr=false']: + logger.log(u"Womble's Index cache update URL: " + url, logger.DEBUG) + data = self.getRSSFeed(url) + + # As long as we got something from the provider we count it as an update + if not data: + return [] + + # By now we know we've got data and no auth errors, all we need to do is put it in the database + cl = [] + for item in data.entries: + + ci = self._parseItem(item) + if ci is not None: + cl.append(ci) + + if cl: + myDB = self._getDB() + myDB.mass_action(cl) + + # set last updated + if data: + self.setLastUpdate() def _checkAuth(self, data): return data != 'Invalid Link' - provider = WombleProvider() + diff --git a/sickbeard/rssfeeds.py b/sickbeard/rssfeeds.py new file mode 100644 index 00000000..71c2623c --- /dev/null +++ b/sickbeard/rssfeeds.py @@ -0,0 +1,62 @@ +import os +import threading +import urllib +import urlparse +import re +import time +import sickbeard + +from sickbeard import logger +from sickbeard import encodingKludge as ek +from sickbeard.exceptions import ex +from lib.shove import Shove +from lib.feedcache import cache + +feed_lock = threading.Lock() + +class RSSFeeds: + def __init__(self, db_name): + try: + self.fs = self.fs = Shove('sqlite:///' + ek.ek(os.path.join, sickbeard.CACHE_DIR, db_name + '.db'), compress=True) + self.fc = cache.Cache(self.fs) + except Exception, e: + logger.log(u"RSS error: " + ex(e), logger.ERROR) + raise + + def __enter__(self): + return self + + def __exit__(self, type, value, tb): + self.fc = None + self.fs.close() + + def clearCache(self, age=None): + if not self.fc: + return + + self.fc.purge(age) + + def getRSSFeed(self, url, post_data=None): + if not self.fc: + return + + with feed_lock: + parsed = list(urlparse.urlparse(url)) + parsed[2] = re.sub("/{2,}", "/", parsed[2]) # replace two or more / with one + + if post_data: + url += urllib.urlencode(post_data) + + feed = self.fc.fetch(url) + if not feed: + logger.log(u"RSS Error loading URL: " + url, logger.ERROR) + return + elif 'error' in feed.feed: + logger.log(u"RSS ERROR:[%s] CODE:[%s]" % (feed.feed['error']['description'], feed.feed['error']['code']), + logger.DEBUG) + return + elif not feed.entries: + logger.log(u"No RSS items found using URL: " + url, logger.WARNING) + return + + return feed \ No newline at end of file diff --git a/sickbeard/tvcache.py b/sickbeard/tvcache.py index 5de6cb68..16879349 100644 --- a/sickbeard/tvcache.py +++ b/sickbeard/tvcache.py @@ -22,29 +22,21 @@ import os import time import datetime -import urllib -import urlparse -import re import threading import sickbeard -from lib.shove import Shove -from lib.feedcache import cache - from sickbeard import db from sickbeard import logger -from sickbeard.common import Quality, cpu_presets +from sickbeard.common import Quality from sickbeard import helpers, show_name_helpers from sickbeard.exceptions import MultipleShowObjectsException from sickbeard.exceptions import AuthException -from sickbeard import encodingKludge as ek - from name_parser.parser import NameParser, InvalidNameException +from sickbeard.rssfeeds import RSSFeeds cache_lock = threading.Lock() - class CacheDBConnection(db.DBConnection): def __init__(self, providerName): db.DBConnection.__init__(self, "cache.db") @@ -87,13 +79,15 @@ class TVCache(): return CacheDBConnection(self.providerID) def _clearCache(self): - if not self.shouldClearCache(): - return + if self.shouldClearCache(): + curDate = datetime.date.today() - datetime.timedelta(weeks=1) - curDate = datetime.date.today() - datetime.timedelta(weeks=1) + myDB = self._getDB() + myDB.action("DELETE FROM [" + self.providerID + "] WHERE time < ?", [int(time.mktime(curDate.timetuple()))]) - myDB = self._getDB() - myDB.action("DELETE FROM [" + self.providerID + "] WHERE time < ?", [int(time.mktime(curDate.timetuple()))]) + # clear RSS Feed cache + with RSSFeeds(self.providerID) as feed: + feed.clearCache(int(time.mktime(curDate.timetuple()))) def _getRSSData(self): @@ -126,9 +120,8 @@ class TVCache(): return [] if self._checkAuth(data): - items = data.entries cl = [] - for item in items: + for item in data.entries: ci = self._parseItem(item) if ci is not None: cl.append(ci) @@ -143,34 +136,10 @@ class TVCache(): return [] - def getRSSFeed(self, url, post_data=None, request_headers=None): - # create provider storaqe cache - storage = Shove('sqlite:///' + ek.ek(os.path.join, sickbeard.CACHE_DIR, self.provider.name) + '.db') - fc = cache.Cache(storage) - - parsed = list(urlparse.urlparse(url)) - parsed[2] = re.sub("/{2,}", "/", parsed[2]) # replace two or more / with one - - if post_data: - url += urllib.urlencode(post_data) - - f = fc.fetch(url, request_headers=request_headers) - - if not f: - logger.log(u"Error loading " + self.providerID + " URL: " + url, logger.ERROR) - return None - elif 'error' in f.feed: - logger.log(u"Newznab ERROR:[%s] CODE:[%s]" % (f.feed['error']['description'], f.feed['error']['code']), - logger.DEBUG) - return None - elif not f.entries: - logger.log(u"No items found on " + self.providerID + " using URL: " + url, logger.WARNING) - return None - - storage.close() - - return f - + def getRSSFeed(self, url, post_data=None): + with RSSFeeds(self.providerID) as feed: + data = feed.getRSSFeed(url, post_data) + return data def _translateTitle(self, title): return title.replace(' ', '.') diff --git a/sickbeard/webserve.py b/sickbeard/webserve.py index fd061f3d..2980e389 100644 --- a/sickbeard/webserve.py +++ b/sickbeard/webserve.py @@ -19,6 +19,7 @@ from __future__ import with_statement import base64 import inspect +import traceback import urlparse import zipfile @@ -136,7 +137,9 @@ class MainHandler(RequestHandler): super(MainHandler, self).__init__(application, request, **kwargs) global req_headers - sickbeard.REMOTE_IP = self.request.remote_ip + sickbeard.REMOTE_IP = self.request.headers.get('X-Forwarded-For', + self.request.headers.get('X-Real-Ip', self.request.remote_ip)) + req_headers = self.request.headers def http_error_401_handler(self): @@ -158,11 +161,12 @@ class MainHandler(RequestHandler): return self.redirectTo('/home/') def write_error(self, status_code, **kwargs): - if status_code == 404: - return self.redirectTo('/home/') - elif status_code == 401: + if status_code == 401: self.finish(self.http_error_401_handler()) + elif status_code == 404: + self.redirectTo('/home/') else: + logger.log(traceback.format_exc(), logger.DEBUG) super(MainHandler, self).write_error(status_code, **kwargs) def _dispatch(self): @@ -209,22 +213,17 @@ class MainHandler(RequestHandler): raise HTTPError(404) def redirectTo(self, url): - self._transforms = [] - url = urlparse.urljoin(sickbeard.WEB_ROOT, url) logger.log(u"Redirecting to: " + url, logger.DEBUG) - self.redirect(url, status=303) + self._transforms = [] + self.redirect(url) def get(self, *args, **kwargs): - response = self._dispatch() - if response: - self.finish(response) + self.write(self._dispatch()) def post(self, *args, **kwargs): - response = self._dispatch() - if response: - self.finish(response) + self._dispatch() def robots_txt(self, *args, **kwargs): """ Keep web crawlers out """ @@ -456,13 +455,13 @@ class MainHandler(RequestHandler): browser = WebFileBrowser - class PageTemplate(Template): def __init__(self, *args, **KWs): + global req_headers + KWs['file'] = os.path.join(sickbeard.PROG_DIR, "gui/" + sickbeard.GUI_NAME + "/interfaces/default/", KWs['file']) super(PageTemplate, self).__init__(*args, **KWs) - global req_headers self.sbRoot = sickbeard.WEB_ROOT self.sbHttpPort = sickbeard.WEB_PORT @@ -495,7 +494,7 @@ class PageTemplate(Template): {'title': 'Manage', 'key': 'manage'}, {'title': 'Config', 'key': 'config'}, {'title': logPageTitle, 'key': 'errorlogs'}, - ] + ] class IndexerWebUI(MainHandler): @@ -512,9 +511,7 @@ class IndexerWebUI(MainHandler): def _munge(string): - to_return = unicode(string).encode('utf-8', 'xmlcharrefreplace') - return to_return - + return unicode(string).encode('utf-8', 'xmlcharrefreplace') def _genericMessage(subject, message): t = PageTemplate(file="genericMessage.tmpl") @@ -4296,21 +4293,21 @@ class Home(MainHandler): return json.dumps({'result': 'failure'}) - class UI(MainHandler): - def add_message(self, *args, **kwargs): + def add_message(self): + ui.notifications.message('Test 1', 'This is test number 1') ui.notifications.error('Test 2', 'This is test number 2') return "ok" - def get_messages(self, *args, **kwargs): + def get_messages(self): messages = {} cur_notification_num = 1 for cur_notification in ui.notifications.get_notifications(): messages['notification-' + str(cur_notification_num)] = {'title': cur_notification.title, - 'message': cur_notification.message, - 'type': cur_notification.type} + 'message': cur_notification.message, + 'type': cur_notification.type} cur_notification_num += 1 - return json.dumps(messages) + return json.dumps(messages) \ No newline at end of file diff --git a/tornado/__init__.py b/tornado/__init__.py index 81900d20..5e90b770 100644 --- a/tornado/__init__.py +++ b/tornado/__init__.py @@ -25,5 +25,5 @@ from __future__ import absolute_import, division, print_function, with_statement # is zero for an official release, positive for a development branch, # or negative for a release candidate or beta (after the base version # number has been incremented) -version = "4.0.dev1" -version_info = (4, 0, 0, -100) +version = "4.0b1" +version_info = (4, 0, 0, -99) diff --git a/tornado/auth.py b/tornado/auth.py index f8dadb66..7bd3fa1e 100644 --- a/tornado/auth.py +++ b/tornado/auth.py @@ -883,9 +883,10 @@ class FriendFeedMixin(OAuthMixin): class GoogleMixin(OpenIdMixin, OAuthMixin): """Google Open ID / OAuth authentication. - *Deprecated:* New applications should use `GoogleOAuth2Mixin` - below instead of this class. As of May 19, 2014, Google has stopped - supporting registration-free authentication. + .. deprecated:: 4.0 + New applications should use `GoogleOAuth2Mixin` + below instead of this class. As of May 19, 2014, Google has stopped + supporting registration-free authentication. No application registration is necessary to use Google for authentication or to access Google resources on behalf of a user. @@ -1053,9 +1054,10 @@ class GoogleOAuth2Mixin(OAuth2Mixin): class FacebookMixin(object): """Facebook Connect authentication. - *Deprecated:* New applications should use `FacebookGraphMixin` - below instead of this class. This class does not support the - Future-based interface seen on other classes in this module. + .. deprecated:: 1.1 + New applications should use `FacebookGraphMixin` + below instead of this class. This class does not support the + Future-based interface seen on other classes in this module. To authenticate with Facebook, register your application with Facebook at http://www.facebook.com/developers/apps.php. Then diff --git a/tornado/curl_httpclient.py b/tornado/curl_httpclient.py index c190ac91..ae4471fd 100644 --- a/tornado/curl_httpclient.py +++ b/tornado/curl_httpclient.py @@ -51,18 +51,6 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): self._fds = {} self._timeout = None - try: - self._socket_action = self._multi.socket_action - except AttributeError: - # socket_action is found in pycurl since 7.18.2 (it's been - # in libcurl longer than that but wasn't accessible to - # python). - gen_log.warning("socket_action method missing from pycurl; " - "falling back to socket_all. Upgrading " - "libcurl and pycurl will improve performance") - self._socket_action = \ - lambda fd, action: self._multi.socket_all() - # libcurl has bugs that sometimes cause it to not report all # relevant file descriptors and timeouts to TIMERFUNCTION/ # SOCKETFUNCTION. Mitigate the effects of such bugs by @@ -142,7 +130,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): action |= pycurl.CSELECT_OUT while True: try: - ret, num_handles = self._socket_action(fd, action) + ret, num_handles = self._multi.socket_action(fd, action) except pycurl.error as e: ret = e.args[0] if ret != pycurl.E_CALL_MULTI_PERFORM: @@ -155,7 +143,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): self._timeout = None while True: try: - ret, num_handles = self._socket_action( + ret, num_handles = self._multi.socket_action( pycurl.SOCKET_TIMEOUT, 0) except pycurl.error as e: ret = e.args[0] @@ -223,11 +211,6 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): "callback": callback, "curl_start_time": time.time(), } - # Disable IPv6 to mitigate the effects of this bug - # on curl versions <= 7.21.0 - # http://sourceforge.net/tracker/?func=detail&aid=3017819&group_id=976&atid=100976 - if pycurl.version_info()[2] <= 0x71500: # 7.21.0 - curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4) _curl_setup_request(curl, request, curl.info["buffer"], curl.info["headers"]) self._multi.add_handle(curl) @@ -383,7 +366,6 @@ def _curl_setup_request(curl, request, buffer, headers): if request.allow_ipv6 is False: # Curl behaves reasonably when DNS resolution gives an ipv6 address # that we can't reach, so allow ipv6 unless the user asks to disable. - # (but see version check in _process_queue above) curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4) else: curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_WHATEVER) diff --git a/tornado/httpclient.py b/tornado/httpclient.py index 48731c15..8418b5b2 100644 --- a/tornado/httpclient.py +++ b/tornado/httpclient.py @@ -22,14 +22,17 @@ to switch to ``curl_httpclient`` for reasons such as the following: * ``curl_httpclient`` was the default prior to Tornado 2.0. -Note that if you are using ``curl_httpclient``, it is highly recommended that -you use a recent version of ``libcurl`` and ``pycurl``. Currently the minimum -supported version is 7.18.2, and the recommended version is 7.21.1 or newer. -It is highly recommended that your ``libcurl`` installation is built with -asynchronous DNS resolver (threaded or c-ares), otherwise you may encounter -various problems with request timeouts (for more information, see +Note that if you are using ``curl_httpclient``, it is highly +recommended that you use a recent version of ``libcurl`` and +``pycurl``. Currently the minimum supported version of libcurl is +7.21.1, and the minimum version of pycurl is 7.18.2. It is highly +recommended that your ``libcurl`` installation is built with +asynchronous DNS resolver (threaded or c-ares), otherwise you may +encounter various problems with request timeouts (for more +information, see http://curl.haxx.se/libcurl/c/curl_easy_setopt.html#CURLOPTCONNECTTIMEOUTMS and comments in curl_httpclient.py). + """ from __future__ import absolute_import, division, print_function, with_statement @@ -144,12 +147,21 @@ class AsyncHTTPClient(Configurable): def __new__(cls, io_loop=None, force_instance=False, **kwargs): io_loop = io_loop or IOLoop.current() - if io_loop in cls._async_clients() and not force_instance: - return cls._async_clients()[io_loop] + if force_instance: + instance_cache = None + else: + instance_cache = cls._async_clients() + if instance_cache is not None and io_loop in instance_cache: + return instance_cache[io_loop] instance = super(AsyncHTTPClient, cls).__new__(cls, io_loop=io_loop, **kwargs) - if not force_instance: - cls._async_clients()[io_loop] = instance + # Make sure the instance knows which cache to remove itself from. + # It can't simply call _async_clients() because we may be in + # __new__(AsyncHTTPClient) but instance.__class__ may be + # SimpleAsyncHTTPClient. + instance._instance_cache = instance_cache + if instance_cache is not None: + instance_cache[instance.io_loop] = instance return instance def initialize(self, io_loop, defaults=None): @@ -172,9 +184,13 @@ class AsyncHTTPClient(Configurable): ``close()``. """ + if self._closed: + return self._closed = True - if self._async_clients().get(self.io_loop) is self: - del self._async_clients()[self.io_loop] + if self._instance_cache is not None: + if self._instance_cache.get(self.io_loop) is not self: + raise RuntimeError("inconsistent AsyncHTTPClient cache") + del self._instance_cache[self.io_loop] def fetch(self, request, callback=None, **kwargs): """Executes a request, asynchronously returning an `HTTPResponse`. diff --git a/tornado/ioloop.py b/tornado/ioloop.py index 3477684c..da9b7dbd 100644 --- a/tornado/ioloop.py +++ b/tornado/ioloop.py @@ -45,8 +45,7 @@ import traceback from tornado.concurrent import TracebackFuture, is_future from tornado.log import app_log, gen_log from tornado import stack_context -from tornado.util import Configurable -from tornado.util import errno_from_exception +from tornado.util import Configurable, errno_from_exception, timedelta_to_seconds try: import signal @@ -433,7 +432,7 @@ class IOLoop(Configurable): """ return time.time() - def add_timeout(self, deadline, callback): + def add_timeout(self, deadline, callback, *args, **kwargs): """Runs the ``callback`` at the time ``deadline`` from the I/O loop. Returns an opaque handle that may be passed to @@ -442,13 +441,59 @@ class IOLoop(Configurable): ``deadline`` may be a number denoting a time (on the same scale as `IOLoop.time`, normally `time.time`), or a `datetime.timedelta` object for a deadline relative to the - current time. + current time. Since Tornado 4.0, `call_later` is a more + convenient alternative for the relative case since it does not + require a timedelta object. Note that it is not safe to call `add_timeout` from other threads. Instead, you must use `add_callback` to transfer control to the `IOLoop`'s thread, and then call `add_timeout` from there. + + Subclasses of IOLoop must implement either `add_timeout` or + `call_at`; the default implementations of each will call + the other. `call_at` is usually easier to implement, but + subclasses that wish to maintain compatibility with Tornado + versions prior to 4.0 must use `add_timeout` instead. + + .. versionchanged:: 4.0 + Now passes through ``*args`` and ``**kwargs`` to the callback. """ - raise NotImplementedError() + if isinstance(deadline, numbers.Real): + return self.call_at(deadline, callback, *args, **kwargs) + elif isinstance(deadline, datetime.timedelta): + return self.call_at(self.time() + timedelta_to_seconds(deadline), + callback, *args, **kwargs) + else: + raise TypeError("Unsupported deadline %r" % deadline) + + def call_later(self, delay, callback, *args, **kwargs): + """Runs the ``callback`` after ``delay`` seconds have passed. + + Returns an opaque handle that may be passed to `remove_timeout` + to cancel. Note that unlike the `asyncio` method of the same + name, the returned object does not have a ``cancel()`` method. + + See `add_timeout` for comments on thread-safety and subclassing. + + .. versionadded:: 4.0 + """ + self.call_at(self.time() + delay, callback, *args, **kwargs) + + def call_at(self, when, callback, *args, **kwargs): + """Runs the ``callback`` at the absolute time designated by ``when``. + + ``when`` must be a number using the same reference point as + `IOLoop.time`. + + Returns an opaque handle that may be passed to `remove_timeout` + to cancel. Note that unlike the `asyncio` method of the same + name, the returned object does not have a ``cancel()`` method. + + See `add_timeout` for comments on thread-safety and subclassing. + + .. versionadded:: 4.0 + """ + self.add_timeout(when, callback, *args, **kwargs) def remove_timeout(self, timeout): """Cancels a pending timeout. @@ -606,7 +651,7 @@ class PollIOLoop(IOLoop): self._thread_ident = None self._blocking_signal_threshold = None self._timeout_counter = itertools.count() - + # Create a pipe that we send bogus data to when we want to wake # the I/O loop when it is idle self._waker = Waker() @@ -813,8 +858,11 @@ class PollIOLoop(IOLoop): def time(self): return self.time_func() - def add_timeout(self, deadline, callback): - timeout = _Timeout(deadline, stack_context.wrap(callback), self) + def call_at(self, deadline, callback, *args, **kwargs): + timeout = _Timeout( + deadline, + functools.partial(stack_context.wrap(callback), *args, **kwargs), + self) heapq.heappush(self._timeouts, timeout) return timeout @@ -869,24 +917,12 @@ class _Timeout(object): __slots__ = ['deadline', 'callback', 'tiebreaker'] def __init__(self, deadline, callback, io_loop): - if isinstance(deadline, numbers.Real): - self.deadline = deadline - elif isinstance(deadline, datetime.timedelta): - now = io_loop.time() - try: - self.deadline = now + deadline.total_seconds() - except AttributeError: # py2.6 - self.deadline = now + _Timeout.timedelta_to_seconds(deadline) - else: + if not isinstance(deadline, numbers.Real): raise TypeError("Unsupported deadline %r" % deadline) + self.deadline = deadline self.callback = callback self.tiebreaker = next(io_loop._timeout_counter) - @staticmethod - def timedelta_to_seconds(td): - """Equivalent to td.total_seconds() (introduced in python 2.7).""" - return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / float(10 ** 6) - # Comparison methods to sort by deadline, with object id as a tiebreaker # to guarantee a consistent ordering. The heapq module uses __le__ # in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons diff --git a/tornado/iostream.py b/tornado/iostream.py index 8b614258..3ebcd586 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -57,12 +57,24 @@ except ImportError: # some they differ. _ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN) +if hasattr(errno, "WSAEWOULDBLOCK"): + _ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,) + # These errnos indicate that a connection has been abruptly terminated. # They should be caught and handled less noisily than other errors. _ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE, errno.ETIMEDOUT) +if hasattr(errno, "WSAECONNRESET"): + _ERRNO_CONNRESET += (errno.WSAECONNRESET, errno.WSAECONNABORTED, errno.WSAETIMEDOUT) +# More non-portable errnos: +_ERRNO_INPROGRESS = (errno.EINPROGRESS,) + +if hasattr(errno, "WSAEINPROGRESS"): + _ERRNO_INPROGRESS += (errno.WSAEINPROGRESS,) + +####################################################### class StreamClosedError(IOError): """Exception raised by `IOStream` methods when the stream is closed. @@ -990,7 +1002,7 @@ class IOStream(BaseIOStream): # returned immediately when attempting to connect to # localhost, so handle them the same way as an error # reported later in _handle_connect. - if (errno_from_exception(e) != errno.EINPROGRESS and + if (errno_from_exception(e) not in _ERRNO_INPROGRESS and errno_from_exception(e) not in _ERRNO_WOULDBLOCK): gen_log.warning("Connect error on fd %s: %s", self.socket.fileno(), e) diff --git a/tornado/netutil.py b/tornado/netutil.py index a9e05d1e..336c8062 100644 --- a/tornado/netutil.py +++ b/tornado/netutil.py @@ -57,6 +57,9 @@ u('foo').encode('idna') # some they differ. _ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN) +if hasattr(errno, "WSAEWOULDBLOCK"): + _ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,) + def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128, flags=None): """Creates listening sockets bound to the given port and address. diff --git a/tornado/platform/asyncio.py b/tornado/platform/asyncio.py index 6518dea5..b40f0141 100644 --- a/tornado/platform/asyncio.py +++ b/tornado/platform/asyncio.py @@ -13,9 +13,9 @@ from __future__ import absolute_import, division, print_function, with_statement import datetime import functools -# _Timeout is used for its timedelta_to_seconds method for py26 compatibility. -from tornado.ioloop import IOLoop, _Timeout +from tornado.ioloop import IOLoop from tornado import stack_context +from tornado.util import timedelta_to_seconds try: # Import the real asyncio module for py33+ first. Older versions of the @@ -109,15 +109,13 @@ class BaseAsyncIOLoop(IOLoop): def stop(self): self.asyncio_loop.stop() - def add_timeout(self, deadline, callback): - if isinstance(deadline, (int, float)): - delay = max(deadline - self.time(), 0) - elif isinstance(deadline, datetime.timedelta): - delay = _Timeout.timedelta_to_seconds(deadline) - else: - raise TypeError("Unsupported deadline %r", deadline) - return self.asyncio_loop.call_later(delay, self._run_callback, - stack_context.wrap(callback)) + def call_at(self, when, callback, *args, **kwargs): + # asyncio.call_at supports *args but not **kwargs, so bind them here. + # We do not synchronize self.time and asyncio_loop.time, so + # convert from absolute to relative. + return self.asyncio_loop.call_later( + max(0, when - self.time()), self._run_callback, + functools.partial(stack_context.wrap(callback), *args, **kwargs)) def remove_timeout(self, timeout): timeout.cancel() diff --git a/tornado/platform/twisted.py b/tornado/platform/twisted.py index 18263dd9..b271dfce 100644 --- a/tornado/platform/twisted.py +++ b/tornado/platform/twisted.py @@ -68,6 +68,7 @@ from __future__ import absolute_import, division, print_function, with_statement import datetime import functools +import numbers import socket import twisted.internet.abstract @@ -90,11 +91,7 @@ from tornado.log import app_log from tornado.netutil import Resolver from tornado.stack_context import NullContext, wrap from tornado.ioloop import IOLoop - -try: - long # py2 -except NameError: - long = int # py3 +from tornado.util import timedelta_to_seconds @implementer(IDelayedCall) @@ -475,14 +472,19 @@ class TwistedIOLoop(tornado.ioloop.IOLoop): def stop(self): self.reactor.crash() - def add_timeout(self, deadline, callback): - if isinstance(deadline, (int, long, float)): + def add_timeout(self, deadline, callback, *args, **kwargs): + # This method could be simplified (since tornado 4.0) by + # overriding call_at instead of add_timeout, but we leave it + # for now as a test of backwards-compatibility. + if isinstance(deadline, numbers.Real): delay = max(deadline - self.time(), 0) elif isinstance(deadline, datetime.timedelta): - delay = tornado.ioloop._Timeout.timedelta_to_seconds(deadline) + delay = timedelta_to_seconds(deadline) else: raise TypeError("Unsupported deadline %r") - return self.reactor.callLater(delay, self._run_callback, wrap(callback)) + return self.reactor.callLater( + delay, self._run_callback, + functools.partial(wrap(callback), *args, **kwargs)) def remove_timeout(self, timeout): if timeout.active(): diff --git a/tornado/test/ioloop_test.py b/tornado/test/ioloop_test.py index e4f07338..e21d5d4c 100644 --- a/tornado/test/ioloop_test.py +++ b/tornado/test/ioloop_test.py @@ -155,7 +155,7 @@ class TestIOLoop(AsyncTestCase): def test_remove_timeout_after_fire(self): # It is not an error to call remove_timeout after it has run. - handle = self.io_loop.add_timeout(self.io_loop.time(), self.stop()) + handle = self.io_loop.add_timeout(self.io_loop.time(), self.stop) self.wait() self.io_loop.remove_timeout(handle) @@ -173,6 +173,18 @@ class TestIOLoop(AsyncTestCase): self.io_loop.add_callback(lambda: self.io_loop.add_callback(self.stop)) self.wait() + def test_timeout_with_arguments(self): + # This tests that all the timeout methods pass through *args correctly. + results = [] + self.io_loop.add_timeout(self.io_loop.time(), results.append, 1) + self.io_loop.add_timeout(datetime.timedelta(seconds=0), + results.append, 2) + self.io_loop.call_at(self.io_loop.time(), results.append, 3) + self.io_loop.call_later(0, results.append, 4) + self.io_loop.call_later(0, self.stop) + self.wait() + self.assertEqual(results, [1, 2, 3, 4]) + def test_close_file_object(self): """When a file object is used instead of a numeric file descriptor, the object should be closed (by IOLoop.close(all_fds=True), diff --git a/tornado/test/iostream_test.py b/tornado/test/iostream_test.py index e9d241a5..01b0d95a 100644 --- a/tornado/test/iostream_test.py +++ b/tornado/test/iostream_test.py @@ -232,8 +232,11 @@ class TestIOStreamMixin(object): self.assertFalse(self.connect_called) self.assertTrue(isinstance(stream.error, socket.error), stream.error) if sys.platform != 'cygwin': + _ERRNO_CONNREFUSED = (errno.ECONNREFUSED,) + if hasattr(errno, "WSAECONNREFUSED"): + _ERRNO_CONNREFUSED += (errno.WSAECONNREFUSED,) # cygwin's errnos don't match those used on native windows python - self.assertEqual(stream.error.args[0], errno.ECONNREFUSED) + self.assertTrue(stream.error.args[0] in _ERRNO_CONNREFUSED) def test_gaierror(self): # Test that IOStream sets its exc_info on getaddrinfo error diff --git a/tornado/test/simple_httpclient_test.py b/tornado/test/simple_httpclient_test.py index 2ba9f75d..f17da7e0 100644 --- a/tornado/test/simple_httpclient_test.py +++ b/tornado/test/simple_httpclient_test.py @@ -321,8 +321,10 @@ class SimpleHTTPClientTestMixin(object): if sys.platform != 'cygwin': # cygwin returns EPERM instead of ECONNREFUSED here - self.assertTrue(str(errno.ECONNREFUSED) in str(response.error), - response.error) + contains_errno = str(errno.ECONNREFUSED) in str(response.error) + if not contains_errno and hasattr(errno, "WSAECONNREFUSED"): + contains_errno = str(errno.WSAECONNREFUSED) in str(response.error) + self.assertTrue(contains_errno, response.error) # This is usually "Connection refused". # On windows, strerror is broken and returns "Unknown error". expected_message = os.strerror(errno.ECONNREFUSED) diff --git a/tornado/test/stack_context_test.py b/tornado/test/stack_context_test.py index d65a5b21..853260e3 100644 --- a/tornado/test/stack_context_test.py +++ b/tornado/test/stack_context_test.py @@ -35,11 +35,11 @@ class TestRequestHandler(RequestHandler): logging.debug('in part3()') raise Exception('test exception') - def get_error_html(self, status_code, **kwargs): - if 'exception' in kwargs and str(kwargs['exception']) == 'test exception': - return 'got expected exception' + def write_error(self, status_code, **kwargs): + if 'exc_info' in kwargs and str(kwargs['exc_info'][1]) == 'test exception': + self.write('got expected exception') else: - return 'unexpected failure' + self.write('unexpected failure') class HTTPStackContextTest(AsyncHTTPTestCase): diff --git a/tornado/test/web_test.py b/tornado/test/web_test.py index cbb62b9b..15b2fb5f 100644 --- a/tornado/test/web_test.py +++ b/tornado/test/web_test.py @@ -10,7 +10,7 @@ from tornado.template import DictLoader from tornado.testing import AsyncHTTPTestCase, ExpectLog, gen_test from tornado.test.util import unittest from tornado.util import u, bytes_type, ObjectDict, unicode_type -from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body +from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish import binascii import contextlib @@ -773,20 +773,6 @@ class ErrorResponseTest(WebTestCase): else: self.write("Status: %d" % status_code) - class GetErrorHtmlHandler(RequestHandler): - def get(self): - if self.get_argument("status", None): - self.send_error(int(self.get_argument("status"))) - else: - 1 / 0 - - def get_error_html(self, status_code, **kwargs): - self.set_header("Content-Type", "text/plain") - if "exception" in kwargs: - self.write("Exception: %s" % sys.exc_info()[0].__name__) - else: - self.write("Status: %d" % status_code) - class FailedWriteErrorHandler(RequestHandler): def get(self): 1 / 0 @@ -796,7 +782,6 @@ class ErrorResponseTest(WebTestCase): return [url("/default", DefaultHandler), url("/write_error", WriteErrorHandler), - url("/get_error_html", GetErrorHtmlHandler), url("/failed_write_error", FailedWriteErrorHandler), ] @@ -820,16 +805,6 @@ class ErrorResponseTest(WebTestCase): self.assertEqual(response.code, 503) self.assertEqual(b"Status: 503", response.body) - def test_get_error_html(self): - with ExpectLog(app_log, "Uncaught exception"): - response = self.fetch("/get_error_html") - self.assertEqual(response.code, 500) - self.assertEqual(b"Exception: ZeroDivisionError", response.body) - - response = self.fetch("/get_error_html?status=503") - self.assertEqual(response.code, 503) - self.assertEqual(b"Status: 503", response.body) - def test_failed_write_error(self): with ExpectLog(app_log, "Uncaught exception"): response = self.fetch("/failed_write_error") @@ -2307,3 +2282,20 @@ class XSRFTest(SimpleHandlerTestCase): body=urllib_parse.urlencode(dict(_xsrf=body_token)), headers=self.cookie_headers(cookie_token)) self.assertEqual(response.code, 200) + + +@wsgi_safe +class FinishExceptionTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + self.set_status(401) + self.set_header('WWW-Authenticate', 'Basic realm="something"') + self.write('authentication required') + raise Finish() + + def test_finish_exception(self): + response = self.fetch('/') + self.assertEqual(response.code, 401) + self.assertEqual('Basic realm="something"', + response.headers.get('WWW-Authenticate')) + self.assertEqual(b'authentication required', response.body) diff --git a/tornado/testing.py b/tornado/testing.py index b1564aa6..b4bfb274 100644 --- a/tornado/testing.py +++ b/tornado/testing.py @@ -70,8 +70,8 @@ def get_unused_port(): only that a series of get_unused_port calls in a single process return distinct ports. - **Deprecated**. Use bind_unused_port instead, which is guaranteed - to find an unused port. + .. deprecated:: + Use bind_unused_port instead, which is guaranteed to find an unused port. """ global _next_port port = _next_port diff --git a/tornado/util.py b/tornado/util.py index 49eea2c3..b6e06c67 100644 --- a/tornado/util.py +++ b/tornado/util.py @@ -311,6 +311,11 @@ class ArgReplacer(object): return old_value, args, kwargs +def timedelta_to_seconds(td): + """Equivalent to td.total_seconds() (introduced in python 2.7).""" + return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / float(10 ** 6) + + def _websocket_mask_python(mask, data): """Websocket masking function. diff --git a/tornado/web.py b/tornado/web.py index 4884dd67..506caae7 100644 --- a/tornado/web.py +++ b/tornado/web.py @@ -630,7 +630,6 @@ class RequestHandler(object): self.set_status(status) self.set_header("Location", urlparse.urljoin(utf8(self.request.uri), utf8(url))) - self.finish() def write(self, chunk): @@ -944,26 +943,7 @@ class RequestHandler(object): ``kwargs["exc_info"]``. Note that this exception may not be the "current" exception for purposes of methods like ``sys.exc_info()`` or ``traceback.format_exc``. - - For historical reasons, if a method ``get_error_html`` exists, - it will be used instead of the default ``write_error`` implementation. - ``get_error_html`` returned a string instead of producing output - normally, and had different semantics for exception handling. - Users of ``get_error_html`` are encouraged to convert their code - to override ``write_error`` instead. """ - if hasattr(self, 'get_error_html'): - if 'exc_info' in kwargs: - exc_info = kwargs.pop('exc_info') - kwargs['exception'] = exc_info[1] - try: - # Put the traceback into sys.exc_info() - raise_exc_info(exc_info) - except Exception: - self.finish(self.get_error_html(status_code, **kwargs)) - else: - self.finish(self.get_error_html(status_code, **kwargs)) - return if self.settings.get("serve_traceback") and "exc_info" in kwargs: # in debug mode, try to send a traceback self.set_header('Content-Type', 'text/plain') @@ -1385,6 +1365,11 @@ class RequestHandler(object): " (" + self.request.remote_ip + ")" def _handle_request_exception(self, e): + if isinstance(e, Finish): + # Not an error; just finish the request without logging. + if not self._finished: + self.finish() + return self.log_exception(*sys.exc_info()) if self._finished: # Extra errors after the request has been finished should @@ -1558,7 +1543,7 @@ def removeslash(method): if uri: # don't try to redirect '/' to '' if self.request.query: uri += "?" + self.request.query - self.redirectTo(uri, permanent=True) + self.redirect(uri, permanent=True) return else: raise HTTPError(404) @@ -1580,7 +1565,7 @@ def addslash(method): uri = self.request.path + "/" if self.request.query: uri += "?" + self.request.query - self.redirectTo(uri, permanent=True) + self.redirect(uri, permanent=True) return raise HTTPError(404) return method(self, *args, **kwargs) @@ -1939,6 +1924,9 @@ class HTTPError(Exception): `RequestHandler.send_error` since it automatically ends the current function. + To customize the response sent with an `HTTPError`, override + `RequestHandler.write_error`. + :arg int status_code: HTTP status code. Must be listed in `httplib.responses ` unless the ``reason`` keyword argument is given. @@ -1967,6 +1955,25 @@ class HTTPError(Exception): return message +class Finish(Exception): + """An exception that ends the request without producing an error response. + + When `Finish` is raised in a `RequestHandler`, the request will end + (calling `RequestHandler.finish` if it hasn't already been called), + but the outgoing response will not be modified and the error-handling + methods (including `RequestHandler.write_error`) will not be called. + + This can be a more convenient way to implement custom error pages + than overriding ``write_error`` (especially in library code):: + + if self.current_user is None: + self.set_status(401) + self.set_header('WWW-Authenticate', 'Basic realm="something"') + raise Finish() + """ + pass + + class MissingArgumentError(HTTPError): """Exception raised by `RequestHandler.get_argument`. @@ -2494,9 +2501,9 @@ class FallbackHandler(RequestHandler): class OutputTransform(object): """A transform modifies the result of an HTTP request (e.g., GZip encoding) - A new transform instance is created for every request. See the - GZipContentEncoding example below if you want to implement a - new Transform. + Applications are not expected to create their own OutputTransforms + or interact with them directly; the framework chooses which transforms + (if any) to apply. """ def __init__(self, request): pass @@ -2587,7 +2594,7 @@ def authenticated(method): else: next_url = self.request.uri url += "?" + urlencode(dict(next=next_url)) - self.redirectTo(url) + self.redirect(url) return raise HTTPError(403) return method(self, *args, **kwargs) diff --git a/tornado/websocket.py b/tornado/websocket.py index 19196b88..b3349988 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -3,18 +3,17 @@ `WebSockets `_ allow for bidirectional communication between the browser and server. -.. warning:: +WebSockets are supported in the current versions of all major browsers, +although older versions that do not support WebSockets are still in use +(refer to http://caniuse.com/websockets for details). - The WebSocket protocol was recently finalized as `RFC 6455 - `_ and is not yet supported in - all browsers. Refer to http://caniuse.com/websockets for details - on compatibility. In addition, during development the protocol - went through several incompatible versions, and some browsers only - support older versions. By default this module only supports the - latest version of the protocol, but optional support for an older - version (known as "draft 76" or "hixie-76") can be enabled by - overriding `WebSocketHandler.allow_draft76` (see that method's - documentation for caveats). +This module implements the final version of the WebSocket protocol as +defined in `RFC 6455 `_. Certain +browser versions (notably Safari 5.x) implemented an earlier draft of +the protocol (known as "draft 76") and are not compatible with this module. + +.. versionchanged:: 4.0 + Removed support for the draft 76 protocol version. """ from __future__ import absolute_import, division, print_function, with_statement @@ -22,11 +21,9 @@ from __future__ import absolute_import, division, print_function, with_statement import base64 import collections -import functools import hashlib import os import struct -import time import tornado.escape import tornado.web @@ -38,7 +35,7 @@ from tornado.iostream import StreamClosedError from tornado.log import gen_log, app_log from tornado import simple_httpclient from tornado.tcpclient import TCPClient -from tornado.util import bytes_type, unicode_type, _websocket_mask +from tornado.util import bytes_type, _websocket_mask try: from urllib.parse import urlparse # py2 @@ -161,10 +158,6 @@ class WebSocketHandler(tornado.web.RequestHandler): if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"): self.ws_connection = WebSocketProtocol13(self) self.ws_connection.accept_connection() - elif (self.allow_draft76() and - "Sec-WebSocket-Version" not in self.request.headers): - self.ws_connection = WebSocketProtocol76(self) - self.ws_connection.accept_connection() else: self.stream.write(tornado.escape.utf8( "HTTP/1.1 426 Upgrade Required\r\n" @@ -256,9 +249,6 @@ class WebSocketHandler(tornado.web.RequestHandler): closing. These values are made available to the client, but are not otherwise interpreted by the websocket protocol. - The ``code`` and ``reason`` arguments are ignored in the "draft76" - protocol version. - .. versionchanged:: 4.0 Added the ``code`` and ``reason`` arguments. @@ -296,21 +286,6 @@ class WebSocketHandler(tornado.web.RequestHandler): # Check to see that origin matches host directly, including ports return origin == host - def allow_draft76(self): - """Override to enable support for the older "draft76" protocol. - - The draft76 version of the websocket protocol is disabled by - default due to security concerns, but it can be enabled by - overriding this method to return True. - - Connections using the draft76 protocol do not support the - ``binary=True`` flag to `write_message`. - - Support for the draft76 protocol is deprecated and will be - removed in a future version of Tornado. - """ - return False - def set_nodelay(self, value): """Set the no-delay flag for this stream. @@ -327,18 +302,6 @@ class WebSocketHandler(tornado.web.RequestHandler): """ self.stream.set_nodelay(value) - def get_websocket_scheme(self): - """Return the url scheme used for this request, either "ws" or "wss". - - This is normally decided by HTTPServer, but applications - may wish to override this if they are using an SSL proxy - that does not provide the X-Scheme header as understood - by HTTPServer. - - Note that this is only used by the draft76 protocol. - """ - return "wss" if self.request.protocol == "https" else "ws" - def on_connection_close(self): if self.ws_connection: self.ws_connection.on_connection_close() @@ -392,175 +355,6 @@ class WebSocketProtocol(object): self.close() # let the subclass cleanup -class WebSocketProtocol76(WebSocketProtocol): - """Implementation of the WebSockets protocol, version hixie-76. - - This class provides basic functionality to process WebSockets requests as - specified in - http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76 - """ - def __init__(self, handler): - WebSocketProtocol.__init__(self, handler) - self.challenge = None - self._waiting = None - - def accept_connection(self): - try: - self._handle_websocket_headers() - except ValueError: - gen_log.debug("Malformed WebSocket request received") - self._abort() - return - - scheme = self.handler.get_websocket_scheme() - - # draft76 only allows a single subprotocol - subprotocol_header = '' - subprotocol = self.request.headers.get("Sec-WebSocket-Protocol", None) - if subprotocol: - selected = self.handler.select_subprotocol([subprotocol]) - if selected: - assert selected == subprotocol - subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected - - # Write the initial headers before attempting to read the challenge. - # This is necessary when using proxies (such as HAProxy), which - # need to see the Upgrade headers before passing through the - # non-HTTP traffic that follows. - self.stream.write(tornado.escape.utf8( - "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Server: TornadoServer/%(version)s\r\n" - "Sec-WebSocket-Origin: %(origin)s\r\n" - "Sec-WebSocket-Location: %(scheme)s://%(host)s%(uri)s\r\n" - "%(subprotocol)s" - "\r\n" % (dict( - version=tornado.version, - origin=self.request.headers["Origin"], - scheme=scheme, - host=self.request.host, - uri=self.request.uri, - subprotocol=subprotocol_header)))) - self.stream.read_bytes(8, self._handle_challenge) - - def challenge_response(self, challenge): - """Generates the challenge response that's needed in the handshake - - The challenge parameter should be the raw bytes as sent from the - client. - """ - key_1 = self.request.headers.get("Sec-Websocket-Key1") - key_2 = self.request.headers.get("Sec-Websocket-Key2") - try: - part_1 = self._calculate_part(key_1) - part_2 = self._calculate_part(key_2) - except ValueError: - raise ValueError("Invalid Keys/Challenge") - return self._generate_challenge_response(part_1, part_2, challenge) - - def _handle_challenge(self, challenge): - try: - challenge_response = self.challenge_response(challenge) - except ValueError: - gen_log.debug("Malformed key data in WebSocket request") - self._abort() - return - self._write_response(challenge_response) - - def _write_response(self, challenge): - self.stream.write(challenge) - self._run_callback(self.handler.open, *self.handler.open_args, - **self.handler.open_kwargs) - self._receive_message() - - def _handle_websocket_headers(self): - """Verifies all invariant- and required headers - - If a header is missing or have an incorrect value ValueError will be - raised - """ - fields = ("Origin", "Host", "Sec-Websocket-Key1", - "Sec-Websocket-Key2") - if not all(map(lambda f: self.request.headers.get(f), fields)): - raise ValueError("Missing/Invalid WebSocket headers") - - def _calculate_part(self, key): - """Processes the key headers and calculates their key value. - - Raises ValueError when feed invalid key.""" - # pyflakes complains about variable reuse if both of these lines use 'c' - number = int(''.join(c for c in key if c.isdigit())) - spaces = len([c2 for c2 in key if c2.isspace()]) - try: - key_number = number // spaces - except (ValueError, ZeroDivisionError): - raise ValueError - return struct.pack(">I", key_number) - - def _generate_challenge_response(self, part_1, part_2, part_3): - m = hashlib.md5() - m.update(part_1) - m.update(part_2) - m.update(part_3) - return m.digest() - - def _receive_message(self): - self.stream.read_bytes(1, self._on_frame_type) - - def _on_frame_type(self, byte): - frame_type = ord(byte) - if frame_type == 0x00: - self.stream.read_until(b"\xff", self._on_end_delimiter) - elif frame_type == 0xff: - self.stream.read_bytes(1, self._on_length_indicator) - else: - self._abort() - - def _on_end_delimiter(self, frame): - if not self.client_terminated: - self._run_callback(self.handler.on_message, - frame[:-1].decode("utf-8", "replace")) - if not self.client_terminated: - self._receive_message() - - def _on_length_indicator(self, byte): - if ord(byte) != 0x00: - self._abort() - return - self.client_terminated = True - self.close() - - def write_message(self, message, binary=False): - """Sends the given message to the client of this Web Socket.""" - if binary: - raise ValueError( - "Binary messages not supported by this version of websockets") - if isinstance(message, unicode_type): - message = message.encode("utf-8") - assert isinstance(message, bytes_type) - self.stream.write(b"\x00" + message + b"\xff") - - def write_ping(self, data): - """Send ping frame.""" - raise ValueError("Ping messages not supported by this version of websockets") - - def close(self, code=None, reason=None): - """Closes the WebSocket connection.""" - if not self.server_terminated: - if not self.stream.closed(): - self.stream.write("\xff\x00") - self.server_terminated = True - if self.client_terminated: - if self._waiting is not None: - self.stream.io_loop.remove_timeout(self._waiting) - self._waiting = None - self.stream.close() - elif self._waiting is None: - self._waiting = self.stream.io_loop.add_timeout( - time.time() + 5, self._abort) - - class WebSocketProtocol13(WebSocketProtocol): """Implementation of the WebSocket protocol from RFC 6455.