diff --git a/SickBeard.py b/SickBeard.py index 0a2d3f8c..ded3d2ff 100755 --- a/SickBeard.py +++ b/SickBeard.py @@ -76,6 +76,7 @@ from sickbeard.exceptions import ex from lib.configobj import ConfigObj throwaway = datetime.datetime.strptime('20110101', '%Y%m%d') +rollback_loaded = None signal.signal(signal.SIGINT, sickbeard.sig_handler) signal.signal(signal.SIGTERM, sickbeard.sig_handler) @@ -153,6 +154,19 @@ class SickGear(object): return '\n'.join(help_msg) + @staticmethod + def execute_rollback(mo, max_v): + global rollback_loaded + try: + if None is rollback_loaded: + rollback_loaded = db.get_rollback_module() + if None is not rollback_loaded: + rollback_loaded.__dict__[mo]().run(max_v) + else: + print(u'ERROR: Could not download Rollback Module.') + except (StandardError, Exception): + pass + def start(self): # do some preliminary stuff sickbeard.MY_FULLNAME = os.path.normpath(os.path.abspath(__file__)) @@ -324,14 +338,28 @@ class SickGear(object): print('Stack Size %s not set: %s' % (stack_size, e.message)) # check all db versions - for d, min_v, max_v, mo in [ - ('failed.db', sickbeard.failed_db.MIN_DB_VERSION, sickbeard.failed_db.MAX_DB_VERSION, 'FailedDb'), - ('cache.db', sickbeard.cache_db.MIN_DB_VERSION, sickbeard.cache_db.MAX_DB_VERSION, 'CacheDb'), - ('sickbeard.db', sickbeard.mainDB.MIN_DB_VERSION, sickbeard.mainDB.MAX_DB_VERSION, 'MainDb') + for d, min_v, max_v, base_v, mo in [ + ('failed.db', sickbeard.failed_db.MIN_DB_VERSION, sickbeard.failed_db.MAX_DB_VERSION, sickbeard.failed_db.TEST_BASE_VERSION, 'FailedDb'), + ('cache.db', sickbeard.cache_db.MIN_DB_VERSION, sickbeard.cache_db.MAX_DB_VERSION, sickbeard.cache_db.TEST_BASE_VERSION, 'CacheDb'), + ('sickbeard.db', sickbeard.mainDB.MIN_DB_VERSION, sickbeard.mainDB.MAX_DB_VERSION, sickbeard.mainDB.TEST_BASE_VERSION, 'MainDb') ]: cur_db_version = db.DBConnection(d).checkDBVersion() - if cur_db_version > 0: + # handling of standalone TEST db versions + if cur_db_version >= 100000 and cur_db_version != max_v: + print('Your [%s] database version (%s) is a test db version and doesn\'t match SickGear required ' + 'version (%s), downgrading to production db' % (d, cur_db_version, max_v)) + self.execute_rollback(mo, max_v) + cur_db_version = db.DBConnection(d).checkDBVersion() + if cur_db_version >= 100000: + print(u'Rollback to production failed.') + sys.exit(u'If you have used other forks, your database may be unusable due to their changes') + if 100000 <= max_v and None is not base_v: + max_v = base_v # set max_v to the needed base production db for test_db + print(u'Rollback to production of [%s] successful.' % d) + + # handling of production db versions + if 0 < cur_db_version < 100000: if cur_db_version < min_v: print(u'Your [%s] database version (%s) is too old to migrate from with this version of SickGear' % (d, cur_db_version)) @@ -341,19 +369,16 @@ class SickGear(object): print(u'Your [%s] database version (%s) has been incremented past' u' what this version of SickGear supports. Trying to rollback now. Please wait...' % (d, cur_db_version)) - try: - rollback_loaded = db.get_rollback_module() - if None is not rollback_loaded: - rollback_loaded.__dict__[mo]().run(max_v) - else: - print(u'ERROR: Could not download Rollback Module.') - except (StandardError, Exception): - pass + self.execute_rollback(mo, max_v) if db.DBConnection(d).checkDBVersion() > max_v: print(u'Rollback failed.') sys.exit(u'If you have used other forks, your database may be unusable due to their changes') print(u'Rollback of [%s] successful.' % d) + # free memory + global rollback_loaded + rollback_loaded = None + # Initialize the config and our threads sickbeard.initialize(console_logging=self.console_logging) diff --git a/sickbeard/databases/cache_db.py b/sickbeard/databases/cache_db.py index f9bd2863..de3dc252 100644 --- a/sickbeard/databases/cache_db.py +++ b/sickbeard/databases/cache_db.py @@ -22,6 +22,7 @@ import re MIN_DB_VERSION = 1 MAX_DB_VERSION = 4 +TEST_BASE_VERSION = None # the base production db version, only needed for TEST db versions (>=100000) # Add new migrations at the bottom of the list; subclass the previous migration. diff --git a/sickbeard/databases/failed_db.py b/sickbeard/databases/failed_db.py index 7d78abd0..c78d6650 100644 --- a/sickbeard/databases/failed_db.py +++ b/sickbeard/databases/failed_db.py @@ -21,6 +21,7 @@ from sickbeard.common import Quality MIN_DB_VERSION = 1 MAX_DB_VERSION = 1 +TEST_BASE_VERSION = None # the base production db version, only needed for TEST db versions (>=100000) # Add new migrations at the bottom of the list; subclass the previous migration. class InitialSchema(db.SchemaUpgrade): diff --git a/sickbeard/databases/mainDB.py b/sickbeard/databases/mainDB.py index 099b4c34..b49da757 100644 --- a/sickbeard/databases/mainDB.py +++ b/sickbeard/databases/mainDB.py @@ -28,6 +28,7 @@ from sickbeard.name_parser.parser import NameParser, InvalidNameException, Inval MIN_DB_VERSION = 9 # oldest db version we support migrating from MAX_DB_VERSION = 20008 +TEST_BASE_VERSION = None # the base production db version, only needed for TEST db versions (>=100000) class MainSanityCheck(db.DBSanityCheck): diff --git a/tests/db_tests.py b/tests/db_tests.py index d77898eb..a64c1ed1 100644 --- a/tests/db_tests.py +++ b/tests/db_tests.py @@ -20,6 +20,7 @@ from __future__ import print_function import unittest import test_lib as test +from sickbeard import cache_db, mainDB, failed_db class DBBasicTests(test.SickbeardTestDBCase): @@ -28,9 +29,16 @@ class DBBasicTests(test.SickbeardTestDBCase): super(DBBasicTests, self).setUp() self.db = test.db.DBConnection() + def is_testdb(self, version): + if isinstance(version, (int, long)): + return 100000 <= version + def test_select(self): self.db.select('SELECT * FROM tv_episodes WHERE showid = ? AND location != ""', [0000]) self.db.close() + self.assertEqual(cache_db.TEST_BASE_VERSION is not None, self.is_testdb(cache_db.MAX_DB_VERSION)) + self.assertEqual(mainDB.TEST_BASE_VERSION is not None, self.is_testdb(mainDB.MAX_DB_VERSION)) + self.assertEqual(failed_db.TEST_BASE_VERSION is not None, self.is_testdb(failed_db.MAX_DB_VERSION)) if __name__ == '__main__': print('==================')