SickGear/sickbeard/db.py

507 lines
18 KiB
Python
Raw Normal View History

# Author: Nic Wolfe <nic@wolfeden.ca>
# URL: http://code.google.com/p/sickbeard/
#
# This file is part of SickGear.
#
# SickGear is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# SickGear is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with SickGear. If not, see <http://www.gnu.org/licenses/>.
from __future__ import with_statement
import os.path
import re
import sqlite3
2014-06-05 16:22:54 +00:00
import time
2014-06-05 19:52:30 +00:00
import threading
2014-06-05 19:52:30 +00:00
import sickbeard
from sickbeard import encodingKludge as ek
from sickbeard import logger
from sickbeard.exceptions import ex
db_lock = threading.Lock()
def dbFilename(filename="sickbeard.db", suffix=None):
"""
@param filename: The sqlite database filename to use. If not specified,
will be made to be sickbeard.db
@param suffix: The suffix to append to the filename. A '.' will be added
automatically, i.e. suffix='v0' will make dbfile.db.v0
@return: the correct location of the database file.
"""
if suffix:
filename = "%s.%s" % (filename, suffix)
return ek.ek(os.path.join, sickbeard.DATA_DIR, filename)
class DBConnection(object):
def __init__(self, filename="sickbeard.db", suffix=None, row_type=None):
self.filename = filename
self.suffix = suffix
self.row_type = row_type
self.connection = None
try:
self.reconnect()
except Exception as e:
logger.log(u"DB error: " + ex(e), logger.ERROR)
raise
def reconnect(self):
"""Closes the existing database connection and re-opens it."""
self.close()
self.connection = sqlite3.connect(dbFilename(self.filename, self.suffix), 20, check_same_thread=False)
self.connection.isolation_level = None
if self.row_type == "dict":
self.connection.row_factory = self._dict_factory
else:
self.connection.row_factory = sqlite3.Row
def __del__(self):
self.close()
def _cursor(self):
"""Returns the cursor; reconnects if disconnected."""
if self.connection is None: self.reconnect()
return self.connection.cursor()
def execute(self, query, args=None, fetchall=False, fetchone=False):
"""Executes the given query, returning the lastrowid from the query."""
cursor = self._cursor()
try:
if fetchall:
return self._execute(cursor, query, args).fetchall()
elif fetchone:
return self._execute(cursor, query, args).fetchone()
else:
return self._execute(cursor, query, args)
finally:
cursor.close()
def _execute(self, cursor, query, args):
try:
if args == None:
return cursor.execute(query)
return cursor.execute(query, args)
except sqlite3.OperationalError as e:
logger.log(u"DB error: " + ex(e), logger.ERROR)
self.close()
raise
def checkDBVersion(self):
result = None
try:
if self.hasTable('db_version'):
result = self.select("SELECT db_version FROM db_version")
except:
2014-06-30 15:57:32 +00:00
return 0
if result:
return int(result[0]["db_version"])
else:
return 0
def mass_action(self, querylist, logTransaction=False, fetchall=False):
with db_lock:
# remove None types
querylist = [i for i in querylist if i != None]
if querylist == None:
return
sqlResult = []
attempt = 0
while attempt < 5:
try:
2014-06-05 16:22:54 +00:00
for qu in querylist:
if len(qu) == 1:
if logTransaction:
logger.log(qu[0], logger.DEBUG)
sqlResult.append(self.execute(qu[0], fetchall=fetchall))
2014-06-05 16:22:54 +00:00
elif len(qu) > 1:
if logTransaction:
logger.log(qu[0] + " with args " + str(qu[1]), logger.DEBUG)
sqlResult.append(self.execute(qu[0], qu[1], fetchall=fetchall))
2014-06-05 19:52:30 +00:00
logger.log(u"Transaction with " + str(len(querylist)) + u" queries executed", logger.DEBUG)
# finished
break
except sqlite3.OperationalError, e:
sqlResult = []
if self.connection:
self.connection.rollback()
if "unable to open database file" in e.args[0] or "database is locked" in e.args[0]:
logger.log(u"DB error: " + ex(e), logger.WARNING)
attempt += 1
2014-06-30 15:57:32 +00:00
time.sleep(1)
else:
logger.log(u"DB error: " + ex(e), logger.ERROR)
raise
except sqlite3.DatabaseError, e:
sqlResult = []
if self.connection:
self.connection.rollback()
logger.log(u"Fatal error executing query: " + ex(e), logger.ERROR)
raise
#time.sleep(0.02)
2014-06-30 15:57:32 +00:00
return sqlResult
def action(self, query, args=None, fetchall=False, fetchone=False):
with db_lock:
if query == None:
return
sqlResult = None
attempt = 0
while attempt < 5:
try:
if args == None:
logger.log(self.filename + ": " + query, logger.DB)
else:
logger.log(self.filename + ": " + query + " with args " + str(args), logger.DB)
sqlResult = self.execute(query, args, fetchall=fetchall, fetchone=fetchone)
# get out of the connection attempt loop since we were successful
break
except sqlite3.OperationalError, e:
if "unable to open database file" in e.args[0] or "database is locked" in e.args[0]:
logger.log(u"DB error: " + ex(e), logger.WARNING)
attempt += 1
2014-06-30 15:57:32 +00:00
time.sleep(1)
else:
logger.log(u"DB error: " + ex(e), logger.ERROR)
raise
except sqlite3.DatabaseError, e:
logger.log(u"Fatal error executing query: " + ex(e), logger.ERROR)
raise
#time.sleep(0.02)
2014-06-30 15:57:32 +00:00
return sqlResult
def select(self, query, args=None):
sqlResults = self.action(query, args, fetchall=True)
if sqlResults == None:
return []
return sqlResults
def selectOne(self, query, args=None):
sqlResults = self.action(query, args, fetchone=True)
if sqlResults == None:
return []
return sqlResults
def upsert(self, tableName, valueDict, keyDict):
2014-06-05 09:19:48 +00:00
changesBefore = self.connection.total_changes
genParams = lambda myDict: [x + " = ?" for x in myDict.keys()]
query = "UPDATE " + tableName + " SET " + ", ".join(genParams(valueDict)) + " WHERE " + " AND ".join(
genParams(keyDict))
2014-06-05 19:52:30 +00:00
self.action(query, valueDict.values() + keyDict.values())
2014-06-05 09:19:48 +00:00
if self.connection.total_changes == changesBefore:
query = "INSERT INTO " + tableName + " (" + ", ".join(valueDict.keys() + keyDict.keys()) + ")" + \
" VALUES (" + ", ".join(["?"] * len(valueDict.keys() + keyDict.keys())) + ")"
self.action(query, valueDict.values() + keyDict.values())
def tableInfo(self, tableName):
# FIXME ? binding is not supported here, but I cannot find a way to escape a string manually
sqlResult = self.select("PRAGMA table_info(%s)" % tableName)
columns = {}
for column in sqlResult:
columns[column['name']] = {'type': column['type']}
return columns
# http://stackoverflow.com/questions/3300464/how-can-i-get-dict-from-sqlite-query
def _dict_factory(self, cursor, row):
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d
def hasTable(self, tableName):
return len(self.select("SELECT 1 FROM sqlite_master WHERE name = ?;", (tableName, ))) > 0
def hasColumn(self, tableName, column):
return column in self.tableInfo(tableName)
def addColumn(self, table, column, type="NUMERIC", default=0):
self.action("ALTER TABLE %s ADD %s %s" % (table, column, type))
self.action("UPDATE %s SET %s = ?" % (table, column), (default,))
def close(self):
"""Close database connection"""
if getattr(self, "connection", None) is not None:
self.connection.close()
self.connection = None
def sanityCheckDatabase(connection, sanity_check):
sanity_check(connection).check()
class DBSanityCheck(object):
def __init__(self, connection):
self.connection = connection
def check(self):
pass
2014-06-05 19:52:30 +00:00
# ===============
# = Upgrade API =
# ===============
def upgradeDatabase(connection, schema):
logger.log(u"Checking database structure...", logger.MESSAGE)
_processUpgrade(connection, schema)
def prettyName(class_name):
return ' '.join([x.group() for x in re.finditer("([A-Z])([a-z0-9]+)", class_name)])
def restoreDatabase(version):
logger.log(u"Restoring database before trying upgrade again")
if not sickbeard.helpers.restoreVersionedFile(dbFilename(suffix='v' + str(version)), version):
logger.log_error_and_exit(u"Database restore failed, abort upgrading database")
return False
else:
return True
def _processUpgrade(connection, upgradeClass):
instance = upgradeClass(connection)
logger.log(u"Checking " + prettyName(upgradeClass.__name__) + " database upgrade", logger.DEBUG)
if not instance.test():
logger.log(u"Database upgrade required: " + prettyName(upgradeClass.__name__), logger.MESSAGE)
try:
instance.execute()
2014-06-05 19:52:30 +00:00
except sqlite3.DatabaseError, e:
# attemping to restore previous DB backup and perform upgrade
try:
instance.execute()
except:
restored = False
result = connection.select("SELECT db_version FROM db_version")
if result:
version = int(result[0]["db_version"])
# close db before attempting restore
connection.close()
if restoreDatabase(version):
# initialize the main SB database
upgradeDatabase(DBConnection(), sickbeard.mainDB.InitialSchema)
restored = True
if not restored:
print "Error in " + str(upgradeClass.__name__) + ": " + ex(e)
raise
logger.log(upgradeClass.__name__ + " upgrade completed", logger.DEBUG)
else:
logger.log(upgradeClass.__name__ + " upgrade not required", logger.DEBUG)
for upgradeSubClass in upgradeClass.__subclasses__():
_processUpgrade(connection, upgradeSubClass)
# Base migration class. All future DB changes should be subclassed from this class
class SchemaUpgrade(object):
def __init__(self, connection):
self.connection = connection
def hasTable(self, tableName):
return len(self.connection.select("SELECT 1 FROM sqlite_master WHERE name = ?;", (tableName, ))) > 0
def hasColumn(self, tableName, column):
return column in self.connection.tableInfo(tableName)
def addColumn(self, table, column, type="NUMERIC", default=0):
self.connection.action("ALTER TABLE %s ADD %s %s" % (table, column, type))
self.connection.action("UPDATE %s SET %s = ?" % (table, column), (default,))
2014-12-09 12:30:00 +00:00
def dropColumn(self, table, column):
# get old table columns and store the ones we want to keep
result = self.connection.select('pragma table_info(%s)' % table)
keptColumns = [c for c in result if c['name'] != column]
keptColumnsNames = []
final = []
pk = []
# copy the old table schema, column by column
for column in keptColumns:
keptColumnsNames.append(column['name'])
cl = []
cl.append(column['name'])
cl.append(column['type'])
'''
To be implemented if ever required
if column['dflt_value']:
cl.append(str(column['dflt_value']))
if column['notnull']:
cl.append(column['notnull'])
'''
if int(column['pk']) != 0:
pk.append(column['name'])
b = ' '.join(cl)
final.append(b)
# join all the table column creation fields
final = ', '.join(final)
keptColumnsNames = ', '.join(keptColumnsNames)
# generate sql for the new table creation
if len(pk) == 0:
sql = 'CREATE TABLE %s_new (%s)' % (table, final)
else:
pk = ', '.join(pk)
sql = 'CREATE TABLE %s_new (%s, PRIMARY KEY(%s))' % (table, final, pk)
# create new temporary table and copy the old table data across, barring the removed column
self.connection.action(sql)
self.connection.action('INSERT INTO %s_new SELECT %s FROM %s' % (table, keptColumnsNames, table))
# copy the old indexes from the old table
result = self.connection.select('SELECT sql FROM sqlite_master WHERE tbl_name=? and type="index"', [table])
# remove the old table and rename the new table to take it's place
self.connection.action('DROP TABLE %s' % table)
self.connection.action('ALTER TABLE %s_new RENAME TO %s' % (table, table))
# write any indexes to the new table
if len(result) > 0:
for index in result:
self.connection.action(index['sql'])
# vacuum the db as we will have a lot of space to reclaim after dropping tables
self.connection.action("VACUUM")
def checkDBVersion(self):
2014-06-30 15:57:32 +00:00
return self.connection.checkDBVersion()
def incDBVersion(self):
new_version = self.checkDBVersion() + 1
self.connection.action("UPDATE db_version SET db_version = ?", [new_version])
return new_version
2014-12-09 12:30:00 +00:00
def setDBVersion(self, new_version):
self.connection.action("UPDATE db_version SET db_version = ?", [new_version])
return new_version
def MigrationCode(myDB):
schema = {
0: sickbeard.mainDB.InitialSchema, # 0->20000
9: sickbeard.mainDB.AddSizeAndSceneNameFields,
10: sickbeard.mainDB.RenameSeasonFolders,
11: sickbeard.mainDB.Add1080pAndRawHDQualities,
12: sickbeard.mainDB.AddShowidTvdbidIndex,
13: sickbeard.mainDB.AddLastUpdateTVDB,
14: sickbeard.mainDB.AddDBIncreaseTo15,
15: sickbeard.mainDB.AddIMDbInfo,
16: sickbeard.mainDB.AddProperNamingSupport,
17: sickbeard.mainDB.AddEmailSubscriptionTable,
18: sickbeard.mainDB.AddProperSearch,
19: sickbeard.mainDB.AddDvdOrderOption,
20: sickbeard.mainDB.AddSubtitlesSupport,
21: sickbeard.mainDB.ConvertTVShowsToIndexerScheme,
22: sickbeard.mainDB.ConvertTVEpisodesToIndexerScheme,
23: sickbeard.mainDB.ConvertIMDBInfoToIndexerScheme,
24: sickbeard.mainDB.ConvertInfoToIndexerScheme,
25: sickbeard.mainDB.AddArchiveFirstMatchOption,
26: sickbeard.mainDB.AddSceneNumbering,
27: sickbeard.mainDB.ConvertIndexerToInteger,
28: sickbeard.mainDB.AddRequireAndIgnoreWords,
29: sickbeard.mainDB.AddSportsOption,
30: sickbeard.mainDB.AddSceneNumberingToTvEpisodes,
31: sickbeard.mainDB.AddAnimeTVShow,
32: sickbeard.mainDB.AddAbsoluteNumbering,
33: sickbeard.mainDB.AddSceneAbsoluteNumbering,
34: sickbeard.mainDB.AddAnimeBlacklistWhitelist,
35: sickbeard.mainDB.AddSceneAbsoluteNumbering2,
36: sickbeard.mainDB.AddXemRefresh,
37: sickbeard.mainDB.AddSceneToTvShows,
38: sickbeard.mainDB.AddIndexerMapping,
39: sickbeard.mainDB.AddVersionToTvEpisodes,
40: sickbeard.mainDB.BumpDatabaseVersion,
41: sickbeard.mainDB.Migrate41,
10000: sickbeard.mainDB.SickGearDatabaseVersion,
10001: sickbeard.mainDB.RemoveDefaultEpStatusFromTvShows
#20000: sickbeard.mainDB.AddCoolSickGearFeature1,
#20001: sickbeard.mainDB.AddCoolSickGearFeature2,
#20002: sickbeard.mainDB.AddCoolSickGearFeature3,
}
db_version = myDB.checkDBVersion()
logger.log(u'Detected database version: v' + str(db_version), logger.DEBUG)
if not (db_version in schema):
if db_version == sickbeard.mainDB.MAX_DB_VERSION:
logger.log(u'Database schema is up-to-date, no upgrade required')
elif db_version < 10000:
logger.log_error_and_exit(u'SickGear does not currently support upgrading from this database version')
else:
logger.log_error_and_exit(u'Invalid database version')
else:
while db_version < sickbeard.mainDB.MAX_DB_VERSION:
try:
update = schema[db_version](myDB)
db_version = update.execute()
except Exception, e:
myDB.close()
logger.log(u'Failed to update database with error: ' + ex(e) + ' attempting recovery...', logger.ERROR)
if restoreDatabase(db_version):
# initialize the main SB database
logger.log_error_and_exit(u'Successfully restored database version:' + str(db_version))
else:
logger.log_error_and_exit(u'Failed to restore database version:' + str(db_version))