SickGear/sickgear/db.py

866 lines
33 KiB
Python

#
# This file is part of SickGear.
#
# SickGear is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# SickGear is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with SickGear. If not, see <http://www.gnu.org/licenses/>.
import datetime
import itertools
import os.path
import re
import sqlite3
import threading
import time
from exceptions_helper import ex
import sickgear
from . import logger, sgdatetime
from .sgdatetime import SGDatetime
from sg_helpers import make_path, compress_file, remove_file_perm, scantree
from _23 import scandir
from six import iterkeys, iteritems, itervalues
# noinspection PyUnreachableCode
if False:
# noinspection PyUnresolvedReferences
from typing import Any, AnyStr, Dict, List, Optional, Tuple, Union
db_lock = threading.Lock()
db_support_multiple_insert = (3, 7, 11) <= sqlite3.sqlite_version_info # type: bool
db_support_column_rename = (3, 25, 0) <= sqlite3.sqlite_version_info # type: bool
db_support_upsert = (3, 25, 0) <= sqlite3.sqlite_version_info # type: bool
db_supports_backup = hasattr(sqlite3.Connection, 'backup') and (3, 6, 11) <= sqlite3.sqlite_version_info # type: bool
def db_filename(filename='sickbeard.db', suffix=None):
# type: (AnyStr, Optional[AnyStr]) -> AnyStr
"""
@param filename: The sqlite database filename to use. If not specified,
will be made to be sickbeard.db
@param suffix: The suffix to append to the filename. A '.' will be added
automatically, i.e. suffix='v0' will make dbfile.db.v0
@return: the correct location of the database file.
"""
if suffix:
filename = '%s.%s' % (filename, suffix)
return os.path.join(sickgear.DATA_DIR, filename)
def mass_upsert_sql(table_name, value_dict, key_dict, sanitise=True):
# type: (AnyStr, Dict, Dict, bool) -> List[List[AnyStr]]
"""
use with cl.extend(mass_upsert_sql(tableName, valueDict, keyDict))
:param table_name: table name
:param value_dict: dict of values to be set {'table_fieldname': value}
:param key_dict: dict of restrains for update {'table_fieldname': value}
:param sanitise: True to remove k, v pairs in keyDict from valueDict as they must not exist in both.
This option has a performance hit, so it's best to remove key_dict keys from value_dict and set this False instead.
:type sanitise: Boolean
:return: list of 2 sql command
"""
cl = []
gen_params = (lambda my_dict: [x + ' = ?' for x in iterkeys(my_dict)])
# sanity: remove k, v pairs in keyDict from valueDict
if sanitise:
value_dict = dict(filter(lambda k: k[0] not in key_dict, iteritems(value_dict)))
# noinspection SqlResolve
cl.append(['UPDATE [%s] SET %s WHERE %s' %
(table_name, ', '.join(gen_params(value_dict)), ' AND '.join(gen_params(key_dict))),
list(value_dict.values()) + list(key_dict.values())])
# noinspection SqlResolve
cl.append(['INSERT INTO [' + table_name + '] (' +
', '.join(["'%s'" % ('%s' % v).replace("'", "''") for v in
itertools.chain(iterkeys(value_dict), iterkeys(key_dict))]) + ')' +
' SELECT ' +
', '.join(["'%s'" % ('%s' % v).replace("'", "''") for v in
itertools.chain(itervalues(value_dict), itervalues(key_dict))]) +
' WHERE changes() = 0'])
return cl
class DBConnection(object):
def __init__(self, filename='sickbeard.db', row_type=None, **kwargs):
# type: (AnyStr, Optional[AnyStr], Dict) -> None
from . import helpers
self.new_db = False
db_src = db_filename(filename)
if not os.path.isfile(db_src):
db_alt = db_filename('sickrage.db')
if os.path.isfile(db_alt):
helpers.copy_file(db_alt, db_src)
self.filename = filename
self.connection = sqlite3.connect(db_src, timeout=20)
if 'dict' == row_type:
self.connection.row_factory = self._dict_factory
else:
self.connection.row_factory = sqlite3.Row
def backup_db(self, target, backup_filename=None):
# type: (AnyStr, AnyStr) -> Tuple[bool, AnyStr]
"""
backup the db to target dir + optional filename
Availability: SQLite 3.6.11 or higher
New in version 3.7
:param target: target dir
:param backup_filename: optional backup filename (default is the source name)
:return: success, message
"""
if not db_supports_backup:
logger.debug('this python sqlite3 version doesn\'t support backups')
return False, 'this python sqlite3 version doesn\'t support backups'
if not os.path.isdir(target):
logger.error('Backup target invalid')
return False, 'Backup target invalid'
target_db = os.path.join(target, (backup_filename, self.filename)[None is backup_filename])
if os.path.exists(target_db):
logger.error('Backup target file already exists')
return False, 'Backup target file already exists'
# noinspection PyUnusedLocal
def progress(status, remaining, total):
logger.debug('Copied %s of %s pages...' % (total - remaining, total))
backup_con = None
try:
# copy into this DB
backup_con = sqlite3.connect(target_db, timeout=20)
with backup_con:
with db_lock:
self.connection.backup(backup_con, progress=progress)
logger.debug('%s backup successful' % self.filename)
except sqlite3.Error as error:
logger.error("Error while taking backup: %s" % ex(error))
return False, 'Backup failed'
finally:
if backup_con:
try:
backup_con.close()
except (BaseException, Exception):
pass
return True, 'Backup successful'
def check_db_version(self):
# type: (...) -> int
try:
if self.has_table('db_version'):
result = self.select('SELECT db_version FROM db_version')
else:
version = self.select('PRAGMA user_version')[0]['user_version']
if version:
self.action('PRAGMA user_version = 0')
self.action('CREATE TABLE db_version (db_version INTEGER);')
self.action('INSERT INTO db_version (db_version) VALUES (%s);' % version)
return version
except (BaseException, Exception):
return 0
if result:
version = int(result[0]['db_version'])
if 10000 > version and self.has_column('db_version', 'db_minor_version'):
# noinspection SqlResolve
minor = self.select('SELECT db_minor_version FROM db_version')
return version * 100 + int(minor[0]['db_minor_version'])
return version
return 0
def mass_action(self, queries, log_transaction=False):
# type: (List[Union[List[AnyStr], Tuple[AnyStr, List], Tuple[AnyStr]]], bool) -> Optional[List, sqlite3.Cursor]
from . import helpers
with db_lock:
if None is queries:
return
if not queries:
return []
attempt = 0
sql_result = []
affected = 0
while 5 > attempt:
try:
cursor = self.connection.cursor()
if not log_transaction:
for cur_query in queries:
sql_result.append(cursor.execute(*tuple(cur_query)).fetchall())
affected += abs(cursor.rowcount)
else:
for cur_query in queries:
logger.log(cur_query[0] if 1 == len(cur_query)
else '%s with args %s' % tuple(cur_query), logger.DB)
sql_result.append(cursor.execute(*tuple(cur_query)).fetchall())
affected += abs(cursor.rowcount)
self.connection.commit()
if 0 < affected:
logger.debug(f'Transaction with {len(queries)} queries executed affected at least {affected:d}'
f' row{helpers.maybe_plural(affected)}')
return sql_result
except sqlite3.OperationalError as e:
sql_result = []
if self.connection:
self.connection.rollback()
if not self.action_error(e):
raise
attempt += 1
except sqlite3.DatabaseError as e:
if self.connection:
self.connection.rollback()
logger.error(f'Fatal error executing query: {ex(e)}')
raise
return sql_result
@staticmethod
def action_error(e):
if 'unable to open database file' in e.args[0] or 'database is locked' in e.args[0]:
logger.warning(f'DB error: {ex(e)}')
time.sleep(1)
return True
logger.error(f'DB error: {ex(e)}')
def action(self, query, args=None):
# type: (AnyStr, Optional[List, Tuple]) -> Optional[Union[List, sqlite3.Cursor]]
with db_lock:
if None is query:
return
sql_result = None
attempt = 0
while 5 > attempt:
try:
if None is args:
logger.log('%s: %s' % (self.filename, query), logger.DB)
sql_result = self.connection.execute(query)
else:
logger.log('%s: %s with args %s' % (self.filename, query, str(args)), logger.DB)
sql_result = self.connection.execute(query, args)
self.connection.commit()
# get out of the connection attempt loop since we were successful
break
except sqlite3.OperationalError as e:
if not self.action_error(e):
raise
attempt += 1
except sqlite3.DatabaseError as e:
logger.error(f'Fatal error executing query: {ex(e)}')
raise
return sql_result
def select(self, query, args=None):
# type: (AnyStr, Optional[List, Tuple]) -> List
sql_results = self.action(query, args).fetchall()
if None is sql_results:
return []
return sql_results
def upsert(self, table_name, value_dict, key_dict):
# type: (AnyStr, Dict, Dict) -> None
changes_before = self.connection.total_changes
gen_params = (lambda my_dict: [x + ' = ?' for x in iterkeys(my_dict)])
# noinspection SqlResolve
query = 'UPDATE [%s] SET %s WHERE %s' % (
table_name, ', '.join(gen_params(value_dict)), ' AND '.join(gen_params(key_dict)))
self.action(query, list(value_dict.values()) + list(key_dict.values()))
if self.connection.total_changes == changes_before:
# noinspection SqlResolve
query = 'INSERT INTO [' + table_name + ']' \
+ ' (%s)' % ', '.join(itertools.chain(iterkeys(value_dict), iterkeys(key_dict))) \
+ ' VALUES (%s)' % ', '.join(['?'] * (len(value_dict) + len(key_dict)))
self.action(query, list(value_dict.values()) + list(key_dict.values()))
def table_info(self, table_name):
# type: (AnyStr) -> Dict[AnyStr, Dict[AnyStr, AnyStr]]
# FIXME ? binding is not supported here, but I cannot find a way to escape a string manually
sql_result = self.select('PRAGMA table_info([%s])' % table_name)
columns = {}
for cur_column in sql_result:
columns[cur_column['name']] = {'type': cur_column['type']}
return columns
# http://stackoverflow.com/questions/3300464/how-can-i-get-dict-from-sqlite-query
@staticmethod
def _dict_factory(cursor, row):
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d
def has_table(self, table_name):
# type: (AnyStr) -> bool
return 0 < len(self.select('SELECT 1 FROM sqlite_master WHERE name = ?;', (table_name,)))
def has_column(self, table_name, column):
# type: (AnyStr, AnyStr) -> bool
return column in self.table_info(table_name)
def has_index(self, table_name, index):
# type: (AnyStr, AnyStr) -> bool
sql_results = self.select('PRAGMA index_list([%s])' % table_name)
for result in sql_results:
if result['name'] == index:
return True
return False
def remove_index(self, table, name):
# type: (AnyStr, AnyStr) -> None
if self.has_index(table, name):
self.action('DROP INDEX' + ' [%s]' % name)
def remove_table(self, name):
# type: (AnyStr) -> None
if self.has_table(name):
self.action('DROP TABLE' + ' [%s]' % name)
def has_flag(self, flag_name):
# type: (AnyStr) -> bool
sql_result = self.select('SELECT flag FROM flags WHERE flag = ?', [flag_name])
return 0 < len(sql_result)
def add_flag(self, flag_name):
# type: (AnyStr) -> bool
has_flag = self.has_flag(flag_name)
if not has_flag:
self.action('INSERT INTO flags (flag) VALUES (?)', [flag_name])
return not has_flag
def remove_flag(self, flag_name):
# type: (AnyStr) -> bool
has_flag = self.has_flag(flag_name)
if has_flag:
self.action('DELETE FROM flags WHERE flag = ?', [flag_name])
return has_flag
def toggle_flag(self, flag_name):
# type: (AnyStr) -> bool
"""
Add or remove a flag
:param flag_name: Name of flag
:return: True if this call added the flag, False if flag is removed
"""
if self.remove_flag(flag_name):
return False
self.add_flag(flag_name)
return True
def set_flag(self, flag_name, state=True):
# type: (AnyStr, bool) -> bool
"""
Set state of flag
:param flag_name: Name of flag
:param state: If true, create flag otherwise remove flag
:return: Previous state of flag
"""
return (self.add_flag, self.remove_flag)[not bool(state)](flag_name)
def close(self):
"""Close database connection"""
if None is not getattr(self, 'connection', None):
self.connection.close()
self.connection = None
def upgrade_log(self, to_log, log_level=logger.MESSAGE):
# type: (AnyStr, int) -> None
logger.load_log('Upgrading %s' % self.filename, to_log, log_level)
def sanity_check_db(connection, sanity_check):
sanity_check(connection).check()
class DBSanityCheck(object):
def __init__(self, connection):
self.connection = connection
def check(self):
pass
def upgrade_database(connection, schema):
logger.log('Checking database structure...', logger.MESSAGE)
connection.is_upgrading = False
connection.new_db = 0 == connection.check_db_version()
_process_upgrade(connection, schema)
if connection.is_upgrading:
connection.upgrade_log('Finished')
def _pretty_name(class_name):
# type: (AnyStr) -> AnyStr
return ' '.join([x.group() for x in re.finditer('([A-Z])([a-z0-9]+)', class_name)])
def _restore_database(filename, version):
logger.log('Restoring database before trying upgrade again')
if not sickgear.helpers.restore_versioned_file(db_filename(filename=filename, suffix='v%s' % version), version):
logger.log_error_and_exit('Database restore failed, abort upgrading database')
return False
return True
def _process_upgrade(connection, upgrade_class):
instance = upgrade_class(connection)
logger.debug('Checking %s database upgrade' % _pretty_name(upgrade_class.__name__))
if not instance.test():
connection.is_upgrading = True
connection.upgrade_log(getattr(upgrade_class, 'pretty_name', None) or _pretty_name(upgrade_class.__name__))
logger.log('Database upgrade required: %s' % _pretty_name(upgrade_class.__name__), logger.MESSAGE)
db_version = connection.check_db_version()
try:
# only do backup if it's not a new db
0 < db_version and backup_database(connection, connection.filename, db_version)
instance.execute()
cleanup_old_db_backups(connection.filename)
except (BaseException, Exception):
# attempting to restore previous DB backup and perform upgrade
if db_version:
# close db before attempting restore
connection.close()
if _restore_database(connection.filename, db_version):
logger.log_error_and_exit('Successfully restored database version: %s' % db_version)
else:
logger.log_error_and_exit('Failed to restore database version: %s' % db_version)
else:
logger.log_error_and_exit('Database upgrade failed, can\'t determine old db version, not restoring.')
logger.debug('%s upgrade completed' % upgrade_class.__name__)
else:
logger.debug('%s upgrade not required' % upgrade_class.__name__)
for upgradeSubClass in upgrade_class.__subclasses__():
_process_upgrade(connection, upgradeSubClass)
# Base migration class. All future DB changes should be subclassed from this class
class SchemaUpgrade(object):
def __init__(self, connection, **kwargs):
self.connection = connection
def has_table(self, table_name):
return 0 < len(self.connection.select('SELECT 1 FROM sqlite_master WHERE name = ?;', (table_name,)))
def has_column(self, table_name, column):
return column in self.connection.table_info(table_name)
def list_tables(self):
# type: (...) -> List[AnyStr]
"""
returns list of all table names in db
"""
return [s['name'] for s in self.connection.select('SELECT name FROM main.sqlite_master WHERE type = ?;',
['table'])]
def list_indexes(self):
# type: (...) -> List[AnyStr]
"""
returns list of all index names in db
"""
return [s['name'] for s in self.connection.select('SELECT name FROM main.sqlite_master WHERE type = ?;',
['index'])]
# noinspection SqlResolve
def add_column(self, table, column, data_type='NUMERIC', default=0, set_default=False):
self.connection.action('ALTER TABLE [%s] ADD %s %s%s' %
(table, column, data_type, ('', ' DEFAULT "%s"' % default)[set_default]))
self.connection.action('UPDATE [%s] SET %s = ?' % (table, column), (default,))
# noinspection SqlResolve
def add_columns(self, table, column_list=None):
# type: (AnyStr, List) -> None
if isinstance(column_list, list):
sql = []
for col in column_list:
is_list = isinstance(col, (list, tuple))
list_len = 0 if not is_list else len(col)
column = col if not is_list else col[0]
data_type = 'NUMERIC' if not is_list or 2 > list_len else col[1]
default = 0 if not is_list or 3 > list_len else col[2]
sql.append(['ALTER TABLE [%s] ADD %s %s%s' %
(table, column, data_type, '' if list_len < 3 else
' DEFAULT %s' % ('""' if 'TEXT' == data_type and '' == default else default))])
if 2 < list_len:
sql.append(['UPDATE [%s] SET %s = ?' % (table, column), (default,)])
if sql:
self.connection.mass_action(sql)
def drop_columns(self, table, column):
# type: (AnyStr, Union[AnyStr, List[AnyStr]]) -> None
# get old table columns and store the ones we want to keep
result = self.connection.select('pragma table_info([%s])' % table)
columns_list = ([column], column)[isinstance(column, list)]
kept_columns = list(filter(lambda col: col['name'] not in columns_list, result))
kept_columns_names = []
final = []
pk = []
# copy the old table schema, column by column
for column in kept_columns:
kept_columns_names.append(column['name'])
cl = [column['name'], column['type']]
'''
To be implemented if ever required
if column['dflt_value']:
cl.append(str(column['dflt_value']))
if column['notnull']:
cl.append(column['notnull'])
'''
if 0 != int(column['pk']):
pk.append(column['name'])
b = ' '.join(cl)
final.append(b)
# join all the table column creation fields
final = ', '.join(final)
kept_columns_names = ', '.join(kept_columns_names)
# generate sql for the new table creation
if 0 == len(pk):
sql = 'CREATE TABLE [%s_new] (%s)' % (table, final)
else:
pk = ', '.join(pk)
sql = 'CREATE TABLE [%s_new] (%s, PRIMARY KEY(%s))' % (table, final, pk)
# create new temporary table and copy the old table data across, barring the removed column
self.connection.action(sql)
# noinspection SqlResolve
self.connection.action('INSERT INTO [%s_new] SELECT %s FROM [%s]' % (table, kept_columns_names, table))
# copy the old indexes from the old table
result = self.connection.select("SELECT sql FROM sqlite_master WHERE tbl_name=? AND type='index'", [table])
# remove the old table and rename the new table to take its place
# noinspection SqlResolve
self.connection.action('DROP TABLE [%s]' % table)
# noinspection SqlResolve
self.connection.action('ALTER TABLE [%s_new] RENAME TO [%s]' % (table, table))
# write any indexes to the new table
if 0 < len(result):
for index in result:
self.connection.action(index['sql'])
# vacuum the db as we will have a lot of space to reclaim after dropping tables
self.connection.action('VACUUM')
def call_check_db_version(self):
return self.connection.check_db_version()
def inc_db_version(self):
new_version = self.call_check_db_version() + 1
# noinspection SqlConstantCondition
self.connection.action('UPDATE db_version SET db_version = ? WHERE 1=1', [new_version])
return new_version
def set_db_version(self, new_version, check_db_version=True):
# noinspection SqlConstantCondition
self.connection.action('UPDATE db_version SET db_version = ? WHERE 1=1', [new_version])
return check_db_version and self.call_check_db_version()
def do_query(self, queries):
if not isinstance(queries, list):
queries = list(queries)
elif isinstance(queries[0], list):
queries = [item for sublist in queries for item in sublist]
for query in queries:
tbl_name = re.findall(r'(?i)DROP.*?TABLE.*?\[?([^\s\]]+)', query)
if tbl_name and not self.has_table(tbl_name[0]):
continue
tbl_name = re.findall(r'(?i)CREATE.*?TABLE.*?\s([^\s(]+)\s*\(', query)
if tbl_name and self.has_table(tbl_name[0]):
continue
self.connection.action(query)
def finish(self, tbl_dropped=False):
if tbl_dropped:
self.connection.action('VACUUM')
self.inc_db_version()
def upgrade_log(self, *args, **kwargs):
self.connection.upgrade_log(*args, **kwargs)
def migration_code(my_db):
schema = {
0: sickgear.mainDB.InitialSchema,
9: sickgear.mainDB.AddSizeAndSceneNameFields,
10: sickgear.mainDB.RenameSeasonFolders,
11: sickgear.mainDB.Add1080pAndRawHDQualities,
12: sickgear.mainDB.AddShowidTvdbidIndex,
13: sickgear.mainDB.AddLastUpdateTVDB,
14: sickgear.mainDB.AddDBIncreaseTo15,
15: sickgear.mainDB.AddIMDbInfo,
16: sickgear.mainDB.AddProperNamingSupport,
17: sickgear.mainDB.AddEmailSubscriptionTable,
18: sickgear.mainDB.AddProperSearch,
19: sickgear.mainDB.AddDvdOrderOption,
20: sickgear.mainDB.AddSubtitlesSupport,
21: sickgear.mainDB.ConvertTVShowsToIndexerScheme,
22: sickgear.mainDB.ConvertTVEpisodesToIndexerScheme,
23: sickgear.mainDB.ConvertIMDBInfoToIndexerScheme,
24: sickgear.mainDB.ConvertInfoToIndexerScheme,
25: sickgear.mainDB.AddArchiveFirstMatchOption,
26: sickgear.mainDB.AddSceneNumbering,
27: sickgear.mainDB.ConvertIndexerToInteger,
28: sickgear.mainDB.AddRequireAndIgnoreWords,
29: sickgear.mainDB.AddSportsOption,
30: sickgear.mainDB.AddSceneNumberingToTvEpisodes,
31: sickgear.mainDB.AddAnimeTVShow,
32: sickgear.mainDB.AddAbsoluteNumbering,
33: sickgear.mainDB.AddSceneAbsoluteNumbering,
34: sickgear.mainDB.AddAnimeAllowlistBlocklist,
35: sickgear.mainDB.AddSceneAbsoluteNumbering2,
36: sickgear.mainDB.AddXemRefresh,
37: sickgear.mainDB.AddSceneToTvShows,
38: sickgear.mainDB.AddIndexerMapping,
39: sickgear.mainDB.AddVersionToTvEpisodes,
40: sickgear.mainDB.BumpDatabaseVersion,
41: sickgear.mainDB.Migrate41,
42: sickgear.mainDB.Migrate41,
43: sickgear.mainDB.Migrate43,
44: sickgear.mainDB.Migrate43,
4301: sickgear.mainDB.Migrate4301,
4302: sickgear.mainDB.Migrate4302,
4400: sickgear.mainDB.Migrate4302,
5816: sickgear.mainDB.MigrateUpstream,
5817: sickgear.mainDB.MigrateUpstream,
5818: sickgear.mainDB.MigrateUpstream,
10000: sickgear.mainDB.SickGearDatabaseVersion,
10001: sickgear.mainDB.RemoveDefaultEpStatusFromTvShows,
10002: sickgear.mainDB.RemoveMinorDBVersion,
10003: sickgear.mainDB.RemoveMetadataSub,
20000: sickgear.mainDB.DBIncreaseTo20001,
20001: sickgear.mainDB.AddTvShowOverview,
20002: sickgear.mainDB.AddTvShowTags,
20003: sickgear.mainDB.ChangeMapIndexer,
20004: sickgear.mainDB.AddShowNotFoundCounter,
20005: sickgear.mainDB.AddFlagTable,
20006: sickgear.mainDB.DBIncreaseTo20007,
20007: sickgear.mainDB.AddWebdlTypesTable,
20008: sickgear.mainDB.AddWatched,
20009: sickgear.mainDB.AddPrune,
20010: sickgear.mainDB.AddIndexerToTables,
20011: sickgear.mainDB.AddShowExludeGlobals,
20012: sickgear.mainDB.RenameAllowBlockListTables,
20013: sickgear.mainDB.AddHistoryHideColumn,
20014: sickgear.mainDB.ChangeShowData,
20015: sickgear.mainDB.ChangeTmdbID,
# 20002: sickgear.mainDB.AddCoolSickGearFeature3,
}
db_version = my_db.check_db_version()
my_db.new_db = 0 == db_version
logger.debug(f'Detected database version: v{db_version}')
if not (db_version in schema):
if db_version == sickgear.mainDB.MAX_DB_VERSION:
logger.log('Database schema is up-to-date, no upgrade required')
elif 10000 > db_version:
logger.log_error_and_exit('SickGear does not currently support upgrading from this database version')
else:
logger.log_error_and_exit('Invalid database version')
else:
my_db.upgrade_log('Upgrading')
while db_version < sickgear.mainDB.MAX_DB_VERSION:
if None is schema[db_version]: # skip placeholders used when multi PRs are updating DB
db_version += 1
continue
try:
update = schema[db_version](my_db)
db_version = update.execute()
cleanup_old_db_backups(my_db.filename)
except (BaseException, Exception) as e:
my_db.close()
logger.error(f'Failed to update database with error: {ex(e)} attempting recovery...')
if _restore_database(my_db.filename, db_version):
# initialize the main SB database
logger.log_error_and_exit(f'Successfully restored database version: {db_version}')
else:
logger.log_error_and_exit(f'Failed to restore database version: {db_version}')
my_db.upgrade_log('Finished')
def cleanup_old_db_backups(filename):
try:
d, filename = os.path.split(filename)
if not d:
d = sickgear.DATA_DIR
for f in filter(lambda fn: fn.is_file() and filename in fn.name and
re.search(r'\.db(\.v\d+)?\.r\d+$', fn.name),
scandir(d)):
try:
os.unlink(f.path)
except (BaseException, Exception):
pass
except (BaseException, Exception):
pass
def backup_database(db_connection, filename, version):
if db_connection.new_db:
logger.debug('new db, no backup required')
return
logger.log('Backing up database before upgrade')
if not sickgear.helpers.backup_versioned_file(db_filename(filename), version):
logger.log_error_and_exit('Database backup failed, abort upgrading database')
else:
logger.log('Proceeding with upgrade')
def get_rollback_module():
import types
from . import helpers
module_urls = [
'https://raw.githubusercontent.com/SickGear/sickgear.extdata/main/SickGear/Rollback/rollback_sg.py']
try:
hdr = '# SickGear Rollback Module'
module = ''
fetched = False
for t in range(1, 4):
for url in module_urls:
try:
module = helpers.get_url(url)
if module and module.startswith(hdr):
fetched = True
break
except (BaseException, Exception):
continue
if fetched:
break
time.sleep(30)
if fetched:
loaded = types.ModuleType('DbRollback')
exec(module, loaded.__dict__)
return loaded
except (BaseException, Exception):
pass
def delete_old_db_backups(target):
# type: (AnyStr) -> None
"""
remove old db backups (> BACKUP_DB_MAX_COUNT)
:param target: backup folder to check
"""
use_count = (1, sickgear.BACKUP_DB_MAX_COUNT)[not sickgear.BACKUP_DB_ONEDAY]
for include in ['sickbeard', 'cache', 'failed']:
file_list = [f for f in scantree(target, include=include, filter_kind=False)]
if use_count < len(file_list):
file_list.sort(key=lambda _f: _f.stat(follow_symlinks=False).st_mtime, reverse=True)
for direntry in file_list[use_count:]:
remove_file_perm(direntry.path)
def backup_all_dbs(target, compress=True, prefer_7z=True):
# type: (AnyStr, bool, bool) -> Tuple[bool, AnyStr]
"""
backups all dbs to specified dir
optional compress with zip or 7z (python 3 only, external lib py7zr required)
7z falls back to zip if py7zr is not available
:param target: target folder for backup db
:param compress: compress db backups
:param prefer_7z: prefer 7z compression if available
:return: success, message
"""
if not make_path(target):
logger.error('Failed to create db backup dir')
return False, 'Failed to create db backup dir'
my_db = DBConnection('cache.db')
last_backup = my_db.select('SELECT time FROM lastUpdate WHERE provider = ?', ['sickgear_db_backup'])
if last_backup:
now_stamp = SGDatetime.timestamp_near()
the_time = int(last_backup[0]['time'])
# only backup every 23 hours
if now_stamp - the_time < 60 * 60 * 23:
return False, 'Too early to backup db again'
now = sgdatetime.SGDatetime.now()
d = sgdatetime.SGDatetime.sbfdate(now, d_preset='%Y-%m-%d')
t = sgdatetime.SGDatetime.sbftime(now, t_preset='%H-%M')
ds = '%s_%s' % (d, t)
for cur_db in ['sickbeard', 'cache', 'failed']:
db_conn = DBConnection('%s.db' % cur_db)
name = '%s_%s.db' % (cur_db, ds)
success, msg = db_conn.backup_db(target=target, backup_filename=name)
if not success:
return False, msg
if compress:
full_path = os.path.join(target, name)
if not compress_file(full_path, '%s.db' % cur_db, prefer_7z=prefer_7z):
return False, 'Failure to compress backup'
delete_old_db_backups(target)
my_db.upsert('lastUpdate',
{'time': int(time.mktime(now.timetuple()))},
{'provider': 'sickgear_db_backup'})
logger.log('successfully backed up all dbs')
return True, 'successfully backed up all dbs'