# # This file is part of SickGear. # # SickGear is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # SickGear is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with SickGear. If not, see . from __future__ import with_statement import datetime import itertools import os.path import re import sqlite3 import threading import time from exceptions_helper import ex import sickgear from . import logger, sgdatetime from .sgdatetime import timestamp_near from sg_helpers import make_path, compress_file, remove_file_perm, scantree from _23 import filter_iter, filter_list, list_values, scandir from six import iterkeys, iteritems, itervalues # noinspection PyUnreachableCode if False: from typing import Any, AnyStr, Dict, List, Optional, Tuple, Union db_lock = threading.Lock() db_support_multiple_insert = (3, 7, 11) <= sqlite3.sqlite_version_info # type: bool db_support_column_rename = (3, 25, 0) <= sqlite3.sqlite_version_info # type: bool db_support_upsert = (3, 25, 0) <= sqlite3.sqlite_version_info # type: bool db_supports_backup = hasattr(sqlite3.Connection, 'backup') and (3, 6, 11) <= sqlite3.sqlite_version_info # type: bool def dbFilename(filename='sickbeard.db', suffix=None): # type: (AnyStr, Optional[AnyStr]) -> AnyStr """ @param filename: The sqlite database filename to use. If not specified, will be made to be sickbeard.db @param suffix: The suffix to append to the filename. A '.' will be added automatically, i.e. suffix='v0' will make dbfile.db.v0 @return: the correct location of the database file. """ if suffix: filename = '%s.%s' % (filename, suffix) return os.path.join(sickgear.DATA_DIR, filename) def mass_upsert_sql(table_name, value_dict, key_dict, sanitise=True): # type: (AnyStr, Dict, Dict, bool) -> List[List[AnyStr]] """ use with cl.extend(mass_upsert_sql(tableName, valueDict, keyDict)) :param table_name: table name :param value_dict: dict of values to be set {'table_fieldname': value} :param key_dict: dict of restrains for update {'table_fieldname': value} :param sanitise: True to remove k, v pairs in keyDict from valueDict as they must not exist in both. This option has a performance hit so it's best to remove key_dict keys from value_dict and set this False instead. :type sanitise: Boolean :return: list of 2 sql command """ cl = [] gen_params = (lambda my_dict: [x + ' = ?' for x in iterkeys(my_dict)]) # sanity: remove k, v pairs in keyDict from valueDict if sanitise: value_dict = dict(filter_iter(lambda k: k[0] not in key_dict, iteritems(value_dict))) # noinspection SqlResolve cl.append(['UPDATE [%s] SET %s WHERE %s' % (table_name, ', '.join(gen_params(value_dict)), ' AND '.join(gen_params(key_dict))), list_values(value_dict) + list_values(key_dict)]) # noinspection SqlResolve cl.append(['INSERT INTO [' + table_name + '] (' + ', '.join(["'%s'" % ('%s' % v).replace("'", "''") for v in itertools.chain(iterkeys(value_dict), iterkeys(key_dict))]) + ')' + ' SELECT ' + ', '.join(["'%s'" % ('%s' % v).replace("'", "''") for v in itertools.chain(itervalues(value_dict), itervalues(key_dict))]) + ' WHERE changes() = 0']) return cl class DBConnection(object): def __init__(self, filename='sickbeard.db', row_type=None, **kwargs): # type: (AnyStr, Optional[AnyStr], Dict) -> None from . import helpers self.new_db = False db_src = dbFilename(filename) if not os.path.isfile(db_src): db_alt = dbFilename('sickrage.db') if os.path.isfile(db_alt): helpers.copy_file(db_alt, db_src) self.filename = filename self.connection = sqlite3.connect(db_src, 20) if 'dict' == row_type: self.connection.row_factory = self._dict_factory else: self.connection.row_factory = sqlite3.Row def backup_db(self, target, backup_filename=None): # type: (AnyStr, AnyStr) -> Tuple[bool, AnyStr] """ backup the db to target dir + optional filename Availability: SQLite 3.6.11 or higher New in version 3.7 :param target: target dir :param backup_filename: optional backup filename (default is the source name) :return: success, message """ if not db_supports_backup: logger.log('this python sqlite3 version doesn\'t support backups', logger.DEBUG) return False, 'this python sqlite3 version doesn\'t support backups' if not os.path.isdir(target): logger.log('Backup target invalid', logger.ERROR) return False, 'Backup target invalid' target_db = os.path.join(target, (backup_filename, self.filename)[None is backup_filename]) if os.path.exists(target_db): logger.log('Backup target file already exists', logger.ERROR) return False, 'Backup target file already exists' def progress(status, remaining, total): logger.log('Copied %s of %s pages...' % (total - remaining, total), logger.DEBUG) backup_con = None try: # copy into this DB backup_con = sqlite3.connect(target_db, 20) with backup_con: with db_lock: self.connection.backup(backup_con, progress=progress) logger.log('%s backup successful' % self.filename, logger.DEBUG) except sqlite3.Error as error: logger.log("Error while taking backup: %s" % ex(error), logger.ERROR) return False, 'Backup failed' finally: if backup_con: try: backup_con.close() except (BaseException, Exception): pass return True, 'Backup successful' def checkDBVersion(self): # type: (...) -> int try: if self.hasTable('db_version'): result = self.select('SELECT db_version FROM db_version') else: version = self.select('PRAGMA user_version')[0]['user_version'] if version: self.action('PRAGMA user_version = 0') self.action('CREATE TABLE db_version (db_version INTEGER);') self.action('INSERT INTO db_version (db_version) VALUES (%s);' % version) return version except (BaseException, Exception): return 0 if result: version = int(result[0]['db_version']) if 10000 > version and self.hasColumn('db_version', 'db_minor_version'): # noinspection SqlResolve minor = self.select('SELECT db_minor_version FROM db_version') return version * 100 + int(minor[0]['db_minor_version']) return version return 0 def mass_action(self, queries, log_transaction=False): # type: (List[Union[List[AnyStr], Tuple[AnyStr, List], Tuple[AnyStr]]], bool) -> Optional[List, sqlite3.Cursor] from . import helpers with db_lock: if None is queries: return if not queries: return [] attempt = 0 sql_result = [] affected = 0 while 5 > attempt: try: cursor = self.connection.cursor() if not log_transaction: for cur_query in queries: sql_result.append(cursor.execute(*tuple(cur_query)).fetchall()) affected += abs(cursor.rowcount) else: for cur_query in queries: logger.log(cur_query[0] if 1 == len(cur_query) else '%s with args %s' % tuple(cur_query), logger.DB) sql_result.append(cursor.execute(*tuple(cur_query)).fetchall()) affected += abs(cursor.rowcount) self.connection.commit() if 0 < affected: logger.debug(u'Transaction with %s queries executed affected at least %i row%s' % ( len(queries), affected, helpers.maybe_plural(affected))) return sql_result except sqlite3.OperationalError as e: sql_result = [] if self.connection: self.connection.rollback() if not self.action_error(e): raise attempt += 1 except sqlite3.DatabaseError as e: if self.connection: self.connection.rollback() logger.error(u'Fatal error executing query: ' + ex(e)) raise return sql_result @staticmethod def action_error(e): if 'unable to open database file' in e.args[0] or 'database is locked' in e.args[0]: logger.log(u'DB error: ' + ex(e), logger.WARNING) time.sleep(1) return True logger.log(u'DB error: ' + ex(e), logger.ERROR) def action(self, query, args=None): # type: (AnyStr, Optional[List, Tuple]) -> Optional[Union[List, sqlite3.Cursor]] with db_lock: if None is query: return sql_result = None attempt = 0 while 5 > attempt: try: if None is args: logger.log('%s: %s' % (self.filename, query), logger.DB) sql_result = self.connection.execute(query) else: logger.log('%s: %s with args %s' % (self.filename, query, str(args)), logger.DB) sql_result = self.connection.execute(query, args) self.connection.commit() # get out of the connection attempt loop since we were successful break except sqlite3.OperationalError as e: if not self.action_error(e): raise attempt += 1 except sqlite3.DatabaseError as e: logger.log(u'Fatal error executing query: ' + ex(e), logger.ERROR) raise return sql_result def select(self, query, args=None): # type: (AnyStr, Optional[List, Tuple]) -> List sql_results = self.action(query, args).fetchall() if None is sql_results: return [] return sql_results def upsert(self, table_name, value_dict, key_dict): # type: (AnyStr, Dict, Dict) -> None changes_before = self.connection.total_changes gen_params = (lambda my_dict: [x + ' = ?' for x in iterkeys(my_dict)]) # noinspection SqlResolve query = 'UPDATE [%s] SET %s WHERE %s' % ( table_name, ', '.join(gen_params(value_dict)), ' AND '.join(gen_params(key_dict))) self.action(query, list_values(value_dict) + list_values(key_dict)) if self.connection.total_changes == changes_before: # noinspection SqlResolve query = 'INSERT INTO [' + table_name + ']' \ + ' (%s)' % ', '.join(itertools.chain(iterkeys(value_dict), iterkeys(key_dict))) \ + ' VALUES (%s)' % ', '.join(['?'] * (len(value_dict) + len(key_dict))) self.action(query, list_values(value_dict) + list_values(key_dict)) def tableInfo(self, table_name): # type: (AnyStr) -> Dict[AnyStr, Dict[AnyStr, AnyStr]] # FIXME ? binding is not supported here, but I cannot find a way to escape a string manually sql_result = self.select('PRAGMA table_info([%s])' % table_name) columns = {} for cur_column in sql_result: columns[cur_column['name']] = {'type': cur_column['type']} return columns # http://stackoverflow.com/questions/3300464/how-can-i-get-dict-from-sqlite-query @staticmethod def _dict_factory(cursor, row): d = {} for idx, col in enumerate(cursor.description): d[col[0]] = row[idx] return d def hasTable(self, table_name): # type: (AnyStr) -> bool return 0 < len(self.select('SELECT 1 FROM sqlite_master WHERE name = ?;', (table_name,))) def hasColumn(self, table_name, column): # type: (AnyStr, AnyStr) -> bool return column in self.tableInfo(table_name) def hasIndex(self, table_name, index): # type: (AnyStr, AnyStr) -> bool sqlResults = self.select('PRAGMA index_list([%s])' % table_name) for result in sqlResults: if result['name'] == index: return True return False def removeIndex(self, table, name): # type: (AnyStr, AnyStr) -> None if self.hasIndex(table, name): self.action('DROP INDEX' + ' [%s]' % name) def removeTable(self, name): # type: (AnyStr) -> None if self.hasTable(name): self.action('DROP TABLE' + ' [%s]' % name) # noinspection SqlResolve def addColumn(self, table, column, data_type='NUMERIC', default=0): # type: (AnyStr, AnyStr, AnyStr, Any) -> None self.action('ALTER TABLE [%s] ADD %s %s' % (table, column, data_type)) self.action('UPDATE [%s] SET %s = ?' % (table, column), (default,)) def has_flag(self, flag_name): # type: (AnyStr) -> bool sql_result = self.select('SELECT flag FROM flags WHERE flag = ?', [flag_name]) return 0 < len(sql_result) def add_flag(self, flag_name): # type: (AnyStr) -> bool has_flag = self.has_flag(flag_name) if not has_flag: self.action('INSERT INTO flags (flag) VALUES (?)', [flag_name]) return not has_flag def remove_flag(self, flag_name): # type: (AnyStr) -> bool has_flag = self.has_flag(flag_name) if has_flag: self.action('DELETE FROM flags WHERE flag = ?', [flag_name]) return has_flag def toggle_flag(self, flag_name): # type: (AnyStr) -> bool """ Add or remove a flag :param flag_name: Name of flag :return: True if this call added the flag, False if flag is removed """ if self.remove_flag(flag_name): return False self.add_flag(flag_name) return True def set_flag(self, flag_name, state=True): # type: (AnyStr, bool) -> bool """ Set state of flag :param flag_name: Name of flag :param state: If true, create flag otherwise remove flag :return: Previous state of flag """ return (self.add_flag, self.remove_flag)[not bool(state)](flag_name) def close(self): """Close database connection""" if None is not getattr(self, 'connection', None): self.connection.close() self.connection = None def upgrade_log(self, to_log, log_level=logger.MESSAGE): # type: (AnyStr, int) -> None logger.load_log('Upgrading %s' % self.filename, to_log, log_level) def sanityCheckDatabase(connection, sanity_check): sanity_check(connection).check() class DBSanityCheck(object): def __init__(self, connection): self.connection = connection def check(self): pass def upgradeDatabase(connection, schema): logger.log(u'Checking database structure...', logger.MESSAGE) connection.is_upgrading = False connection.new_db = 0 == connection.checkDBVersion() _processUpgrade(connection, schema) if connection.is_upgrading: connection.upgrade_log('Finished') def prettyName(class_name): # type: (AnyStr) -> AnyStr return ' '.join([x.group() for x in re.finditer('([A-Z])([a-z0-9]+)', class_name)]) def restoreDatabase(filename, version): logger.log(u'Restoring database before trying upgrade again') if not sickgear.helpers.restore_versioned_file(dbFilename(filename=filename, suffix='v%s' % version), version): logger.log_error_and_exit(u'Database restore failed, abort upgrading database') return False return True def _processUpgrade(connection, upgrade_class): instance = upgrade_class(connection) logger.log('Checking %s database upgrade' % prettyName(upgrade_class.__name__), logger.DEBUG) if not instance.test(): connection.is_upgrading = True connection.upgrade_log(getattr(upgrade_class, 'pretty_name', None) or prettyName(upgrade_class.__name__)) logger.log('Database upgrade required: %s' % prettyName(upgrade_class.__name__), logger.MESSAGE) db_version = connection.checkDBVersion() try: # only do backup if it's not a new db 0 < db_version and backup_database(connection, connection.filename, db_version) instance.execute() cleanup_old_db_backups(connection.filename) except (BaseException, Exception): # attempting to restore previous DB backup and perform upgrade if db_version: # close db before attempting restore connection.close() if restoreDatabase(connection.filename, db_version): logger.log_error_and_exit('Successfully restored database version: %s' % db_version) else: logger.log_error_and_exit('Failed to restore database version: %s' % db_version) else: logger.log_error_and_exit('Database upgrade failed, can\'t determine old db version, not restoring.') logger.log('%s upgrade completed' % upgrade_class.__name__, logger.DEBUG) else: logger.log('%s upgrade not required' % upgrade_class.__name__, logger.DEBUG) for upgradeSubClass in upgrade_class.__subclasses__(): _processUpgrade(connection, upgradeSubClass) # Base migration class. All future DB changes should be subclassed from this class class SchemaUpgrade(object): def __init__(self, connection, **kwargs): self.connection = connection def hasTable(self, table_name): return 0 < len(self.connection.select('SELECT 1 FROM sqlite_master WHERE name = ?;', (table_name,))) def hasColumn(self, table_name, column): return column in self.connection.tableInfo(table_name) def list_tables(self): # type: (...) -> List[AnyStr] """ returns list of all table names in db """ return [s['name'] for s in self.connection.select('SELECT name FROM main.sqlite_master WHERE type = ?;', ['table'])] def list_indexes(self): # type: (...) -> List[AnyStr] """ returns list of all index names in db """ return [s['name'] for s in self.connection.select('SELECT name FROM main.sqlite_master WHERE type = ?;', ['index'])] # noinspection SqlResolve def addColumn(self, table, column, data_type='NUMERIC', default=0, set_default=False): self.connection.action('ALTER TABLE [%s] ADD %s %s%s' % (table, column, data_type, ('', ' DEFAULT "%s"' % default)[set_default])) self.connection.action('UPDATE [%s] SET %s = ?' % (table, column), (default,)) # noinspection SqlResolve def addColumns(self, table, column_list=None): # type: (AnyStr, List) -> None if isinstance(column_list, list): sql = [] for col in column_list: is_list = isinstance(col, (list, tuple)) list_len = 0 if not is_list else len(col) column = col if not is_list else col[0] data_type = 'NUMERIC' if not is_list or 2 > list_len else col[1] default = 0 if not is_list or 3 > list_len else col[2] sql.append(['ALTER TABLE [%s] ADD %s %s%s' % (table, column, data_type, '' if list_len < 3 else ' DEFAULT %s' % ('""' if 'TEXT' == data_type and '' == default else default))]) if 2 < list_len: sql.append(['UPDATE [%s] SET %s = ?' % (table, column), (default,)]) if sql: self.connection.mass_action(sql) def dropColumn(self, table, columns): # type: (AnyStr, AnyStr) -> None self.drop_columns(table, columns) def drop_columns(self, table, column): # type: (AnyStr, Union[AnyStr, List[AnyStr]]) -> None # get old table columns and store the ones we want to keep result = self.connection.select('pragma table_info([%s])' % table) columns_list = ([column], column)[isinstance(column, list)] keptColumns = filter_list(lambda col: col['name'] not in columns_list, result) keptColumnsNames = [] final = [] pk = [] # copy the old table schema, column by column for column in keptColumns: keptColumnsNames.append(column['name']) cl = [column['name'], column['type']] ''' To be implemented if ever required if column['dflt_value']: cl.append(str(column['dflt_value'])) if column['notnull']: cl.append(column['notnull']) ''' if 0 != int(column['pk']): pk.append(column['name']) b = ' '.join(cl) final.append(b) # join all the table column creation fields final = ', '.join(final) keptColumnsNames = ', '.join(keptColumnsNames) # generate sql for the new table creation if 0 == len(pk): sql = 'CREATE TABLE [%s_new] (%s)' % (table, final) else: pk = ', '.join(pk) sql = 'CREATE TABLE [%s_new] (%s, PRIMARY KEY(%s))' % (table, final, pk) # create new temporary table and copy the old table data across, barring the removed column self.connection.action(sql) # noinspection SqlResolve self.connection.action('INSERT INTO [%s_new] SELECT %s FROM [%s]' % (table, keptColumnsNames, table)) # copy the old indexes from the old table result = self.connection.select("SELECT sql FROM sqlite_master WHERE tbl_name=? AND type='index'", [table]) # remove the old table and rename the new table to take it's place # noinspection SqlResolve self.connection.action('DROP TABLE [%s]' % table) # noinspection SqlResolve self.connection.action('ALTER TABLE [%s_new] RENAME TO [%s]' % (table, table)) # write any indexes to the new table if 0 < len(result): for index in result: self.connection.action(index['sql']) # vacuum the db as we will have a lot of space to reclaim after dropping tables self.connection.action('VACUUM') def checkDBVersion(self): return self.connection.checkDBVersion() def incDBVersion(self): new_version = self.checkDBVersion() + 1 # noinspection SqlConstantCondition self.connection.action('UPDATE db_version SET db_version = ? WHERE 1=1', [new_version]) return new_version def setDBVersion(self, new_version, check_db_version=True): # noinspection SqlConstantCondition self.connection.action('UPDATE db_version SET db_version = ? WHERE 1=1', [new_version]) return check_db_version and self.checkDBVersion() def listTables(self): return self.list_tables() def do_query(self, queries): if not isinstance(queries, list): queries = list(queries) elif isinstance(queries[0], list): queries = [item for sublist in queries for item in sublist] for query in queries: tbl_name = re.findall(r'(?i)DROP.*?TABLE.*?\[?([^\s\]]+)', query) if tbl_name and not self.hasTable(tbl_name[0]): continue tbl_name = re.findall(r'(?i)CREATE.*?TABLE.*?\s([^\s(]+)\s*\(', query) if tbl_name and self.hasTable(tbl_name[0]): continue self.connection.action(query) def finish(self, tbl_dropped=False): if tbl_dropped: self.connection.action('VACUUM') self.incDBVersion() def upgrade_log(self, *args, **kwargs): self.connection.upgrade_log(*args, **kwargs) def MigrationCode(my_db): schema = { 0: sickgear.mainDB.InitialSchema, 9: sickgear.mainDB.AddSizeAndSceneNameFields, 10: sickgear.mainDB.RenameSeasonFolders, 11: sickgear.mainDB.Add1080pAndRawHDQualities, 12: sickgear.mainDB.AddShowidTvdbidIndex, 13: sickgear.mainDB.AddLastUpdateTVDB, 14: sickgear.mainDB.AddDBIncreaseTo15, 15: sickgear.mainDB.AddIMDbInfo, 16: sickgear.mainDB.AddProperNamingSupport, 17: sickgear.mainDB.AddEmailSubscriptionTable, 18: sickgear.mainDB.AddProperSearch, 19: sickgear.mainDB.AddDvdOrderOption, 20: sickgear.mainDB.AddSubtitlesSupport, 21: sickgear.mainDB.ConvertTVShowsToIndexerScheme, 22: sickgear.mainDB.ConvertTVEpisodesToIndexerScheme, 23: sickgear.mainDB.ConvertIMDBInfoToIndexerScheme, 24: sickgear.mainDB.ConvertInfoToIndexerScheme, 25: sickgear.mainDB.AddArchiveFirstMatchOption, 26: sickgear.mainDB.AddSceneNumbering, 27: sickgear.mainDB.ConvertIndexerToInteger, 28: sickgear.mainDB.AddRequireAndIgnoreWords, 29: sickgear.mainDB.AddSportsOption, 30: sickgear.mainDB.AddSceneNumberingToTvEpisodes, 31: sickgear.mainDB.AddAnimeTVShow, 32: sickgear.mainDB.AddAbsoluteNumbering, 33: sickgear.mainDB.AddSceneAbsoluteNumbering, 34: sickgear.mainDB.AddAnimeAllowlistBlocklist, 35: sickgear.mainDB.AddSceneAbsoluteNumbering2, 36: sickgear.mainDB.AddXemRefresh, 37: sickgear.mainDB.AddSceneToTvShows, 38: sickgear.mainDB.AddIndexerMapping, 39: sickgear.mainDB.AddVersionToTvEpisodes, 40: sickgear.mainDB.BumpDatabaseVersion, 41: sickgear.mainDB.Migrate41, 42: sickgear.mainDB.Migrate41, 43: sickgear.mainDB.Migrate43, 44: sickgear.mainDB.Migrate43, 4301: sickgear.mainDB.Migrate4301, 4302: sickgear.mainDB.Migrate4302, 4400: sickgear.mainDB.Migrate4302, 5816: sickgear.mainDB.MigrateUpstream, 5817: sickgear.mainDB.MigrateUpstream, 5818: sickgear.mainDB.MigrateUpstream, 10000: sickgear.mainDB.SickGearDatabaseVersion, 10001: sickgear.mainDB.RemoveDefaultEpStatusFromTvShows, 10002: sickgear.mainDB.RemoveMinorDBVersion, 10003: sickgear.mainDB.RemoveMetadataSub, 20000: sickgear.mainDB.DBIncreaseTo20001, 20001: sickgear.mainDB.AddTvShowOverview, 20002: sickgear.mainDB.AddTvShowTags, 20003: sickgear.mainDB.ChangeMapIndexer, 20004: sickgear.mainDB.AddShowNotFoundCounter, 20005: sickgear.mainDB.AddFlagTable, 20006: sickgear.mainDB.DBIncreaseTo20007, 20007: sickgear.mainDB.AddWebdlTypesTable, 20008: sickgear.mainDB.AddWatched, 20009: sickgear.mainDB.AddPrune, 20010: sickgear.mainDB.AddIndexerToTables, 20011: sickgear.mainDB.AddShowExludeGlobals, 20012: sickgear.mainDB.RenameAllowBlockListTables, 20013: sickgear.mainDB.AddHistoryHideColumn, 20014: sickgear.mainDB.ChangeShowData, 20015: sickgear.mainDB.ChangeTmdbID, # 20002: sickgear.mainDB.AddCoolSickGearFeature3, } db_version = my_db.checkDBVersion() my_db.new_db = 0 == db_version logger.log(u'Detected database version: v%s' % db_version, logger.DEBUG) if not (db_version in schema): if db_version == sickgear.mainDB.MAX_DB_VERSION: logger.log(u'Database schema is up-to-date, no upgrade required') elif 10000 > db_version: logger.log_error_and_exit(u'SickGear does not currently support upgrading from this database version') else: logger.log_error_and_exit(u'Invalid database version') else: my_db.upgrade_log('Upgrading') while db_version < sickgear.mainDB.MAX_DB_VERSION: if None is schema[db_version]: # skip placeholders used when multi PRs are updating DB db_version += 1 continue try: update = schema[db_version](my_db) db_version = update.execute() cleanup_old_db_backups(my_db.filename) except (BaseException, Exception) as e: my_db.close() logger.log(u'Failed to update database with error: %s attempting recovery...' % ex(e), logger.ERROR) if restoreDatabase(my_db.filename, db_version): # initialize the main SB database logger.log_error_and_exit(u'Successfully restored database version: %s' % db_version) else: logger.log_error_and_exit(u'Failed to restore database version: %s' % db_version) my_db.upgrade_log('Finished') def cleanup_old_db_backups(filename): try: d, filename = os.path.split(filename) if not d: d = sickgear.DATA_DIR for f in filter_iter(lambda fn: fn.is_file() and filename in fn.name and re.search(r'\.db(\.v\d+)?\.r\d+$', fn.name), scandir(d)): try: os.unlink(f.path) except (BaseException, Exception): pass except (BaseException, Exception): pass def backup_database(db_connection, filename, version): if db_connection.new_db: logger.debug('new db, no backup required') return logger.log(u'Backing up database before upgrade') if not sickgear.helpers.backup_versioned_file(dbFilename(filename), version): logger.log_error_and_exit(u'Database backup failed, abort upgrading database') else: logger.log(u'Proceeding with upgrade') def get_rollback_module(): import types from . import helpers module_urls = [ 'https://raw.githubusercontent.com/SickGear/sickgear.extdata/main/SickGear/Rollback/rollback_sg.py'] try: hdr = '# SickGear Rollback Module' module = '' fetched = False for t in range(1, 4): for url in module_urls: try: module = helpers.get_url(url) if module and module.startswith(hdr): fetched = True break except (BaseException, Exception): continue if fetched: break time.sleep(30) if fetched: loaded = types.ModuleType('DbRollback') exec(module, loaded.__dict__) return loaded except (BaseException, Exception): pass def delete_old_db_backups(target): # type: (AnyStr) -> None """ remove old db backups (> BACKUP_DB_MAX_COUNT) :param target: backup folder to check """ use_count = (1, sickgear.BACKUP_DB_MAX_COUNT)[not sickgear.BACKUP_DB_ONEDAY] for include in ['sickbeard', 'cache', 'failed']: file_list = [f for f in scantree(target, include=include, filter_kind=False)] if use_count < len(file_list): file_list.sort(key=lambda _f: _f.stat(follow_symlinks=False).st_mtime, reverse=True) for direntry in file_list[use_count:]: remove_file_perm(direntry.path) def backup_all_dbs(target, compress=True, prefer_7z=True): # type: (AnyStr, bool, bool) -> Tuple[bool, AnyStr] """ backups all dbs to specified dir optional compress with zip or 7z (python 3 only, external lib py7zr required) 7z falls back to zip if py7zr is not available :param target: target folder to backup to :param compress: compress db backups :param prefer_7z: prefer 7z compression if available :return: success, message """ if not make_path(target): logger.log('Failed to create db backup dir', logger.ERROR) return False, 'Failed to create db backup dir' my_db = DBConnection('cache.db') last_backup = my_db.select('SELECT time FROM lastUpdate WHERE provider = ?', ['sickgear_db_backup']) if last_backup: now_stamp = int(timestamp_near(datetime.datetime.now())) the_time = int(last_backup[0]['time']) # only backup every 23 hours if now_stamp - the_time < 60 * 60 * 23: return False, 'Too early to backup db again' now = sgdatetime.SGDatetime.now() d = sgdatetime.SGDatetime.sbfdate(now, d_preset='%Y-%m-%d') t = sgdatetime.SGDatetime.sbftime(now, t_preset='%H-%M') ds = '%s_%s' % (d, t) for cur_db in ['sickbeard', 'cache', 'failed']: db_conn = DBConnection('%s.db' % cur_db) name = '%s_%s.db' % (cur_db, ds) success, msg = db_conn.backup_db(target=target, backup_filename=name) if not success: return False, msg if compress: full_path = os.path.join(target, name) if not compress_file(full_path, '%s.db' % cur_db, prefer_7z=prefer_7z): return False, 'Failure to compress backup' delete_old_db_backups(target) my_db.upsert('lastUpdate', {'time': int(time.mktime(now.timetuple()))}, {'provider': 'sickgear_db_backup'}) logger.log('successfully backed up all dbs') return True, 'successfully backed up all dbs'